diff --git a/.gitignore b/.gitignore index 47140b99b..306ffbd83 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ _testmain.go *.exe /.vscode/ +/debug.test /generator /cluster-test/cluster-test /cluster-test/*.log diff --git a/bulk.go b/bulk.go index 3a6b63ea6..6f6fcd042 100644 --- a/bulk.go +++ b/bulk.go @@ -26,7 +26,8 @@ import ( // See https://www.elastic.co/guide/en/elasticsearch/reference/6.0/docs-bulk.html // for more details. type BulkService struct { - client *Client + client *Client + retrier Retrier index string typ string @@ -57,6 +58,13 @@ func (s *BulkService) reset() { s.sizeInBytesCursor = 0 } +// Retrier allows to set specific retry logic for this BulkService. +// If not specified, it will use the client's default retrier. +func (s *BulkService) Retrier(retrier Retrier) *BulkService { + s.retrier = retrier + return s +} + // Index specifies the index to use for all batches. You may also leave // this blank and specify the index in the individual bulk requests. func (s *BulkService) Index(index string) *BulkService { @@ -241,6 +249,7 @@ func (s *BulkService) Do(ctx context.Context) (*BulkResponse, error) { Params: params, Body: body, ContentType: "application/x-ndjson", + Retrier: s.retrier, }) if err != nil { return nil, err diff --git a/bulk_test.go b/bulk_test.go index 2aaf62915..2bfa75d26 100644 --- a/bulk_test.go +++ b/bulk_test.go @@ -482,6 +482,30 @@ func TestBulkEstimateSizeInBytesLength(t *testing.T) { } } +func TestBulkContentType(t *testing.T) { + var header http.Header + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header = r.Header + fmt.Fprintln(w, `{}`) + })) + defer ts.Close() + + client, err := NewSimpleClient(SetURL(ts.URL)) + if err != nil { + t.Fatal(err) + } + indexReq := NewBulkIndexRequest().Index(testIndexName).Type("doc").Id("1").Doc(tweet{User: "olivere", Message: "Welcome to Golang and Elasticsearch."}) + if _, err := client.Bulk().Add(indexReq).Do(context.Background()); err != nil { + t.Fatal(err) + } + if header == nil { + t.Fatalf("expected header, got %v", header) + } + if want, have := "application/x-ndjson", header.Get("Content-Type"); want != have { + t.Fatalf("Content-Type: want %q, have %q", want, have) + } +} + var benchmarkBulkEstimatedSizeInBytes int64 func BenchmarkBulkEstimatedSizeInBytesWith1Request(b *testing.B) { @@ -516,30 +540,6 @@ func BenchmarkBulkEstimatedSizeInBytesWith100Requests(b *testing.B) { benchmarkBulkEstimatedSizeInBytes = result // ensure the compiler doesn't optimize } -func TestBulkContentType(t *testing.T) { - var header http.Header - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - header = r.Header - fmt.Fprintln(w, `{}`) - })) - defer ts.Close() - - client, err := NewSimpleClient(SetURL(ts.URL)) - if err != nil { - t.Fatal(err) - } - indexReq := NewBulkIndexRequest().Index(testIndexName).Type("doc").Id("1").Doc(tweet{User: "olivere", Message: "Welcome to Golang and Elasticsearch."}) - if _, err := client.Bulk().Add(indexReq).Do(context.Background()); err != nil { - t.Fatal(err) - } - if header == nil { - t.Fatalf("expected header, got %v", header) - } - if want, have := "application/x-ndjson", header.Get("Content-Type"); want != have { - t.Fatalf("Content-Type: want %q, have %q", want, have) - } -} - func BenchmarkBulkAllocs(b *testing.B) { b.Run("1000 docs with 64 byte", func(b *testing.B) { benchmarkBulkAllocs(b, 64, 1000) }) b.Run("1000 docs with 1 KiB", func(b *testing.B) { benchmarkBulkAllocs(b, 1024, 1000) }) diff --git a/client.go b/client.go index b82d1d1bf..fb53514d1 100644 --- a/client.go +++ b/client.go @@ -26,7 +26,7 @@ import ( const ( // Version is the current version of Elastic. - Version = "6.1.0" + Version = "6.1.1" // DefaultURL is the default endpoint of Elasticsearch on the local machine. // It is used e.g. when initializing a new Client without a specific URL. @@ -1169,6 +1169,7 @@ type PerformRequestOptions struct { Body interface{} ContentType string IgnoreErrors []int + Retrier Retrier } // PerformRequest does a HTTP request to Elasticsearch. @@ -1186,6 +1187,10 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions) basicAuthUsername := c.basicAuthUsername basicAuthPassword := c.basicAuthPassword sendGetBodyAs := c.sendGetBodyAs + retrier := c.retrier + if opt.Retrier != nil { + retrier = opt.Retrier + } c.mu.RUnlock() var err error @@ -1214,7 +1219,7 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions) // Force a healtcheck as all connections seem to be dead. c.healthcheck(timeout, false) } - wait, ok, rerr := c.retrier.Retry(ctx, n, nil, nil, err) + wait, ok, rerr := retrier.Retry(ctx, n, nil, nil, err) if rerr != nil { return nil, rerr } @@ -1270,7 +1275,7 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions) } if err != nil { n++ - wait, ok, rerr := c.retrier.Retry(ctx, n, (*http.Request)(req), res, err) + wait, ok, rerr := retrier.Retry(ctx, n, (*http.Request)(req), res, err) if rerr != nil { c.errorf("elastic: %s is dead", conn.URL()) conn.MarkAsDead() diff --git a/retrier_test.go b/retrier_test.go index 8580ee10f..c1c5ff524 100644 --- a/retrier_test.go +++ b/retrier_test.go @@ -127,3 +127,48 @@ func TestRetrierWithError(t *testing.T) { t.Errorf("expected %d Retrier calls; got: %d", 1, retrier.N) } } + +func TestRetrierOnPerformRequest(t *testing.T) { + var numFailedReqs int + fail := func(r *http.Request) (*http.Response, error) { + numFailedReqs += 1 + //return &http.Response{Request: r, StatusCode: 400}, nil + return nil, errors.New("request failed") + } + + tr := &failingTransport{path: "/fail", fail: fail} + httpClient := &http.Client{Transport: tr} + + defaultRetrier := &testRetrier{ + Retrier: NewStopRetrier(), + } + requestRetrier := &testRetrier{ + Retrier: NewStopRetrier(), + } + + client, err := NewClient( + SetHttpClient(httpClient), + SetHealthcheck(false), + SetRetrier(defaultRetrier)) + if err != nil { + t.Fatal(err) + } + + res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{ + Method: "GET", + Path: "/fail", + Retrier: requestRetrier, + }) + if err == nil { + t.Fatal("expected error") + } + if res != nil { + t.Fatal("expected no response") + } + if want, have := int64(0), defaultRetrier.N; want != have { + t.Errorf("defaultRetrier: expected %d calls; got: %d", want, have) + } + if want, have := int64(1), requestRetrier.N; want != have { + t.Errorf("requestRetrier: expected %d calls; got: %d", want, have) + } +} diff --git a/scroll.go b/scroll.go index 59a779e0d..ac51a8c00 100644 --- a/scroll.go +++ b/scroll.go @@ -23,6 +23,7 @@ const ( // ScrollService iterates over pages of search results from Elasticsearch. type ScrollService struct { client *Client + retrier Retrier indices []string types []string keepAlive string @@ -50,6 +51,13 @@ func NewScrollService(client *Client) *ScrollService { return builder } +// Retrier allows to set specific retry logic for this ScrollService. +// If not specified, it will use the client's default retrier. +func (s *ScrollService) Retrier(retrier Retrier) *ScrollService { + s.retrier = retrier + return s +} + // Index sets the name of one or more indices to iterate over. func (s *ScrollService) Index(indices ...string) *ScrollService { if s.indices == nil { @@ -259,10 +267,11 @@ func (s *ScrollService) Clear(ctx context.Context) error { } _, err := s.client.PerformRequest(ctx, PerformRequestOptions{ - Method: "DELETE", - Path: path, - Params: params, - Body: body, + Method: "DELETE", + Path: path, + Params: params, + Body: body, + Retrier: s.retrier, }) if err != nil { return err @@ -289,10 +298,11 @@ func (s *ScrollService) first(ctx context.Context) (*SearchResult, error) { // Get HTTP response res, err := s.client.PerformRequest(ctx, PerformRequestOptions{ - Method: "POST", - Path: path, - Params: params, - Body: body, + Method: "POST", + Path: path, + Params: params, + Body: body, + Retrier: s.retrier, }) if err != nil { return nil, err @@ -408,10 +418,11 @@ func (s *ScrollService) next(ctx context.Context) (*SearchResult, error) { // Get HTTP response res, err := s.client.PerformRequest(ctx, PerformRequestOptions{ - Method: "POST", - Path: path, - Params: params, - Body: body, + Method: "POST", + Path: path, + Params: params, + Body: body, + Retrier: s.retrier, }) if err != nil { return nil, err