Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add MaxRequests paramete #722

Merged
merged 1 commit into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat: Add MaxRequests parameter
Signed-off-by: Guilhem Lettron <guilhem@barpilot.io>
  • Loading branch information
guilhem committed Aug 16, 2022
commit 521f43096415d2d4145535b5f6b175b1926703f1
23 changes: 23 additions & 0 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ type Collector struct {
// Context is the context that will be used for HTTP requests. You can set this
// to support clean cancellation of scraping.
Context context.Context
// MaxRequests limit the number of requests done by the instance.
// Set it to 0 for infinite requests (default).
MaxRequests uint32

store storage.Storage
debugger debug.Debugger
Expand Down Expand Up @@ -228,6 +231,8 @@ var (
ErrAbortedAfterHeaders = errors.New("Aborted after receiving response headers")
// ErrQueueFull is the error returned when the queue is full
ErrQueueFull = errors.New("Queue MaxSize reached")
// ErrMaxRequests is the error returned when exceeding max requests
ErrMaxRequests = errors.New("Max Requests limit reached")
)

var envMap = map[string]func(*Collector, string){
Expand Down Expand Up @@ -268,6 +273,12 @@ var envMap = map[string]func(*Collector, string){
c.MaxDepth = maxDepth
}
},
"MAX_REQUESTS": func(c *Collector, val string) {
maxRequests, err := strconv.ParseUint(val, 0, 32)
if err == nil {
c.MaxRequests = uint32(maxRequests)
}
},
"PARSE_HTTP_ERROR_RESPONSE": func(c *Collector, val string) {
c.ParseHTTPErrorResponse = isYesString(val)
},
Expand Down Expand Up @@ -320,6 +331,13 @@ func MaxDepth(depth int) CollectorOption {
}
}

// MaxDepth limits the recursion depth of visited URLs.
func MaxRequests(max uint32) CollectorOption {
return func(c *Collector) {
c.MaxRequests = max
}
}

// AllowedDomains sets the domain whitelist used by the Collector.
func AllowedDomains(domains ...string) CollectorOption {
return func(c *Collector) {
Expand Down Expand Up @@ -449,6 +467,7 @@ func (c *Collector) Init() {
c.UserAgent = "colly - https://github.com/gocolly/colly/v2"
c.Headers = nil
c.MaxDepth = 0
c.MaxRequests = 0
c.store = &storage.InMemoryStorage{}
c.store.Init()
c.MaxBodySize = 10 * 1024 * 1024
Expand Down Expand Up @@ -717,6 +736,9 @@ func (c *Collector) requestCheck(parsedURL *url.URL, method string, getBody func
if c.MaxDepth > 0 && c.MaxDepth < depth {
return ErrMaxDepth
}
if c.MaxRequests > 0 && c.requestCount >= c.MaxRequests {
return ErrMaxRequests
}
if err := c.checkFilters(u, parsedURL.Hostname()); err != nil {
return err
}
Expand Down Expand Up @@ -1278,6 +1300,7 @@ func (c *Collector) Clone() *Collector {
IgnoreRobotsTxt: c.IgnoreRobotsTxt,
MaxBodySize: c.MaxBodySize,
MaxDepth: c.MaxDepth,
MaxRequests: c.MaxRequests,
DisallowedURLFilters: c.DisallowedURLFilters,
URLFilters: c.URLFilters,
CheckHead: c.CheckHead,
Expand Down
19 changes: 19 additions & 0 deletions colly_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,25 @@ func TestCollectorDepth(t *testing.T) {
}
}

func TestCollectorRequests(t *testing.T) {
ts := newTestServer()
defer ts.Close()
maxRequests := uint32(5)
c1 := NewCollector(
MaxRequests(maxRequests),
AllowURLRevisit(),
)
requestCount := 0
c1.OnResponse(func(resp *Response) {
requestCount++
c1.Visit(ts.URL)
})
c1.Visit(ts.URL)
if requestCount != 5 {
t.Errorf("Invalid number of requests: %d (expected 5) with MaxRequests", requestCount)
}
}

func TestCollectorContext(t *testing.T) {
// "/slow" takes 1 second to return the response.
// If context does abort the transfer after 0.5 seconds as it should,
Expand Down