Skip to content

Commit

Permalink
Merge pull request #722 from guilhem/limitnumber
Browse files Browse the repository at this point in the history
feat: Add MaxRequests paramete
  • Loading branch information
WGH- authored Jan 16, 2023
2 parents 9a6de69 + 521f430 commit 485293b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
23 changes: 23 additions & 0 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ type Collector struct {
// Context is the context that will be used for HTTP requests. You can set this
// to support clean cancellation of scraping.
Context context.Context
// MaxRequests limit the number of requests done by the instance.
// Set it to 0 for infinite requests (default).
MaxRequests uint32

store storage.Storage
debugger debug.Debugger
Expand Down Expand Up @@ -228,6 +231,8 @@ var (
ErrAbortedAfterHeaders = errors.New("Aborted after receiving response headers")
// ErrQueueFull is the error returned when the queue is full
ErrQueueFull = errors.New("Queue MaxSize reached")
// ErrMaxRequests is the error returned when exceeding max requests
ErrMaxRequests = errors.New("Max Requests limit reached")
)

var envMap = map[string]func(*Collector, string){
Expand Down Expand Up @@ -268,6 +273,12 @@ var envMap = map[string]func(*Collector, string){
c.MaxDepth = maxDepth
}
},
"MAX_REQUESTS": func(c *Collector, val string) {
maxRequests, err := strconv.ParseUint(val, 0, 32)
if err == nil {
c.MaxRequests = uint32(maxRequests)
}
},
"PARSE_HTTP_ERROR_RESPONSE": func(c *Collector, val string) {
c.ParseHTTPErrorResponse = isYesString(val)
},
Expand Down Expand Up @@ -320,6 +331,13 @@ func MaxDepth(depth int) CollectorOption {
}
}

// MaxDepth limits the recursion depth of visited URLs.
func MaxRequests(max uint32) CollectorOption {
return func(c *Collector) {
c.MaxRequests = max
}
}

// AllowedDomains sets the domain whitelist used by the Collector.
func AllowedDomains(domains ...string) CollectorOption {
return func(c *Collector) {
Expand Down Expand Up @@ -449,6 +467,7 @@ func (c *Collector) Init() {
c.UserAgent = "colly - https://github.com/gocolly/colly/v2"
c.Headers = nil
c.MaxDepth = 0
c.MaxRequests = 0
c.store = &storage.InMemoryStorage{}
c.store.Init()
c.MaxBodySize = 10 * 1024 * 1024
Expand Down Expand Up @@ -717,6 +736,9 @@ func (c *Collector) requestCheck(parsedURL *url.URL, method string, getBody func
if c.MaxDepth > 0 && c.MaxDepth < depth {
return ErrMaxDepth
}
if c.MaxRequests > 0 && c.requestCount >= c.MaxRequests {
return ErrMaxRequests
}
if err := c.checkFilters(u, parsedURL.Hostname()); err != nil {
return err
}
Expand Down Expand Up @@ -1278,6 +1300,7 @@ func (c *Collector) Clone() *Collector {
IgnoreRobotsTxt: c.IgnoreRobotsTxt,
MaxBodySize: c.MaxBodySize,
MaxDepth: c.MaxDepth,
MaxRequests: c.MaxRequests,
DisallowedURLFilters: c.DisallowedURLFilters,
URLFilters: c.URLFilters,
CheckHead: c.CheckHead,
Expand Down
19 changes: 19 additions & 0 deletions colly_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,25 @@ func TestCollectorDepth(t *testing.T) {
}
}

func TestCollectorRequests(t *testing.T) {
ts := newTestServer()
defer ts.Close()
maxRequests := uint32(5)
c1 := NewCollector(
MaxRequests(maxRequests),
AllowURLRevisit(),
)
requestCount := 0
c1.OnResponse(func(resp *Response) {
requestCount++
c1.Visit(ts.URL)
})
c1.Visit(ts.URL)
if requestCount != 5 {
t.Errorf("Invalid number of requests: %d (expected 5) with MaxRequests", requestCount)
}
}

func TestCollectorContext(t *testing.T) {
// "/slow" takes 1 second to return the response.
// If context does abort the transfer after 0.5 seconds as it should,
Expand Down

0 comments on commit 485293b

Please sign in to comment.