Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

service/s3/s3manager: Add retry download object part body #843

Merged
merged 1 commit into from
Sep 20, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 73 additions & 35 deletions service/s3/s3manager/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"sync"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
Expand Down Expand Up @@ -103,6 +104,10 @@ func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *D
return d
}

type maxRetrier interface {
MaxRetries() int
}

// Download downloads an object in S3 and writes the payload into w using
// concurrent GET requests.
//
Expand All @@ -121,6 +126,19 @@ func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ..
option(&impl.ctx)
}

if s, ok := d.S3.(maxRetrier); ok {
impl.partBodyMaxRetries = s.MaxRetries()
}

impl.totalBytes = -1
if impl.ctx.Concurrency == 0 {
impl.ctx.Concurrency = DefaultDownloadConcurrency
}

if impl.ctx.PartSize == 0 {
impl.ctx.PartSize = DefaultDownloadPartSize
}

return impl.download()
}

Expand All @@ -138,26 +156,13 @@ type downloader struct {
totalBytes int64
written int64
err error
}

// init initializes the downloader with default options.
func (d *downloader) init() {
d.totalBytes = -1

if d.ctx.Concurrency == 0 {
d.ctx.Concurrency = DefaultDownloadConcurrency
}

if d.ctx.PartSize == 0 {
d.ctx.PartSize = DefaultDownloadPartSize
}
partBodyMaxRetries int
}

// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
d.init()

// Spin off first worker to check additional header information
d.getChunk()

Expand Down Expand Up @@ -214,49 +219,82 @@ func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()
for {
chunk, ok := <-ch
if !ok {
if !ok || d.getErr() != nil {
break
}

if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
break
}
d.downloadChunk(chunk)
}
}

// getChunk grabs a chunk of data from the body.
// Not thread safe. Should only used when grabbing data on a single thread.
func (d *downloader) getChunk() {
if d.getErr() != nil {
return
}

chunk := dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
d.downloadChunk(chunk)

if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
}

// downloadChunk downloads the chunk froom s3
func (d *downloader) downloadChunk(chunk dlchunk) {
if d.getErr() != nil {
return
}
// Get the next byte range of data
func (d *downloader) downloadChunk(chunk dlchunk) error {
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)

// Get the next byte range of data
rng := fmt.Sprintf("bytes=%d-%d", chunk.start, chunk.start+chunk.size-1)
in.Range = &rng

req, resp := d.ctx.S3.GetObjectRequest(in)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
err := req.Send()
var n int64
var err error
for retry := 0; retry <= d.partBodyMaxRetries; retry++ {
req, resp := d.ctx.S3.GetObjectRequest(in)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))

if err != nil {
d.setErr(err)
} else {
err = req.Send()
if err != nil {
return err
}
d.setTotalBytes(resp) // Set total if not yet set.

n, err := io.Copy(&chunk, resp.Body)
n, err = io.Copy(&chunk, resp.Body)
resp.Body.Close()

if err != nil {
d.setErr(err)
if err == nil {
break
}
d.incrWritten(n)

chunk.cur = 0
logMessage(d.ctx.S3, aws.LogDebugWithRequestRetries,
fmt.Sprintf("DEBUG: object part body download interrupted %s, err, %v, retrying attempt %d",
aws.StringValue(in.Key), err, retry))
}

d.incrWritten(n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then here upon for exit

return err


return err
}

func logMessage(svc s3iface.S3API, level aws.LogLevelType, msg string) {
s, ok := svc.(*s3.S3)
if !ok {
return
}

if s.Config.Logger == nil {
return
}

if s.Config.LogLevel.Matches(level) {
s.Config.Logger.Log(msg)
}
}

Expand Down
118 changes: 118 additions & 0 deletions service/s3/s3manager/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package s3manager_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"regexp"
Expand Down Expand Up @@ -151,6 +152,37 @@ func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.S3, *[]str
return svc, &names
}

func dlLoggingSvcWithErrReader(cases []testErrReader) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
var index int = 0

svc := s3.New(unit.Session, &aws.Config{
MaxRetries: aws.Int(len(cases) - 1),
})
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()

names = append(names, r.Operation.Name)

c := cases[index]

r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(&c),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range",
fmt.Sprintf("bytes %d-%d/%d", 0, c.Len-1, c.Len))
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", c.Len))
index++
})

return svc, &names
}

func TestDownloadOrder(t *testing.T) {
s, names, ranges := dlLoggingSvc(buf12MB)

Expand Down Expand Up @@ -307,3 +339,89 @@ func TestDownloadContentRangeTotalAny(t *testing.T) {
}
assert.Equal(t, 0, count)
}

func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
{Buf: []byte("123"), Len: 3, Err: io.EOF},
})

d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})

w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})

assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
assert.Equal(t, []byte("123"), w.Bytes())
}

func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("abc"), Len: 3, Err: io.EOF},
})

d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})

w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})

assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []byte("abc"), w.Bytes())
}

func TestDownloadPartBodyRetry_FailRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
})

d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})

w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})

assert.Error(t, err)
assert.Equal(t, int64(2), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []byte("ab"), w.Bytes())
}

type testErrReader struct {
Buf []byte
Err error
Len int64

off int
}

func (r *testErrReader) Read(p []byte) (int, error) {
to := len(r.Buf) - r.off

n := copy(p, r.Buf[r.off:to])
r.off += n

if n < len(p) {
return n, r.Err

}

return n, nil
}