Skip to content

Commit

Permalink
aws/request: Fix support for streamed payloads for unsigned body requ…
Browse files Browse the repository at this point in the history
…est (aws#1778)

Fixes the SDK's handling of the SDK's `ReaderSeekerCloser` helper type to
not allow erroneous request retries, and request signature generation.

This fix allows you to use the `aws.ReaderSeekerCloser` to wrap an
arbitrary `io.Reader` for request `io.ReadSeeker` input parameters.

APIs such as lex-runtime's PostContent can now make use of the
ReaderSeekerCloser type without causing unexpected failures.

```go
resp, err := svc.PostContent(&lexruntimeservice.PostContentInput{
    BotAlias:    aws.String("botAlias"),
    BotName:     aws.String("botName"),
    ContentType: aws.String("audio/l16; rate=16000; channels=1"),
    UserId:      aws.String("userID"),
    InputStream: aws.ReadSeekCloser(myReader),
})
```

Fix aws#1776
  • Loading branch information
jasdel committed Feb 14, 2018
1 parent e96f92e commit e926b11
Show file tree
Hide file tree
Showing 12 changed files with 440 additions and 92 deletions.
4 changes: 4 additions & 0 deletions aws/client/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,17 @@ func (reader *teeReaderCloser) Close() error {

func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
bodySeekable := aws.IsReaderSeekable(r.Body)
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}

if logBody {
if !bodySeekable {
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
}
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
Expand Down
87 changes: 87 additions & 0 deletions aws/client/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@ package client

import (
"bytes"
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)

type mockCloser struct {
Expand Down Expand Up @@ -55,3 +64,81 @@ func TestLogWriter(t *testing.T) {
t.Errorf("Expected %q, but received %q", expected, lw.buf.String())
}
}

func TestLogRequest(t *testing.T) {
cases := []struct {
Body io.ReadSeeker
ExpectBody []byte
LogLevel aws.LogLevelType
}{
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
ExpectBody: []byte("body content"),
},
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
}

for i, c := range cases {
logW := bytes.NewBuffer(nil)
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.SetReaderBody(c.Body)
req.Build()

logRequest(req)

b, err := ioutil.ReadAll(req.HTTPRequest.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}

if e, a := c.ExpectBody, b; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}

type bufLogger struct {
w *bytes.Buffer
}

func (l *bufLogger) Log(args ...interface{}) {
fmt.Fprintln(l.w, args...)
}

func testHandlers() request.Handlers {
var handlers request.Handlers

handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)

return handlers
}
20 changes: 7 additions & 13 deletions aws/corehandlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package corehandlers
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
Expand Down Expand Up @@ -36,18 +35,13 @@ var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLen
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ = strconv.ParseInt(slength, 10, 64)
} else {
switch body := r.Body.(type) {
case nil:
length = 0
case lener:
length = int64(body.Len())
case io.Seeker:
r.BodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
if r.Body != nil {
var err error
length, err = aws.SeekerLen(r.Body)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to get request body's length", err)
return
}
}
}

Expand Down
45 changes: 7 additions & 38 deletions aws/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ func (r *Request) SetContext(ctx aws.Context) {

// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
if !aws.IsReaderSeekable(r.Body) && r.HTTPRequest.Body != NoBody {
return false
}
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
}

Expand Down Expand Up @@ -255,6 +258,7 @@ func (r *Request) SetStringBody(s string) {
// SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.Body = reader
r.BodyStart, _ = reader.Seek(0, 1) // Get the Bodies current offset.
r.ResetBody()
}

Expand Down Expand Up @@ -393,7 +397,7 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
// of the SDK if they used that field.
//
// Related golang/go#18257
l, err := computeBodyLength(r.Body)
l, err := aws.SeekerLen(r.Body)
if err != nil {
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err)
}
Expand All @@ -411,7 +415,8 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
// Transfer-Encoding: chunked bodies for these methods.
//
// This would only happen if a aws.ReaderSeekerCloser was used with
// a io.Reader that was not also an io.Seeker.
// a io.Reader that was not also an io.Seeker, or did not implement
// Len() method.
switch r.Operation.HTTPMethod {
case "GET", "HEAD", "DELETE":
body = NoBody
Expand All @@ -423,42 +428,6 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
return body, nil
}

// Attempts to compute the length of the body of the reader using the
// io.Seeker interface. If the value is not seekable because of being
// a ReaderSeekerCloser without an unerlying Seeker -1 will be returned.
// If no error occurs the length of the body will be returned.
func computeBodyLength(r io.ReadSeeker) (int64, error) {
seekable := true
// Determine if the seeker is actually seekable. ReaderSeekerCloser
// hides the fact that a io.Readers might not actually be seekable.
switch v := r.(type) {
case aws.ReaderSeekerCloser:
seekable = v.IsSeeker()
case *aws.ReaderSeekerCloser:
seekable = v.IsSeeker()
}
if !seekable {
return -1, nil
}

curOffset, err := r.Seek(0, 1)
if err != nil {
return 0, err
}

endOffset, err := r.Seek(0, 2)
if err != nil {
return 0, err
}

_, err = r.Seek(curOffset, 0)
if err != nil {
return 0, err
}

return endOffset - curOffset, nil
}

// GetBody will return an io.ReadSeeker of the Request's underlying
// input body with a concurrency safe wrapper.
func (r *Request) GetBody() io.ReadSeeker {
Expand Down
71 changes: 60 additions & 11 deletions aws/request/request_resetbody_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package request

import (
"bytes"
"io"
"net/http"
"strings"
"testing"
Expand All @@ -25,30 +26,78 @@ func TestResetBody_WithBodyContents(t *testing.T) {
}
}

func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) {
type mockReader struct{}

func (mockReader) Read([]byte) (int, error) {
return 0, io.EOF
}

func TestResetBody_ExcludeEmptyUnseekableBodyByMethod(t *testing.T) {
cases := []struct {
Method string
Body io.ReadSeeker
IsNoBody bool
}{
{"GET", true},
{"HEAD", true},
{"DELETE", true},
{"PUT", false},
{"PATCH", false},
{"POST", false},
{
Method: "GET",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "HEAD",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "DELETE",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "PUT",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "PATCH",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "POST",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "GET",
IsNoBody: false,
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc"))),
},
{
Method: "GET",
IsNoBody: true,
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
},
{
Method: "POST",
IsNoBody: false,
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc"))),
},
{
Method: "POST",
IsNoBody: true,
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
},
}

reader := aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc")))

for i, c := range cases {
r := Request{
HTTPRequest: &http.Request{},
Operation: &Operation{
HTTPMethod: c.Method,
},
}

r.SetReaderBody(reader)
r.SetReaderBody(c.Body)

if a, e := r.HTTPRequest.Body == NoBody, c.IsNoBody; a != e {
t.Errorf("%d, expect body to be set to noBody(%t), but was %t", i, e, a)
Expand Down
Loading

0 comments on commit e926b11

Please sign in to comment.