From 275fbb40aa3eb959684315b71b0a882246210536 Mon Sep 17 00:00:00 2001 From: vforfreedom Date: Sat, 21 Jan 2023 20:08:27 +0800 Subject: [PATCH] chore:format codes --- README.md | 1 + cmd.go | 2 +- cmd_test.go | 2 +- context.go | 3 +- distributed.go | 3 +- distributed_test.go | 6 +-- downloader.go | 9 ++-- downloader_test.go | 8 ++-- dupefilters.go | 7 +++- dupefilters_test.go | 36 ++++++++-------- engine.go | 8 ++-- engine_test.go | 12 +++--- events.go | 42 +++++++++++++------ example/quotes/middlerwares.go | 14 ++++--- example/quotes/pipeline.go | 5 +-- example/quotes/quotes.go | 2 +- example/quotes/spider.go | 11 +++-- exception_test.go | 23 +++++----- exceptions.go | 39 ++++++++++------- init.go | 1 - items.go | 4 +- limiter.go | 21 ++++++---- logger.go | 2 +- logger_test.go | 12 +++--- middlewares.go | 11 +++-- rdb.go | 14 +++---- rdb_test.go | 4 +- request.go | 76 +++++++++++++++++++++------------- response.go | 18 ++++---- settings.go | 9 ++-- settings_test.go | 2 +- spiders.go | 10 +++-- spiders_test.go | 16 +++---- stats.go | 46 +++++++++++++++----- stats_test.go | 13 +++--- utils.go | 11 ++++- utils_test.go | 20 ++++----- 37 files changed, 308 insertions(+), 215 deletions(-) diff --git a/README.md b/README.md index 40f4cdb..5621413 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Tegenaria crawl framework +[![Go Report Card](https://goreportcard.com/badge/github.com/wetrycode/tegenaria)](https://goreportcard.com/report/github.com/wetrycode/tegenaria) [![codecov](https://codecov.io/gh/wetrycode/tegenaria/branch/master/graph/badge.svg?token=XMW3K1JYPB)](https://codecov.io/gh/wetrycode/tegenaria) [![go workflow](https://github.com/wetrycode/tegenaria/actions/workflows/go.yml/badge.svg)](https://github.com/wetrycode/tegenaria/actions/workflows/go.yml/badge.svg) [![CodeQL](https://github.com/wetrycode/tegenaria/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/wetrycode/tegenaria/actions/workflows/codeql-analysis.yml) diff --git a/cmd.go b/cmd.go index b83a36c..bdf4846 100644 --- a/cmd.go +++ b/cmd.go @@ -50,6 +50,6 @@ func ExecuteCmd(engine *CrawlEngine) { } rootCmd.AddCommand(crawlCmd) - crawlCmd.Flags().BoolVarP(&engine.isMaster, "master", "m", false, "Whether to set the current node as the master node,defualt false") + crawlCmd.Flags().BoolVarP(&engine.isMaster, "master", "m", false, "Whether to set the current node as the master node,default false") rootCmd.Execute() } diff --git a/cmd_test.go b/cmd_test.go index fd78775..34d51d0 100644 --- a/cmd_test.go +++ b/cmd_test.go @@ -13,7 +13,7 @@ func TestCmdStart(t *testing.T) { buf := new(bytes.Buffer) rootCmd.SetOut(buf) rootCmd.SetErr(buf) - rootCmd.SetArgs([]string{"crawl","testCmdSpider"}) + rootCmd.SetArgs([]string{"crawl", "testCmdSpider"}) convey.So(engine.Execute, convey.ShouldNotPanic) }) } diff --git a/context.go b/context.go index 4da4353..a6e0d73 100644 --- a/context.go +++ b/context.go @@ -129,11 +129,12 @@ func WithContextId(ctxId string) ContextOption { c.CtxId = ctxId } } -func WithItemChannelSize(size int) ContextOption{ +func WithItemChannelSize(size int) ContextOption { return func(c *Context) { c.Items = make(chan *ItemMeta, size) } } + // NewContext 从内存池中构建context对象 func NewContext(request *Request, Spider SpiderInterface, opts ...ContextOption) *Context { ctx := contextPool.Get().(*Context) diff --git a/distributed.go b/distributed.go index b70b1e9..b3f7ef8 100644 --- a/distributed.go +++ b/distributed.go @@ -24,7 +24,6 @@ package tegenaria import ( "bytes" - "context" goContext "context" "encoding/gob" "fmt" @@ -585,7 +584,7 @@ func (w *DistributedWorker) CheckMasterLive() (bool, error) { result := []*redis.StringCmd{} for _, member := range members { - result = append(result, pipe.Get(context.TODO(), member)) + result = append(result, pipe.Get(goContext.TODO(), member)) } count, err := w.executeCheck(pipe, result, count) return count != 0, err diff --git a/distributed_test.go b/distributed_test.go index e3c764c..af52165 100644 --- a/distributed_test.go +++ b/distributed_test.go @@ -129,12 +129,12 @@ func TestAddNodeError(t *testing.T) { convey.So(err.Error(), convey.ShouldContainSubstring, "sadd add node error") patch.Reset() - patch = gomonkey.ApplyFunc((*DistributedWorker).addMaster,func (_ *DistributedWorker)error { + patch = gomonkey.ApplyFunc((*DistributedWorker).addMaster, func(_ *DistributedWorker) error { return errors.New("add master error") - + }) err = worker.AddNode() - convey.So(err.Error(), convey.ShouldContainSubstring,"add master error") + convey.So(err.Error(), convey.ShouldContainSubstring, "add master error") patch.Reset() }) diff --git a/downloader.go b/downloader.go index c8cb963..f6398fd 100644 --- a/downloader.go +++ b/downloader.go @@ -61,7 +61,6 @@ type SpiderDownloader struct { ProxyFunc func(req *http.Request) (*url.URL, error) } - // DownloaderOption 下载器可选参数函数 type DownloaderOption func(d *SpiderDownloader) @@ -273,13 +272,11 @@ func (d *SpiderDownloader) Download(ctx *Context) (*Response, error) { } // 构建网络请求上下文 var asCtxKey ctxKey = "key" - var timeoutCtx context.Context = nil - var valCtx context.Context = nil + valCtx := context.WithValue(ctx, asCtxKey, ctxValue) if ctx.Request.Timeout > 0 { - timeoutCtx, _ = context.WithTimeout(ctx, ctx.Request.Timeout) + timeoutCtx, cancel := context.WithTimeout(ctx, ctx.Request.Timeout) + defer cancel() valCtx = context.WithValue(timeoutCtx, asCtxKey, ctxValue) - } else { - valCtx = context.WithValue(ctx, asCtxKey, ctxValue) } req, err := http.NewRequestWithContext(valCtx, string(ctx.Request.Method), u.String(), ctx.Request.BodyReader) if err != nil { diff --git a/downloader_test.go b/downloader_test.go index b633610..1e7cdbc 100644 --- a/downloader_test.go +++ b/downloader_test.go @@ -224,7 +224,7 @@ func TestRequestProxyWithTimeOut(t *testing.T) { ProxyUrl: proxyServer.URL, } defer proxyServer.Close() - resp, err := newRequestDownloadCase("/testTimeout", GET, RequestWithRequestProxy(proxy),RequestWithTimeout(10 * time.Second)) + resp, err := newRequestDownloadCase("/testTimeout", GET, RequestWithRequestProxy(proxy), RequestWithTimeout(10*time.Second)) convey.So(err, convey.ShouldBeNil) convey.So(resp.Status, convey.ShouldAlmostEqual, 200) convey.So(resp.String(), convey.ShouldContainSubstring, "This is proxy Server.") @@ -236,7 +236,7 @@ func TestRequestProxyWithTimeOut(t *testing.T) { ProxyUrl: proxyServer.URL, } defer proxyServer.Close() - resp, err := newRequestDownloadCase("/testTimeout", GET, RequestWithRequestProxy(proxy),RequestWithTimeout(1 * time.Second)) + resp, err := newRequestDownloadCase("/testTimeout", GET, RequestWithRequestProxy(proxy), RequestWithTimeout(1*time.Second)) convey.So(err, convey.ShouldNotBeNil) convey.So(resp, convey.ShouldBeNil) }) @@ -252,7 +252,7 @@ func TestRequestHeaders(t *testing.T) { resp, err := newRequestDownloadCase("/testHeader", GET, RequestWithRequestHeader(headers)) convey.So(err, convey.ShouldBeNil) convey.So(resp.Status, convey.ShouldAlmostEqual, 200) - content:=resp.String() + content := resp.String() convey.So(content, convey.ShouldContainSubstring, "value") }) @@ -260,7 +260,7 @@ func TestRequestHeaders(t *testing.T) { func TestTimeout(t *testing.T) { convey.Convey("test request timeout", t, func() { - resp, err := newRequestDownloadCase("/testTimeout", GET, RequestWithTimeout(1 * time.Second)) + resp, err := newRequestDownloadCase("/testTimeout", GET, RequestWithTimeout(1*time.Second)) convey.So(err, convey.ShouldNotBeNil) convey.So(resp, convey.ShouldBeNil) diff --git a/dupefilters.go b/dupefilters.go index 158c67d..9dc955d 100644 --- a/dupefilters.go +++ b/dupefilters.go @@ -42,10 +42,12 @@ type RFPDupeFilterInterface interface { // DoDupeFilter request去重 DoDupeFilter(ctx *Context) (bool, error) } + // RFPDupeFilter 去重组件 type RFPDupeFilter struct { bloomFilter *bloom.BloomFilter } + // NewRFPDupeFilter 新建去重组件 // bloomP容错率 // bloomN数据规模 @@ -89,15 +91,16 @@ func (f *RFPDupeFilter) encodeHeader(request *Request) string { } return buf.String() } + // Fingerprint 计算指纹 func (f *RFPDupeFilter) Fingerprint(ctx *Context) ([]byte, error) { - request:=ctx.Request + request := ctx.Request if request.Url == "" { return nil, fmt.Errorf("request is nil,maybe it had been free") } // get sha128 sha := murmur3.New128() - method:=string(request.Method) + method := string(request.Method) _, err := io.WriteString(sha, method) if err != nil { return nil, err diff --git a/dupefilters_test.go b/dupefilters_test.go index 65ed5d5..ead1f76 100644 --- a/dupefilters_test.go +++ b/dupefilters_test.go @@ -8,7 +8,7 @@ import ( func TestDoDupeFilter(t *testing.T) { - convey.Convey("test dupefilter",t,func(){ + convey.Convey("test dupefilter", t, func() { server := newTestServer() headers := map[string]string{ "Params1": "params1", @@ -27,18 +27,18 @@ func TestDoDupeFilter(t *testing.T) { request3 := NewRequest(server.URL+"/testHeader2", GET, testParser, RequestWithRequestHeader(headers)) ctx3 := NewContext(request3, spider1) - duplicates := NewRFPDupeFilter(0.001,1024*1024) + duplicates := NewRFPDupeFilter(0.001, 1024*1024) r1, _ := duplicates.DoDupeFilter(ctx1) - convey.So(r1,convey.ShouldBeFalse) + convey.So(r1, convey.ShouldBeFalse) r2, _ := duplicates.DoDupeFilter(ctx2) convey.So(r2, convey.ShouldBeTrue) r3, _ := duplicates.DoDupeFilter(ctx3) - convey.So(r3,convey.ShouldBeFalse) + convey.So(r3, convey.ShouldBeFalse) }) } func TestDoBodyDupeFilter(t *testing.T) { - convey.Convey("test body dupefilter",t,func(){ + convey.Convey("test body dupefilter", t, func() { server := newTestServer() // downloader := NewDownloader() headers := map[string]string{ @@ -53,35 +53,35 @@ func TestDoBodyDupeFilter(t *testing.T) { spider1 := &TestSpider{ NewBaseSpider("testspider", []string{"https://www.baidu.com"}), } - request1 := NewRequest(server.URL+"/testHeader", GET, testParser, RequestWithRequestHeader(headers),RequestWithRequestBody(body)) + request1 := NewRequest(server.URL+"/testHeader", GET, testParser, RequestWithRequestHeader(headers), RequestWithRequestBody(body)) ctx1 := NewContext(request1, spider1) - request2 := NewRequest(server.URL+"/testHeader", GET, testParser, RequestWithRequestHeader(headers),RequestWithRequestBody(body)) + request2 := NewRequest(server.URL+"/testHeader", GET, testParser, RequestWithRequestHeader(headers), RequestWithRequestBody(body)) ctx2 := NewContext(request2, spider1) request3 := NewRequest(server.URL+"/testHeader2", GET, testParser, RequestWithRequestHeader(headers)) ctx3 := NewContext(request3, spider1) - request4 := NewRequest(server.URL+"/testHeader", GET, testParser, RequestWithRequestHeader(headers),RequestWithRequestBody(body)) + request4 := NewRequest(server.URL+"/testHeader", GET, testParser, RequestWithRequestHeader(headers), RequestWithRequestBody(body)) ctx4 := NewContext(request4, spider1) - duplicates := NewRFPDupeFilter(0.001,1024*1024) + duplicates := NewRFPDupeFilter(0.001, 1024*1024) r1, err := duplicates.DoDupeFilter(ctx1) - convey.So(err,convey.ShouldBeNil) - convey.So(r1,convey.ShouldBeFalse) + convey.So(err, convey.ShouldBeNil) + convey.So(r1, convey.ShouldBeFalse) r2, err := duplicates.DoDupeFilter(ctx2) - convey.So(err,convey.ShouldBeNil) - convey.So(r2,convey.ShouldBeTrue) + convey.So(err, convey.ShouldBeNil) + convey.So(r2, convey.ShouldBeTrue) r3, err := duplicates.DoDupeFilter(ctx3) - convey.So(err,convey.ShouldBeNil) - convey.So(r3,convey.ShouldBeFalse) + convey.So(err, convey.ShouldBeNil) + convey.So(r3, convey.ShouldBeFalse) r4, err := duplicates.DoDupeFilter(ctx4) - convey.So(err,convey.ShouldBeNil) - convey.So(r4,convey.ShouldBeTrue) + convey.So(err, convey.ShouldBeNil) + convey.So(r4, convey.ShouldBeTrue) }) - + } diff --git a/engine.go b/engine.go index 4ac9212..d21b883 100644 --- a/engine.go +++ b/engine.go @@ -311,8 +311,7 @@ func (e *CrawlEngine) writeCache(ctx *Context) error { var err error = nil // 是否进入去重流程 if e.filterDuplicateReq { - var ret bool = false - ret, err = e.rfpDupeFilter.DoDupeFilter(ctx) + ret, err := e.rfpDupeFilter.DoDupeFilter(ctx) if err != nil { isDuplicated = true engineLog.WithField("request_id", ctx.CtxId).Errorf("request unique error %s", err.Error()) @@ -364,9 +363,8 @@ func (e *CrawlEngine) doDownload(ctx *Context) error { } // 增加请求发送量 e.statistic.IncrRequestSent() - var rsp *Response = nil engineLog.WithField("request_id", ctx.CtxId).Infof("%s request ready to download", ctx.CtxId) - rsp, err = e.downloader.Download(ctx) + rsp, err := e.downloader.Download(ctx) if err != nil { return err } @@ -488,7 +486,7 @@ func NewEngine(opts ...EngineOption) *CrawlEngine { checkMasterLive: func() (bool, error) { return true, nil }, limiter: NewDefaultLimiter(32), downloader: NewDownloader(), - hooker: NewDefualtHooks(), + hooker: NewDefaultHooks(), } for _, o := range opts { o(Engine) diff --git a/engine_test.go b/engine_test.go index 6a9cc3b..1b1587f 100644 --- a/engine_test.go +++ b/engine_test.go @@ -275,14 +275,14 @@ func TestEngineStartPanic(t *testing.T) { ctxManager.Clear() } engine := newTestEngine("testStartPanicSpider") - patch:=gomonkey.ApplyFunc((*Statistic).OutputStats,func (_ *Statistic)map[string]uint64 { + patch := gomonkey.ApplyFunc((*Statistic).OutputStats, func(_ *Statistic) map[string]uint64 { panic("output panic") - + }) defer patch.Reset() - f := func(){engine.start("testStartPanicSpider")} + f := func() { engine.start("testStartPanicSpider") } convey.So(f, convey.ShouldPanic) - convey.So(engine.mutex.TryLock(),convey.ShouldBeTrue) + convey.So(engine.mutex.TryLock(), convey.ShouldBeTrue) engine.Close() }) @@ -412,7 +412,7 @@ func TestParseError(t *testing.T) { ctx := NewContext(request, spider) err = engine.doDownload(ctx) convey.So(err, convey.ShouldBeNil) - + err = engine.doParse(ctx) convey.So(err, convey.ShouldNotBeNil) convey.So(err.Error(), convey.ShouldContainSubstring, "parse response error") @@ -421,7 +421,7 @@ func TestParseError(t *testing.T) { } func wokerError(ctx *Context, url string, errMsg string, t *testing.T, patch *gomonkey.Patches, engine *CrawlEngine) { convey.Convey(fmt.Sprintf("test %s", errMsg), t, func() { - ctxPatch:=gomonkey.ApplyFunc((*Context).Close,func(_ *Context){}) + ctxPatch := gomonkey.ApplyFunc((*Context).Close, func(_ *Context) {}) defer func() { patch.Reset() ctxPatch.Reset() diff --git a/events.go b/events.go index d32f68e..be57f09 100644 --- a/events.go +++ b/events.go @@ -41,8 +41,10 @@ const ( // EXIT 退出 EXIT ) + // Hook 事件处理函数类型 type Hook func(params ...interface{}) error + // EventHooksInterface 事件处理函数接口 type EventHooksInterface interface { // Start 处理引擎启动事件 @@ -58,12 +60,14 @@ type EventHooksInterface interface { // EventsWatcher 事件监听器 EventsWatcher(ch chan EventType) error } -type DefualtHooks struct { +type DefaultHooks struct { } + // DistributedHooks 分布式事件监听器 type DistributedHooks struct { worker DistributedWorkerInterface } + // DistributedHooks 构建新的分布式监听器组件对象 func NewDistributedHooks(worker DistributedWorkerInterface) *DistributedHooks { return &DistributedHooks{ @@ -71,33 +75,40 @@ func NewDistributedHooks(worker DistributedWorkerInterface) *DistributedHooks { } } -// NewDefualtHooks 构建新的默认事件监听器 -func NewDefualtHooks() *DefualtHooks { - return &DefualtHooks{} + +// NewDefaultHooks 构建新的默认事件监听器 +func NewDefaultHooks() *DefaultHooks { + return &DefaultHooks{} } + // Start 处理START事件 -func (d *DefualtHooks) Start(params ...interface{}) error { +func (d *DefaultHooks) Start(params ...interface{}) error { return nil } + // Stop 处理STOP事件 -func (d *DefualtHooks) Stop(params ...interface{}) error { +func (d *DefaultHooks) Stop(params ...interface{}) error { return nil } + // Error 处理ERROR事件 -func (d *DefualtHooks) Error(params ...interface{}) error { +func (d *DefaultHooks) Error(params ...interface{}) error { return nil } + // Exit 处理EXIT事件 -func (d *DefualtHooks) Exit(params ...interface{}) error { +func (d *DefaultHooks) Exit(params ...interface{}) error { return nil } + // Heartbeat 处理HEARTBEAT事件 -func (d *DefualtHooks) Heartbeat(params ...interface{}) error { +func (d *DefaultHooks) Heartbeat(params ...interface{}) error { return nil } + // DefaultventsWatcher 默认的事件监听器 // ch 用于接收事件 -// hooker 事件处理实例化接口,比如DefualtHooks +// hooker 事件处理实例化接口,比如DefaultHooks func DefaultventsWatcher(ch chan EventType, hooker EventHooksInterface) error { for { select { @@ -140,34 +151,41 @@ func DefaultventsWatcher(ch chan EventType, hooker EventHooksInterface) error { } } -// EventsWatcher DefualtHooks 的事件监听器 -func (d *DefualtHooks) EventsWatcher(ch chan EventType) error { + +// EventsWatcher DefaultHooks 的事件监听器 +func (d *DefaultHooks) EventsWatcher(ch chan EventType) error { return DefaultventsWatcher(ch, d) } + // Start 用于处理分布式模式下的START事件 func (d *DistributedHooks) Start(params ...interface{}) error { return d.worker.AddNode() } + // Stop 用于处理分布式模式下的STOP事件 func (d *DistributedHooks) Stop(params ...interface{}) error { return d.worker.StopNode() } + // Error 用于处理分布式模式下的ERROR事件 func (d *DistributedHooks) Error(params ...interface{}) error { return nil } + // Exit 用于处理分布式模式下的Exit事件 func (d *DistributedHooks) Exit(params ...interface{}) error { return d.worker.DelNode() } + // EventsWatcher 分布式模式下的事件监听器 func (d *DistributedHooks) EventsWatcher(ch chan EventType) error { return DefaultventsWatcher(ch, d) } + // Exit 用于处理分布式模式下的HEARTBEAT事件 func (d *DistributedHooks) Heartbeat(params ...interface{}) error { return d.worker.Heartbeat() diff --git a/example/quotes/middlerwares.go b/example/quotes/middlerwares.go index ba7efaf..26ba373 100644 --- a/example/quotes/middlerwares.go +++ b/example/quotes/middlerwares.go @@ -11,17 +11,20 @@ type HeadersDownloadMiddler struct { // Priority 优先级 Priority int // Name 中间件名称 - Name string + Name string } + // ProxyDownloadMiddler 代理挂载中间件 type ProxyDownloadMiddler struct { Priority int Name string } + // GetPriority 获取优先级,数字越小优先级越高 func (m HeadersDownloadMiddler) GetPriority() int { return m.Priority } + // ProcessRequest 处理request请求对象 // 此处用于增加请求头 // 按优先级执行 @@ -29,8 +32,8 @@ func (m HeadersDownloadMiddler) ProcessRequest(ctx *tegenaria.Context) error { header := map[string]string{ "Accept": "*/*", "Content-Type": "application/json", - - "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/101.0.4951.64 Safari/537.36", + + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/101.0.4951.64 Safari/537.36", } for key, value := range header { @@ -40,10 +43,11 @@ func (m HeadersDownloadMiddler) ProcessRequest(ctx *tegenaria.Context) error { } return nil } + // ProcessResponse 用于处理请求成功之后的response // 执行顺序你优先级,及优先级越高执行顺序越晚 func (m HeadersDownloadMiddler) ProcessResponse(ctx *tegenaria.Context, req chan<- *tegenaria.Context) error { - if ctx.Response.Status!=200{ + if ctx.Response.Status != 200 { return fmt.Errorf("非法状态码:%d", ctx.Response.Status) } return nil @@ -51,4 +55,4 @@ func (m HeadersDownloadMiddler) ProcessResponse(ctx *tegenaria.Context, req chan } func (m HeadersDownloadMiddler) GetName() string { return m.Name -} \ No newline at end of file +} diff --git a/example/quotes/pipeline.go b/example/quotes/pipeline.go index 40aab1e..9ffbece 100644 --- a/example/quotes/pipeline.go +++ b/example/quotes/pipeline.go @@ -16,8 +16,8 @@ type QuotesbotItemPipeline3 struct { // ProcessItem item处理函数 func (p *QuotesbotItemPipeline) ProcessItem(spider tegenaria.SpiderInterface, item *tegenaria.ItemMeta) error { - i:=item.Item.(*QuotesbotItem) - exampleLog.Infof("%s 抓取到数据:%s",item.CtxId, i.Text) + i := item.Item.(*QuotesbotItem) + exampleLog.Infof("%s 抓取到数据:%s", item.CtxId, i.Text) return nil } @@ -44,4 +44,3 @@ func (p *QuotesbotItemPipeline3) ProcessItem(spider tegenaria.SpiderInterface, i func (p *QuotesbotItemPipeline3) GetPriority() int { return p.Priority } - diff --git a/example/quotes/quotes.go b/example/quotes/quotes.go index 92568c1..3a9309e 100644 --- a/example/quotes/quotes.go +++ b/example/quotes/quotes.go @@ -29,4 +29,4 @@ func NewQuotesEngine(opts ...tegenaria.EngineOption) *tegenaria.CrawlEngine { Engine.RegisterDownloadMiddlewares(middleware) return Engine -} \ No newline at end of file +} diff --git a/example/quotes/spider.go b/example/quotes/spider.go index 179d0d1..d3d9f31 100644 --- a/example/quotes/spider.go +++ b/example/quotes/spider.go @@ -11,10 +11,11 @@ import ( ) var exampleLog *logrus.Entry = tegenaria.GetLogger("example") + // ExampleSpider 定义一个spider type ExampleSpider struct { // Name 爬虫名 - Name string + Name string // 种子urls FeedUrls []string } @@ -25,6 +26,7 @@ type QuotesbotItem struct { Author string Tags string } + // StartRequest 爬虫启动,请求种子urls func (e *ExampleSpider) StartRequest(req chan<- *tegenaria.Context) { for i := 0; i < 512; i++ { @@ -63,7 +65,7 @@ func (e *ExampleSpider) Parser(resp *tegenaria.Context, req chan<- *tegenaria.Co Tags: strings.Join(tags, ","), } exampleLog.Infof("text:%s,author:%s, tag: %s", qText, author, tags) - // 构建item发送到指定的channel + // 构建item发送到指定的channel itemCtx := tegenaria.NewItem(resp, "eItem) resp.Items <- itemCtx }) @@ -94,7 +96,8 @@ func (e *ExampleSpider) ErrorHandler(err *tegenaria.Context, req chan<- *tegenar func (e *ExampleSpider) GetName() string { return e.Name } + // GetFeedUrls 获取种子urls -func(e *ExampleSpider)GetFeedUrls()[]string{ +func (e *ExampleSpider) GetFeedUrls() []string { return e.FeedUrls -} \ No newline at end of file +} diff --git a/exception_test.go b/exception_test.go index 0e18316..36f7607 100644 --- a/exception_test.go +++ b/exception_test.go @@ -8,17 +8,16 @@ import ( ) func TestErrorWithExtras(t *testing.T) { -convey.Convey("test error with extra",t,func(){ - extras:=map[string]interface{}{ - } - extras["errExt"] = "ext" - spider := &TestSpider{ - NewBaseSpider("testspider", []string{"https://www.baidu.com"}), - } - request := NewRequest("http://www.example.com", GET, spider.Parser) - ctx := NewContext(request, spider) - err:=NewError(ctx,errors.New("ctx error"),ErrorWithExtras(extras)) - convey.So(err.Extras, convey.ShouldContainKey,"errExt") -}) + convey.Convey("test error with extra", t, func() { + extras := map[string]interface{}{} + extras["errExt"] = "ext" + spider := &TestSpider{ + NewBaseSpider("testspider", []string{"https://www.baidu.com"}), + } + request := NewRequest("http://www.example.com", GET, spider.Parser) + ctx := NewContext(request, spider) + err := NewError(ctx, errors.New("ctx error"), ErrorWithExtras(extras)) + convey.So(err.Extras, convey.ShouldContainKey, "errExt") + }) } diff --git a/exceptions.go b/exceptions.go index 22872d0..2e84127 100644 --- a/exceptions.go +++ b/exceptions.go @@ -30,66 +30,73 @@ import ( var ( // ErrSpiderMiddleware 下载中间件处理异常 - ErrSpiderMiddleware error = errors.New("handle spider middleware error") + ErrSpiderMiddleware error = errors.New("handle spider middleware error") // ErrSpiderCrawls 抓取流程错误 - ErrSpiderCrawls error = errors.New("handle spider crawl error") + ErrSpiderCrawls error = errors.New("handle spider crawl error") // ErrDuplicateSpiderName 爬虫名重复错误 ErrDuplicateSpiderName error = errors.New("register a duplicate spider name error") // ErrEmptySpiderName 爬虫名不能为空 - ErrEmptySpiderName error = errors.New("register a empty spider name error") + ErrEmptySpiderName error = errors.New("register a empty spider name error") // ErrSpiderNotExist 爬虫实例不存在 - ErrSpiderNotExist error = errors.New("not found spider") + ErrSpiderNotExist error = errors.New("not found spider") // ErrNotAllowStatusCode 不允许的状态码 - ErrNotAllowStatusCode error = errors.New("not allow handle status code") + ErrNotAllowStatusCode error = errors.New("not allow handle status code") // ErrGetCacheItem 获取item 错误 - ErrGetCacheItem error = errors.New("getting item from cache error") + ErrGetCacheItem error = errors.New("getting item from cache error") // ErrGetHttpProxy 获取http代理错误 - ErrGetHttpProxy error = errors.New("getting http proxy ") + ErrGetHttpProxy error = errors.New("getting http proxy ") // ErrGetHttpsProxy 获取https代理错误 - ErrGetHttpsProxy error = errors.New("getting https proxy ") + ErrGetHttpsProxy error = errors.New("getting https proxy ") // ErrParseSocksProxy 解析socks代理错误 - ErrParseSocksProxy error = errors.New("parse socks proxy ") + ErrParseSocksProxy error = errors.New("parse socks proxy ") // ErrResponseRead 响应读取失败 - ErrResponseRead error = errors.New("read response to buffer error") + ErrResponseRead error = errors.New("read response to buffer error") // ErrResponseParse 响应解析失败 - ErrResponseParse error = errors.New("parse response error") + ErrResponseParse error = errors.New("parse response error") ) + // RedirectError 重定向错误 type RedirectError struct { RedirectNum int } + // HandleError 错误处理接口 type HandleError struct { // CtxId 上下文id - CtxId string + CtxId string // Err 处理过程的错误 - Err error + Err error // Extras 携带的额外信息 Extras map[string]interface{} } + // ErrorOption HandleError 可选参数 type ErrorOption func(e *HandleError) + // ErrorWithExtras HandleError 添加额外的数据 -func ErrorWithExtras(extras map[string]interface{}) ErrorOption{ +func ErrorWithExtras(extras map[string]interface{}) ErrorOption { return func(e *HandleError) { e.Extras = extras } } + // NewError 构建新的HandleError实例 func NewError(ctx *Context, err error, opts ...ErrorOption) *HandleError { h := &HandleError{ - CtxId: ctx.CtxId, - Err: err, + CtxId: ctx.CtxId, + Err: err, } for _, o := range opts { o(h) } return h } + // Error 获取HandleError错误信息 func (e *HandleError) Error() string { return fmt.Sprintf("%s with context id %s", e.Err.Error(), e.CtxId) } + // Error获取RedirectError错误 func (e *RedirectError) Error() string { return "exceeded the maximum number of redirects: " + strconv.Itoa(e.RedirectNum) diff --git a/init.go b/init.go index 91d6564..b7bf4a1 100644 --- a/init.go +++ b/init.go @@ -20,7 +20,6 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. - package tegenaria import ( diff --git a/items.go b/items.go index 7818ec7..43cebe7 100644 --- a/items.go +++ b/items.go @@ -25,13 +25,15 @@ package tegenaria // ItemInterface item实例接口 type ItemInterface interface { } + // ItemMeta item元数据结构 type ItemMeta struct { // CtxId 对应的context id CtxId string // Item item对象 - Item ItemInterface + Item ItemInterface } + // NewItem 构建新的ItemMeta对象 func NewItem(ctx *Context, item ItemInterface) *ItemMeta { return &ItemMeta{ diff --git a/limiter.go b/limiter.go index e4284d3..54129f4 100644 --- a/limiter.go +++ b/limiter.go @@ -30,6 +30,7 @@ import ( "github.com/go-redis/redis/v8" "go.uber.org/ratelimit" ) + // LimitInterface 限速器接口 type LimitInterface interface { // checkAndWaitLimiterPass 检查当前并发量 @@ -38,6 +39,7 @@ type LimitInterface interface { // setCurrrentSpider 设置当前正在的运行的spider setCurrrentSpider(spider string) } + // leakyBucketLuaScript 漏桶算法lua脚本 const leakyBucketLuaScript string = `-- 最高水位 local safetyLevel = tonumber(ARGV[1]) @@ -82,20 +84,21 @@ return 1` // leakyBucketLimiterWithRdb单机redis下的漏桶限速器 type leakyBucketLimiterWithRdb struct { // safetyLevel 最高水位 - safetyLevel int + safetyLevel int // currentLevel 当前水位 - currentLevel int + currentLevel int // waterVelocity 水流速度/秒 - waterVelocity int + waterVelocity int // currentSpider 当前正在运行的爬虫名 currentSpider string // rdb redis客户端实例 - rdb redis.Cmdable // redis 客户端 + rdb redis.Cmdable // redis 客户端 // script redis lua脚本 - script *redis.Script // lua脚本 + script *redis.Script // lua脚本 // keyFunc 限速器使用的缓存key函数 - keyFunc GetRDBKey + keyFunc GetRDBKey } + // defaultLimiter 默认的限速器 type defaultLimiter struct { limiter ratelimit.Limiter @@ -108,16 +111,19 @@ func NewDefaultLimiter(limitRate int) *defaultLimiter { limiter: ratelimit.New(limitRate, ratelimit.WithoutSlack), } } + // checkAndWaitLimiterPass 检查当前并发量 // 如果并发量达到上限则等待 func (d *defaultLimiter) checkAndWaitLimiterPass() error { d.limiter.Take() return nil } + // setCurrrentSpider 设置当前的spider名 func (d *defaultLimiter) setCurrrentSpider(spider string) { } + // NewLeakyBucketLimiterWithRdb leakyBucketLimiterWithRdb 构造函数 func NewLeakyBucketLimiterWithRdb(safetyLevel int, rdb redis.Cmdable, keyFunc GetRDBKey) *leakyBucketLimiterWithRdb { script := readLuaScript() @@ -131,6 +137,7 @@ func NewLeakyBucketLimiterWithRdb(safetyLevel int, rdb redis.Cmdable, keyFunc Ge } } + // tryPassLimiter 尝试通过限速器 func (l *leakyBucketLimiterWithRdb) tryPassLimiter() (bool, error) { now := time.Now().Unix() @@ -149,7 +156,7 @@ func (l *leakyBucketLimiterWithRdb) setCurrrentSpider(spider string) { l.currentSpider = spider key, ttl := l.keyFunc() key = fmt.Sprintf("%s:%s", key, l.currentSpider) - if ttl>0{ + if ttl > 0 { l.rdb.Expire(context.TODO(), key, ttl) } diff --git a/logger.go b/logger.go index 865b90a..4d7f76d 100644 --- a/logger.go +++ b/logger.go @@ -68,7 +68,7 @@ func initLog() { logLevel = "error" } logLevel = strings.TrimSpace(logLevel) - if logLevel == ""{ + if logLevel == "" { logLevel = "info" } level, err := logrus.ParseLevel(logLevel) diff --git a/logger_test.go b/logger_test.go index a2768e8..e94a4bb 100644 --- a/logger_test.go +++ b/logger_test.go @@ -16,22 +16,22 @@ func TestLogger(t *testing.T) { log.Infof("testtest") }) convey.Convey("test log level empty", t, func() { - patch := gomonkey.ApplyFunc((*viper.Viper).GetString, func(_ *viper.Viper, _ string)string { + patch := gomonkey.ApplyFunc((*viper.Viper).GetString, func(_ *viper.Viper, _ string) string { return "" }) defer patch.Reset() initLog() - convey.So(logger.Level.String(),convey.ShouldContainSubstring,"info") + convey.So(logger.Level.String(), convey.ShouldContainSubstring, "info") }) convey.Convey("test log level parser error", t, func() { - patch := gomonkey.ApplyFunc(logrus.ParseLevel, func(_ string)(logrus.Level,error) { - return logrus.ErrorLevel,errors.New("parse level error") + patch := gomonkey.ApplyFunc(logrus.ParseLevel, func(_ string) (logrus.Level, error) { + return logrus.ErrorLevel, errors.New("parse level error") }) defer patch.Reset() - f:=func(){ + f := func() { initLog() } - convey.So(f,convey.ShouldPanic) + convey.So(f, convey.ShouldPanic) }) } diff --git a/middlewares.go b/middlewares.go index 74e31dc..276cb34 100644 --- a/middlewares.go +++ b/middlewares.go @@ -28,25 +28,28 @@ type MiddlewaresInterface interface { // GetPriority 获取优先级,数字越小优先级越高 GetPriority() int - // ProcessRequest 处理request请求对象 - // 此处用于增加请求头 - // 按优先级执行 + // ProcessRequest 处理request请求对象 + // 此处用于增加请求头 + // 按优先级执行 ProcessRequest(ctx *Context) error // ProcessResponse 用于处理请求成功之后的response - // 执行顺序你优先级,及优先级越高执行顺序越晚 + // 执行顺序你优先级,及优先级越高执行顺序越晚 ProcessResponse(ctx *Context, req chan<- *Context) error // GetName 获取中间件的名称 GetName() string } + // ProcessResponse 处理下载之后的response函数 type ProcessResponse func(ctx *Context) error type MiddlewaresBase struct { Priority int } + // Middlewares 下载中间件队列 type Middlewares []MiddlewaresInterface + // 实现sort接口 func (p Middlewares) Len() int { return len(p) } func (p Middlewares) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/rdb.go b/rdb.go index b0bbf3c..407ed1d 100644 --- a/rdb.go +++ b/rdb.go @@ -32,7 +32,7 @@ import ( // NewRdbConfig redis 配置构造函数 func NewRdbConfig(config *DistributedWorkerConfig) *redis.Options { return &redis.Options{ - Password: config.RedisPasswd, // 密码 + Password: config.RedisPasswd, // 密码 Username: config.RedisUsername, //用户名 DB: int(config.RedisDB), // redis数据库index @@ -62,22 +62,22 @@ func NewRdbClient(config *DistributedWorkerConfig) *redis.Client { options := NewRdbConfig(config) options.Addr = config.RedisAddr rdb := redis.NewClient(options) - err:=rdb.Ping(context.TODO()).Err() + err := rdb.Ping(context.TODO()).Err() // RedisAddr 为空说明处于集群模式 - if err!=nil && config.RedisAddr !=""{ + if err != nil && config.RedisAddr != "" { panic(err) } return rdb } func NewRdbClusterCLient(config *WorkerConfigWithRdbCluster) *redis.ClusterClient { - client:=redis.NewClusterClient(&redis.ClusterOptions{ + client := redis.NewClusterClient(&redis.ClusterOptions{ Addrs: config.RdbNodes, // MaxRetries: config.DistributedWorkerConfig.RdbMaxRetry, //连接池容量及闲置连接数量 NewClient: func(opt *redis.Options) *redis.Client { - addr :=opt.Addr + addr := opt.Addr opt = NewRdbConfig(config.DistributedWorkerConfig) opt.Addr = addr return redis.NewClient(opt) @@ -86,8 +86,8 @@ func NewRdbClusterCLient(config *WorkerConfigWithRdbCluster) *redis.ClusterClien RouteByLatency: true, RouteRandomly: true, }) - err:=client.Ping(context.TODO()).Err() - if err!=nil{ + err := client.Ping(context.TODO()).Err() + if err != nil { panic(err) } return client diff --git a/rdb_test.go b/rdb_test.go index 778f322..8877207 100644 --- a/rdb_test.go +++ b/rdb_test.go @@ -40,10 +40,10 @@ func TestRdbClusterCLient(t *testing.T) { }) convey.Convey("test rdb connect error", t, func() { mockRedis := miniredis.RunT(t) - addr:=mockRedis.Addr() + addr := mockRedis.Addr() mockRedis.Close() f := func() { - NewDistributedWorker(addr, NewDistributedWorkerConfig("","",0)) + NewDistributedWorker(addr, NewDistributedWorkerConfig("", "", 0)) } convey.So(f, convey.ShouldPanic) }) diff --git a/request.go b/request.go index fbfbd02..2c662b8 100644 --- a/request.go +++ b/request.go @@ -35,6 +35,7 @@ import ( "github.com/mitchellh/mapstructure" "github.com/sirupsen/logrus" ) + // Proxy 代理数据结构 type Proxy struct { // ProxyUrl 代理链接 @@ -44,35 +45,35 @@ type Proxy struct { // Request 请求对象的结构 type Request struct { // Url 请求Url - Url string `json:"url"` - // Header 请求头 - Header map[string]string `json:"header"` - // Method 请求方式 - Method RequestMethod `json:"method"` - // Body 请求body - Body []byte `json:"body"` - // Params 请求url的参数 - Params map[string]string `json:"params"` - // Proxy 代理实例 - Proxy *Proxy `json:"-"` - // Cookies 请求携带的cookies - Cookies map[string]string `json:"cookies"` - // Meta 请求携带的额外的信息 - Meta map[string]interface{} `json:"meta"` - // AllowRedirects 是否允许跳转默认允许 - AllowRedirects bool `json:"allowRedirects"` - // MaxRedirects 最大的跳转次数 - MaxRedirects int `json:"maxRedirects"` - // Parser 该请求绑定的响应解析函数,必须是一个spider实例 - Parser Parser `json:"-"` - // MaxConnsPerHost 单个域名最大的连接数 - MaxConnsPerHost int `json:"maxConnsPerHost"` + Url string `json:"url"` + // Header 请求头 + Header map[string]string `json:"header"` + // Method 请求方式 + Method RequestMethod `json:"method"` + // Body 请求body + Body []byte `json:"body"` + // Params 请求url的参数 + Params map[string]string `json:"params"` + // Proxy 代理实例 + Proxy *Proxy `json:"-"` + // Cookies 请求携带的cookies + Cookies map[string]string `json:"cookies"` + // Meta 请求携带的额外的信息 + Meta map[string]interface{} `json:"meta"` + // AllowRedirects 是否允许跳转默认允许 + AllowRedirects bool `json:"allowRedirects"` + // MaxRedirects 最大的跳转次数 + MaxRedirects int `json:"maxRedirects"` + // Parser 该请求绑定的响应解析函数,必须是一个spider实例 + Parser Parser `json:"-"` + // MaxConnsPerHost 单个域名最大的连接数 + MaxConnsPerHost int `json:"maxConnsPerHost"` // BodyReader 用于读取body - BodyReader io.Reader `json:"-"` - // ResponseWriter 响应读取到本地的接口 - ResponseWriter io.Writer `json:"-"` - // AllowStatusCode 允许的状态码 - AllowStatusCode []uint64 `json:"allowStatusCode"` + BodyReader io.Reader `json:"-"` + // ResponseWriter 响应读取到本地的接口 + ResponseWriter io.Writer `json:"-"` + // AllowStatusCode 允许的状态码 + AllowStatusCode []uint64 `json:"allowStatusCode"` // Timeout 请求超时时间 Timeout time.Duration `json:"timeout"` } @@ -100,6 +101,7 @@ var bufferPool *sync.Pool = &sync.Pool{ // reqLog request logger var reqLog *logrus.Entry = GetLogger("request") + // RequestWithRequestBody 传入请求体到request func RequestWithRequestBody(body map[string]interface{}) RequestOption { return func(r *Request) { @@ -113,42 +115,49 @@ func RequestWithRequestBody(body map[string]interface{}) RequestOption { } } } + // RequestWithRequestBytesBody request绑定bytes body func RequestWithRequestBytesBody(body []byte) RequestOption { return func(r *Request) { r.Body = body } } + // RequestWithRequestParams 设置请求的url参数 func RequestWithRequestParams(params map[string]string) RequestOption { return func(r *Request) { r.Params = params } } + // RequestWithRequestProxy 设置代理 func RequestWithRequestProxy(proxy Proxy) RequestOption { return func(r *Request) { r.Proxy = &proxy } } + // RequestWithRequestHeader 设置请求头 func RequestWithRequestHeader(header map[string]string) RequestOption { return func(r *Request) { r.Header = header } } + // RequestWithRequestCookies 设置cookie func RequestWithRequestCookies(cookies map[string]string) RequestOption { return func(r *Request) { r.Cookies = cookies } } + // RequestWithRequestMeta 设置 meta func RequestWithRequestMeta(meta map[string]interface{}) RequestOption { return func(r *Request) { r.Meta = meta } } + // RequestWithAllowRedirects 设置是否允许跳转 // 如果不允许则MaxRedirects=0 func RequestWithAllowRedirects(allowRedirects bool) RequestOption { @@ -159,42 +168,48 @@ func RequestWithAllowRedirects(allowRedirects bool) RequestOption { } } } + // RequestWithMaxRedirects 设置最大的跳转次数 // 若maxRedirects <= 0则认为不允许跳转AllowRedirects = false func RequestWithMaxRedirects(maxRedirects int) RequestOption { return func(r *Request) { if maxRedirects <= 0 { r.AllowRedirects = false - }else{ + } else { r.MaxRedirects = maxRedirects r.AllowRedirects = true } } } + // RequestWithResponseWriter 设置ResponseWriter func RequestWithResponseWriter(write io.Writer) RequestOption { return func(r *Request) { r.ResponseWriter = write } } + // RequestWithMaxConnsPerHost 设置MaxConnsPerHost func RequestWithMaxConnsPerHost(maxConnsPerHost int) RequestOption { return func(r *Request) { r.MaxConnsPerHost = maxConnsPerHost } } + // RequestWithAllowedStatusCode 设置AllowStatusCode func RequestWithAllowedStatusCode(allowStatusCode []uint64) RequestOption { return func(r *Request) { r.AllowStatusCode = allowStatusCode } } + // RequestWithParser 设置Parser func RequestWithParser(parser Parser) RequestOption { return func(r *Request) { r.Parser = parser } } + // RequestWithTimeout 设置请求超时时间 // 若timeout<=0则认为没有超时时间 func RequestWithTimeout(timeout time.Duration) RequestOption { @@ -202,6 +217,7 @@ func RequestWithTimeout(timeout time.Duration) RequestOption { r.Timeout = timeout } } + // updateQueryParams 将Params配置到url func (r *Request) updateQueryParams() { defer func() { @@ -269,6 +285,7 @@ func freeRequest(r *Request) { r = nil } + // ToMap 将request对象转为map func (r *Request) ToMap() (map[string]interface{}, error) { b, err := json.Marshal(r) @@ -280,6 +297,7 @@ func (r *Request) ToMap() (map[string]interface{}, error) { return m, err } + // RequestFromMap 从map创建requests func RequestFromMap(src map[string]interface{}, opts ...RequestOption) *Request { request := requestPool.Get().(*Request) diff --git a/response.go b/response.go index 8bd2664..5c34125 100644 --- a/response.go +++ b/response.go @@ -33,17 +33,17 @@ import ( // Response 请求响应体的结构 type Response struct { // Status状态码 - Status int + Status int // Header 响应头 - Header map[string][]string // Header response header + Header map[string][]string // Header response header // Delay 请求延迟 - Delay float64 // Delay the time of handle download request + Delay float64 // Delay the time of handle download request // ContentLength 响应体大小 - ContentLength uint64 // ContentLength response content length + ContentLength uint64 // ContentLength response content length // URL 请求url - URL string // URL of request url + URL string // URL of request url // Buffer 响应体缓存 - Buffer *bytes.Buffer // buffer read response buffer + Buffer *bytes.Buffer // buffer read response buffer } // responsePool Response 对象内存池 @@ -55,15 +55,15 @@ var responsePool *sync.Pool = &sync.Pool{ var respLog *logrus.Entry = GetLogger("response") // Json 将响应数据转为json -func (r *Response) Json() (map[string]interface{},error) { +func (r *Response) Json() (map[string]interface{}, error) { jsonResp := map[string]interface{}{} err := json.Unmarshal(r.Buffer.Bytes(), &jsonResp) if err != nil { respLog.Errorf("Get json response error %s", err.Error()) - + return nil, err } - return jsonResp,nil + return jsonResp, nil } // String 将响应数据转为string diff --git a/settings.go b/settings.go index e525d2e..bd42ee4 100644 --- a/settings.go +++ b/settings.go @@ -33,10 +33,9 @@ import ( type Settings interface { // GetValue 获取指定的参数值 - GetValue(key string) (interface{},error) + GetValue(key string) (interface{}, error) } - type Configuration struct { // Log *Logger `ymal:"log"` *viper.Viper @@ -53,9 +52,9 @@ func newTegenariaConfig() { }) } -func(c *Configuration)GetValue(key string) (interface{},error){ - value:=c.Get(key) - return value,nil +func (c *Configuration) GetValue(key string) (interface{}, error) { + value := c.Get(key) + return value, nil } func (c *Configuration) load(dir string) bool { c.AddConfigPath(dir) diff --git a/settings_test.go b/settings_test.go index 8ec953e..8679778 100644 --- a/settings_test.go +++ b/settings_test.go @@ -48,7 +48,7 @@ func TestSetting(t *testing.T) { convey.So(err, convey.ShouldBeNil) config.SetFs(fs) ret := config.load("/etc/viper") - value,_:=config.GetValue("log.level") + value, _ := config.GetValue("log.level") convey.So(ret, convey.ShouldBeTrue) convey.So(config.GetString("redis.addr"), convey.ShouldContainSubstring, "127.0.0.1") convey.So(config.GetString("log.level"), convey.ShouldContainSubstring, "error") diff --git a/spiders.go b/spiders.go index 88c0fd2..154eb8d 100644 --- a/spiders.go +++ b/spiders.go @@ -46,7 +46,7 @@ type SpiderInterface interface { // GetName 获取spider名称 GetName() string // GetFeedUrls 获取种子urls - GetFeedUrls()[]string + GetFeedUrls() []string } // BaseSpider base spider @@ -57,6 +57,7 @@ type BaseSpider struct { // FeedUrls feed urls FeedUrls []string } + // Spiders 全局spiders管理器 // 用于接收注册的SpiderInterface实例 type Spiders struct { @@ -64,7 +65,7 @@ type Spiders struct { SpidersModules map[string]SpiderInterface // Parsers parser函数名和函数的映射 // 用于序列化和反序列化 - Parsers map[string]Parser + Parsers map[string]Parser } var SpidersList *Spiders @@ -80,7 +81,7 @@ func (s *BaseSpider) StartRequest(req chan<- *Context) { // StartRequest start feed urls request } -// Parser parse request respone +// Parser parse request response // it will send item or new request to engine func (s *BaseSpider) Parser(resp *Context, item chan<- *ItemMeta, req chan<- *Context) error { return nil @@ -89,6 +90,7 @@ func (s *BaseSpider) ErrorHandler(err *HandleError) { // ErrorHandler error handler } + // NewSpiders 构建Spiders实例 func NewSpiders() *Spiders { onceSpiders.Do(func() { @@ -99,6 +101,7 @@ func NewSpiders() *Spiders { }) return SpidersList } + // Register spider实例注册到Spiders.SpidersModules func (s *Spiders) Register(spider SpiderInterface) error { // 爬虫名不能为空 @@ -113,6 +116,7 @@ func (s *Spiders) Register(spider SpiderInterface) error { return nil } } + // GetSpider 通过爬虫名获取spider实例 func (s *Spiders) GetSpider(name string) (SpiderInterface, error) { if _, ok := s.SpidersModules[name]; !ok { diff --git a/spiders_test.go b/spiders_test.go index cf6adcb..5039238 100644 --- a/spiders_test.go +++ b/spiders_test.go @@ -11,7 +11,7 @@ type TestSpider struct { } func (s *TestSpider) StartRequest(req chan<- *Context) { - + for _, url := range s.FeedUrls { request := NewRequest(url, GET, s.Parser) ctx := NewContext(request, s) @@ -21,18 +21,18 @@ func (s *TestSpider) StartRequest(req chan<- *Context) { func (s *TestSpider) Parser(resp *Context, req chan<- *Context) error { return testParser(resp, req) } -func (s *TestSpider) ErrorHandler(err *Context, req chan<- *Context){ +func (s *TestSpider) ErrorHandler(err *Context, req chan<- *Context) { } func (s *TestSpider) GetName() string { return s.Name } -func (s *TestSpider)GetFeedUrls()[]string{ +func (s *TestSpider) GetFeedUrls() []string { return s.FeedUrls } func TestSpiders(t *testing.T) { - convey.Convey("test spiders",t,func(){ + convey.Convey("test spiders", t, func() { spiders := NewSpiders() spider1 := &TestSpider{ NewBaseSpider("testspider", []string{"https://www.baidu.com"}), @@ -49,15 +49,15 @@ func TestSpiders(t *testing.T) { spider5 := &TestSpider{ NewBaseSpider("", []string{"https://www.baidu.com"}), } - err:=spiders.Register(spider1) + err := spiders.Register(spider1) convey.So(err, convey.ShouldBeNil) - + err = spiders.Register(spider2) - convey.So(err, convey.ShouldBeError,ErrDuplicateSpiderName) + convey.So(err, convey.ShouldBeError, ErrDuplicateSpiderName) err = spiders.Register(spider3) convey.So(err, convey.ShouldBeNil) - err =spiders.Register(spider4) + err = spiders.Register(spider4) convey.So(err, convey.ShouldBeNil) spiderNames := []string{"testspider", "testspider1", "testspider2"} for _, spider := range spiderNames { diff --git a/stats.go b/stats.go index 5d65071..7f87ee0 100644 --- a/stats.go +++ b/stats.go @@ -33,8 +33,10 @@ import ( "github.com/go-redis/redis/v8" ) + // StatsFieldType 统计指标的数据类型 type StatsFieldType string + const ( // RequestStats 发起的请求总数 RequestStats StatsFieldType = "requests" @@ -45,6 +47,7 @@ const ( // ErrorStats 错误总数 ErrorStats StatsFieldType = "errors" ) + // StatisticInterface 数据统计组件接口 type StatisticInterface interface { // IncrItemScraped 累加一条item @@ -70,40 +73,43 @@ type StatisticInterface interface { // setCurrentSpider 设置当前正在运行的spider setCurrentSpider(spider string) } + // Statistic 数据统计指标 type Statistic struct { // ItemScraped items总数 - ItemScraped uint64 `json:"items"` + ItemScraped uint64 `json:"items"` // RequestSent 请求总数 - RequestSent uint64 `json:"requets"` + RequestSent uint64 `json:"requets"` // DownloadFail 下载失败总数 DownloadFail uint64 `json:"download_fail"` // ErrorCount 错误总数 - ErrorCount uint64 `json:"errors"` + ErrorCount uint64 `json:"errors"` // spider 当前正在运行的spider名 - spider string `json:"-"` + spider string `json:"-"` } // DistributeStatistic 分布式统计组件 type DistributeStatistic struct { // keyPrefix 缓存key前缀,默认tegenaria:v1:nodes - keyPrefix string + keyPrefix string // nodesKey 节点池的key - nodesKey string + nodesKey string // rdb redis客户端实例 - rdb redis.Cmdable + rdb redis.Cmdable // 调度该组件的wg - wg *sync.WaitGroup + wg *sync.WaitGroup // afterResetTTL 重置数据之前缓存多久 // 默认不缓存这些统计数据 afterResetTTL time.Duration // spider 当前在运行的spider名 - spider string + spider string // fields 所有参与统计的指标 fields []StatsFieldType } + // DistributeStatisticOption 分布式组件可选参数定义 type DistributeStatisticOption func(d *DistributeStatistic) + // NewStatistic 默认统计数据组件构造函数 func NewStatistic() *Statistic { return &Statistic{ @@ -113,14 +119,17 @@ func NewStatistic() *Statistic { ErrorCount: 0, } } + // setCurrentSpider 设置当前的spider func (s *Statistic) setCurrentSpider(spider string) { s.spider = spider } + // IncrItemScraped 累加一条item func (s *Statistic) IncrItemScraped() { atomic.AddUint64(&s.ItemScraped, 1) } + // IncrRequestSent 累加一条request func (s *Statistic) IncrRequestSent() { atomic.AddUint64(&s.RequestSent, 1) @@ -130,6 +139,7 @@ func (s *Statistic) IncrRequestSent() { func (s *Statistic) IncrDownloadFail() { atomic.AddUint64(&s.DownloadFail, 1) } + // IncrErrorCount 累加捕获到的数量 func (s *Statistic) IncrErrorCount() { atomic.AddUint64(&s.ErrorCount, 1) @@ -144,6 +154,7 @@ func (s *Statistic) GetItemScraped() uint64 { func (s *Statistic) GetRequestSent() uint64 { return atomic.LoadUint64(&s.RequestSent) } + // GetDownloadFail 获取下载失败的总数 func (s *Statistic) GetDownloadFail() uint64 { return atomic.LoadUint64(&s.DownloadFail) @@ -161,6 +172,7 @@ func (s *Statistic) OutputStats() map[string]uint64 { _ = json.Unmarshal(b, &result) return result } + // Reset 重置统计数据 func (s *Statistic) Reset() error { atomic.StoreUint64(&s.DownloadFail, 0) @@ -169,10 +181,12 @@ func (s *Statistic) Reset() error { atomic.StoreUint64(&s.ErrorCount, 0) return nil } + // setCurrentSpider 设置当前的spider名 func (s *DistributeStatistic) setCurrentSpider(spider string) { s.spider = spider } + // NewDistributeStatistic 分布式数据统计组件构造函数 func NewDistributeStatistic(statsPrefixKey string, rdb redis.Cmdable, wg *sync.WaitGroup, opts ...DistributeStatisticOption) *DistributeStatistic { d := &DistributeStatistic{ @@ -181,13 +195,14 @@ func NewDistributeStatistic(statsPrefixKey string, rdb redis.Cmdable, wg *sync.W rdb: rdb, wg: wg, afterResetTTL: -1 * time.Second, - fields: []StatsFieldType{ItemsStats,RequestStats,DownloadFailStats,ErrorStats}, + fields: []StatsFieldType{ItemsStats, RequestStats, DownloadFailStats, ErrorStats}, } for _, o := range opts { o(d) } return d } + // IncrStats累加指定的统计指标 func (s *DistributeStatistic) IncrStats(field StatsFieldType) { f := func() error { @@ -207,11 +222,13 @@ func (s *DistributeStatistic) IncrRequestSent() { s.IncrStats(RequestStats) } + // IncrErrorCount 累加一条错误 func (s *DistributeStatistic) IncrErrorCount() { s.IncrStats(ErrorStats) } + // GetDownloadFail 累计获取下载失败的数量 func (s *DistributeStatistic) IncrDownloadFail() { s.IncrStats(DownloadFailStats) @@ -227,23 +244,28 @@ func (s *DistributeStatistic) GetStatsField(field StatsFieldType) uint64 { } return uint64(val) } + // GetItemScraped 获取items的值 func (s *DistributeStatistic) GetItemScraped() uint64 { return s.GetStatsField(ItemsStats) } + // GetRequestSent 获取request 量 func (s *DistributeStatistic) GetRequestSent() uint64 { return s.GetStatsField(RequestStats) } + // GetDownloadFail 获取下载失败数 func (s *DistributeStatistic) GetDownloadFail() uint64 { return s.GetStatsField(DownloadFailStats) } + // GetErrorCount 获取错误数 func (s *DistributeStatistic) GetErrorCount() uint64 { return s.GetStatsField(ErrorStats) } + // Reset 重置各项指标 // 若afterResetTTL>0则为每一项指标设置ttl否则直接删除指标 func (s *DistributeStatistic) Reset() error { @@ -266,12 +288,14 @@ func (s *DistributeStatistic) Reset() error { _, err := pipe.Exec(context.TODO()) return err } + // DistributeStatisticAfterResetTTL 为分布式计数器设置重置之前的ttl func DistributeStatisticAfterResetTTL(ttl time.Duration) DistributeStatisticOption { return func(d *DistributeStatistic) { d.afterResetTTL = ttl } } + // OutputStats 格式化输出所有的数据指标 func (s *DistributeStatistic) OutputStats() map[string]uint64 { @@ -279,7 +303,7 @@ func (s *DistributeStatistic) OutputStats() map[string]uint64 { pipe := s.rdb.Pipeline() result := []*redis.StringCmd{} for _, field := range s.fields { - key:=fmt.Sprintf("%s:%s:%s", s.keyPrefix, s.spider, field) + key := fmt.Sprintf("%s:%s:%s", s.keyPrefix, s.spider, field) result = append(result, pipe.Get(context.TODO(), key)) } _, err := pipe.Exec(context.TODO()) diff --git a/stats_test.go b/stats_test.go index 94a1824..74c5847 100644 --- a/stats_test.go +++ b/stats_test.go @@ -9,13 +9,14 @@ import ( "github.com/go-redis/redis/v8" "github.com/smartystreets/goconvey/convey" ) -func newTestStats(mockRedis *miniredis.Miniredis,t *testing.T,opts ...DistributeStatisticOption) *DistributeStatistic{ + +func newTestStats(mockRedis *miniredis.Miniredis, t *testing.T, opts ...DistributeStatisticOption) *DistributeStatistic { rdb := redis.NewClient(&redis.Options{ Addr: mockRedis.Addr(), }) wg := &sync.WaitGroup{} - stats := NewDistributeStatistic("tegenaria:v1:stats", rdb, wg,opts...) + stats := NewDistributeStatistic("tegenaria:v1:stats", rdb, wg, opts...) stats.setCurrentSpider("distributedStatsSpider") worker := NewDistributedWorker(mockRedis.Addr(), NewDistributedWorkerConfig("", "", 0)) @@ -34,9 +35,9 @@ func TestDistributedStats(t *testing.T) { mockRedis := miniredis.RunT(t) defer mockRedis.Close() - stats:=newTestStats(mockRedis,t, DistributeStatisticAfterResetTTL(10*time.Second)) + stats := newTestStats(mockRedis, t, DistributeStatisticAfterResetTTL(10*time.Second)) result := stats.OutputStats() - convey.So(len(result), convey.ShouldBeGreaterThan,0) + convey.So(len(result), convey.ShouldBeGreaterThan, 0) for _, r := range result { convey.So(r, convey.ShouldAlmostEqual, 1) } @@ -60,9 +61,9 @@ func TestDistributedStats(t *testing.T) { convey.Convey("test stats no afterTTL", t, func() { mockRedis := miniredis.RunT(t) defer mockRedis.Close() - stats:=newTestStats(mockRedis,t) + stats := newTestStats(mockRedis, t) result := stats.OutputStats() - convey.So(len(result), convey.ShouldBeGreaterThan,0) + convey.So(len(result), convey.ShouldBeGreaterThan, 0) for _, r := range result { convey.So(r, convey.ShouldAlmostEqual, 1) diff --git a/utils.go b/utils.go index 6ea3de2..88866f0 100644 --- a/utils.go +++ b/utils.go @@ -41,6 +41,7 @@ func GetUUID() string { return uuid } + // GoFunc 协程函数 type GoFunc func() error @@ -53,8 +54,8 @@ func AddGo(wg *sync.WaitGroup, funcs ...GoFunc) <-chan error { wg.Add(1) go func() { defer func() { - if p:=recover();p!=nil{ - ch<-fmt.Errorf("call go funcs paninc %s", p) + if p := recover(); p != nil { + ch <- fmt.Errorf("call go funcs paninc %s", p) } wg.Done() }() @@ -64,6 +65,7 @@ func AddGo(wg *sync.WaitGroup, funcs ...GoFunc) <-chan error { } return ch } + // GetFunctionName 提取解析函数名 func GetFunctionName(fn Parser) string { name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() @@ -71,6 +73,7 @@ func GetFunctionName(fn Parser) string { return strings.ReplaceAll(nodes[len(nodes)-1], "-fm", "") } + // GetParserByName 通过函数名从spider实例中获取解析函数 func GetParserByName(spider SpiderInterface, name string) Parser { return func(resp *Context, req chan<- *Context) error { @@ -84,6 +87,7 @@ func GetParserByName(spider SpiderInterface, name string) Parser { return rets[0].Interface().(error) } } + // GetAllParserMethod 获取spider实例所有的解析函数 func GetAllParserMethod(spider SpiderInterface) map[string]Parser { val := reflect.ValueOf(spider) @@ -102,6 +106,7 @@ func GetAllParserMethod(spider SpiderInterface) map[string]Parser { } return parsers } + // OptimalNumOfHashFunctions计算最优的布隆过滤器哈希函数个数 func OptimalNumOfHashFunctions(n int64, m int64) int64 { // (m / n) * log(2), but avoid truncation due to division! @@ -113,6 +118,7 @@ func OptimalNumOfHashFunctions(n int64, m int64) int64 { func OptimalNumOfBits(n int64, p float64) int64 { return (int64)(-float64(n) * math.Log(p) / (math.Log(2) * math.Log(2))) } + // Map2String 将map转为string func Map2String(m interface{}) string { dataType, _ := json.Marshal(m) @@ -120,6 +126,7 @@ func Map2String(m interface{}) string { return dataString } + // GetMachineIp 获取本机ip func GetMachineIp() string { addrs, err := net.InterfaceAddrs() diff --git a/utils_test.go b/utils_test.go index b4d7c54..d940dc1 100644 --- a/utils_test.go +++ b/utils_test.go @@ -2,20 +2,20 @@ package tegenaria import "testing" -func TestGetParserByName(t *testing.T){ - server:=newTestServer() - spider:=&TestSpider{NewBaseSpider("spiderParser", []string{"http://127.0.0.1" + "/testGET"})} +func TestGetParserByName(t *testing.T) { + server := newTestServer() + spider := &TestSpider{NewBaseSpider("spiderParser", []string{"http://127.0.0.1" + "/testGET"})} request := NewRequest(server.URL+"/testGET", GET, testParser) ctx := NewContext(request, spider) - ch:=make(chan *Context,1) - f:=GetParserByName(spider, "Parser") - f(ctx,ch) - item:=<-ctx.Items - it:=item.Item.(*testItem) - if it.test !="test"{ + ch := make(chan *Context, 1) + f := GetParserByName(spider, "Parser") + f(ctx, ch) + item := <-ctx.Items + it := item.Item.(*testItem) + if it.test != "test" { t.Error("call parser func fail") } - + close(ch) }