Skip to content

Commit

Permalink
Allow custom Retry/Backoff per request
Browse files Browse the repository at this point in the history
This commit extends the `PerformRequestOptions` to pass a custom
`Retrier` per request. This is enabled for the Scroll and Bulk API for
now.

See olivere#666 and olivere#610
  • Loading branch information
olivere committed Jan 8, 2018
1 parent 832286e commit d2219c2
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 40 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ _testmain.go
*.exe

/.vscode/
/debug.test
/generator
/cluster-test/cluster-test
/cluster-test/*.log
Expand Down
11 changes: 10 additions & 1 deletion bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import (
// See https://www.elastic.co/guide/en/elasticsearch/reference/6.0/docs-bulk.html
// for more details.
type BulkService struct {
client *Client
client *Client
retrier Retrier

index string
typ string
Expand Down Expand Up @@ -57,6 +58,13 @@ func (s *BulkService) reset() {
s.sizeInBytesCursor = 0
}

// Retrier allows to set specific retry logic for this BulkService.
// If not specified, it will use the client's default retrier.
func (s *BulkService) Retrier(retrier Retrier) *BulkService {
s.retrier = retrier
return s
}

// Index specifies the index to use for all batches. You may also leave
// this blank and specify the index in the individual bulk requests.
func (s *BulkService) Index(index string) *BulkService {
Expand Down Expand Up @@ -241,6 +249,7 @@ func (s *BulkService) Do(ctx context.Context) (*BulkResponse, error) {
Params: params,
Body: body,
ContentType: "application/x-ndjson",
Retrier: s.retrier,
})
if err != nil {
return nil, err
Expand Down
48 changes: 24 additions & 24 deletions bulk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,30 @@ func TestBulkEstimateSizeInBytesLength(t *testing.T) {
}
}

func TestBulkContentType(t *testing.T) {
var header http.Header
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header = r.Header
fmt.Fprintln(w, `{}`)
}))
defer ts.Close()

client, err := NewSimpleClient(SetURL(ts.URL))
if err != nil {
t.Fatal(err)
}
indexReq := NewBulkIndexRequest().Index(testIndexName).Type("doc").Id("1").Doc(tweet{User: "olivere", Message: "Welcome to Golang and Elasticsearch."})
if _, err := client.Bulk().Add(indexReq).Do(context.Background()); err != nil {
t.Fatal(err)
}
if header == nil {
t.Fatalf("expected header, got %v", header)
}
if want, have := "application/x-ndjson", header.Get("Content-Type"); want != have {
t.Fatalf("Content-Type: want %q, have %q", want, have)
}
}

var benchmarkBulkEstimatedSizeInBytes int64

func BenchmarkBulkEstimatedSizeInBytesWith1Request(b *testing.B) {
Expand Down Expand Up @@ -516,30 +540,6 @@ func BenchmarkBulkEstimatedSizeInBytesWith100Requests(b *testing.B) {
benchmarkBulkEstimatedSizeInBytes = result // ensure the compiler doesn't optimize
}

func TestBulkContentType(t *testing.T) {
var header http.Header
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header = r.Header
fmt.Fprintln(w, `{}`)
}))
defer ts.Close()

client, err := NewSimpleClient(SetURL(ts.URL))
if err != nil {
t.Fatal(err)
}
indexReq := NewBulkIndexRequest().Index(testIndexName).Type("doc").Id("1").Doc(tweet{User: "olivere", Message: "Welcome to Golang and Elasticsearch."})
if _, err := client.Bulk().Add(indexReq).Do(context.Background()); err != nil {
t.Fatal(err)
}
if header == nil {
t.Fatalf("expected header, got %v", header)
}
if want, have := "application/x-ndjson", header.Get("Content-Type"); want != have {
t.Fatalf("Content-Type: want %q, have %q", want, have)
}
}

func BenchmarkBulkAllocs(b *testing.B) {
b.Run("1000 docs with 64 byte", func(b *testing.B) { benchmarkBulkAllocs(b, 64, 1000) })
b.Run("1000 docs with 1 KiB", func(b *testing.B) { benchmarkBulkAllocs(b, 1024, 1000) })
Expand Down
11 changes: 8 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (

const (
// Version is the current version of Elastic.
Version = "6.1.0"
Version = "6.1.1"

// DefaultURL is the default endpoint of Elasticsearch on the local machine.
// It is used e.g. when initializing a new Client without a specific URL.
Expand Down Expand Up @@ -1169,6 +1169,7 @@ type PerformRequestOptions struct {
Body interface{}
ContentType string
IgnoreErrors []int
Retrier Retrier
}

// PerformRequest does a HTTP request to Elasticsearch.
Expand All @@ -1186,6 +1187,10 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions)
basicAuthUsername := c.basicAuthUsername
basicAuthPassword := c.basicAuthPassword
sendGetBodyAs := c.sendGetBodyAs
retrier := c.retrier
if opt.Retrier != nil {
retrier = opt.Retrier
}
c.mu.RUnlock()

var err error
Expand Down Expand Up @@ -1214,7 +1219,7 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions)
// Force a healtcheck as all connections seem to be dead.
c.healthcheck(timeout, false)
}
wait, ok, rerr := c.retrier.Retry(ctx, n, nil, nil, err)
wait, ok, rerr := retrier.Retry(ctx, n, nil, nil, err)
if rerr != nil {
return nil, rerr
}
Expand Down Expand Up @@ -1270,7 +1275,7 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions)
}
if err != nil {
n++
wait, ok, rerr := c.retrier.Retry(ctx, n, (*http.Request)(req), res, err)
wait, ok, rerr := retrier.Retry(ctx, n, (*http.Request)(req), res, err)
if rerr != nil {
c.errorf("elastic: %s is dead", conn.URL())
conn.MarkAsDead()
Expand Down
45 changes: 45 additions & 0 deletions retrier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,48 @@ func TestRetrierWithError(t *testing.T) {
t.Errorf("expected %d Retrier calls; got: %d", 1, retrier.N)
}
}

func TestRetrierOnPerformRequest(t *testing.T) {
var numFailedReqs int
fail := func(r *http.Request) (*http.Response, error) {
numFailedReqs += 1
//return &http.Response{Request: r, StatusCode: 400}, nil
return nil, errors.New("request failed")
}

tr := &failingTransport{path: "/fail", fail: fail}
httpClient := &http.Client{Transport: tr}

defaultRetrier := &testRetrier{
Retrier: NewStopRetrier(),
}
requestRetrier := &testRetrier{
Retrier: NewStopRetrier(),
}

client, err := NewClient(
SetHttpClient(httpClient),
SetHealthcheck(false),
SetRetrier(defaultRetrier))
if err != nil {
t.Fatal(err)
}

res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
Method: "GET",
Path: "/fail",
Retrier: requestRetrier,
})
if err == nil {
t.Fatal("expected error")
}
if res != nil {
t.Fatal("expected no response")
}
if want, have := int64(0), defaultRetrier.N; want != have {
t.Errorf("defaultRetrier: expected %d calls; got: %d", want, have)
}
if want, have := int64(1), requestRetrier.N; want != have {
t.Errorf("requestRetrier: expected %d calls; got: %d", want, have)
}
}
35 changes: 23 additions & 12 deletions scroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
// ScrollService iterates over pages of search results from Elasticsearch.
type ScrollService struct {
client *Client
retrier Retrier
indices []string
types []string
keepAlive string
Expand Down Expand Up @@ -50,6 +51,13 @@ func NewScrollService(client *Client) *ScrollService {
return builder
}

// Retrier allows to set specific retry logic for this ScrollService.
// If not specified, it will use the client's default retrier.
func (s *ScrollService) Retrier(retrier Retrier) *ScrollService {
s.retrier = retrier
return s
}

// Index sets the name of one or more indices to iterate over.
func (s *ScrollService) Index(indices ...string) *ScrollService {
if s.indices == nil {
Expand Down Expand Up @@ -259,10 +267,11 @@ func (s *ScrollService) Clear(ctx context.Context) error {
}

_, err := s.client.PerformRequest(ctx, PerformRequestOptions{
Method: "DELETE",
Path: path,
Params: params,
Body: body,
Method: "DELETE",
Path: path,
Params: params,
Body: body,
Retrier: s.retrier,
})
if err != nil {
return err
Expand All @@ -289,10 +298,11 @@ func (s *ScrollService) first(ctx context.Context) (*SearchResult, error) {

// Get HTTP response
res, err := s.client.PerformRequest(ctx, PerformRequestOptions{
Method: "POST",
Path: path,
Params: params,
Body: body,
Method: "POST",
Path: path,
Params: params,
Body: body,
Retrier: s.retrier,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -408,10 +418,11 @@ func (s *ScrollService) next(ctx context.Context) (*SearchResult, error) {

// Get HTTP response
res, err := s.client.PerformRequest(ctx, PerformRequestOptions{
Method: "POST",
Path: path,
Params: params,
Body: body,
Method: "POST",
Path: path,
Params: params,
Body: body,
Retrier: s.retrier,
})
if err != nil {
return nil, err
Expand Down

0 comments on commit d2219c2

Please sign in to comment.