Skip to content

Commit

Permalink
Modify and Merge protocol test request unit tests codegen logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianyi Wang committed Aug 11, 2023
1 parent 341bcbe commit 9a47e68
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.function.Consumer;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SymbolUtils;
Expand Down Expand Up @@ -195,36 +196,14 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC
*
* @param writer writer to write generated code with.
*/
protected void generateTestBodySetup(GoWriter writer) {
writer.write("var actualReq *http.Request");
}
protected void generateTestBodySetup(GoWriter writer) {}

/**
* Hook to generate the HTTP response body of the protocol test.
*
* @param writer writer to write generated code with.
*/
protected void generateTestServerHandler(GoWriter writer) {
writer.write("actualReq = r.Clone(r.Context())");
// Go does not set RawPath on http server if nothing is escaped
writer.openBlock("if len(actualReq.URL.RawPath) == 0 {", "}", () -> {
writer.write("actualReq.URL.RawPath = actualReq.URL.Path");
});
// Go automatically removes Content-Length header setting it to the member.
writer.addUseImports(SmithyGoDependency.STRCONV);
writer.openBlock("if v := actualReq.ContentLength; v != 0 {", "}", () -> {
writer.write("actualReq.Header.Set(\"Content-Length\", strconv.FormatInt(v, 10))");
});

writer.addUseImports(SmithyGoDependency.BYTES);
writer.write("var buf bytes.Buffer");
writer.openBlock("if _, err := io.Copy(&buf, r.Body); err != nil {", "}", () -> {
writer.write("t.Errorf(\"failed to read request body, %v\", err)");
});
writer.addUseImports(SmithyGoDependency.IOUTIL);
writer.write("actualReq.Body = ioutil.NopCloser(&buf)");
writer.write("");

super.generateTestServerHandler(writer);
}

Expand All @@ -236,8 +215,19 @@ protected void generateTestServerHandler(GoWriter writer) {
*/
@Override
protected void generateTestInvokeClientOperation(GoWriter writer, String clientName) {
Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack",
SmithyGoDependency.SMITHY_MIDDLEWARE).build();
writer.addUseImports(SmithyGoDependency.CONTEXT);
writer.write("result, err := $L.$T(context.Background(), c.Params)", clientName, opSymbol);
writer.write("capturedReq := &http.Request{}");
writer.openBlock("result, err := $L.$T(context.Background(), c.Params, func(options *Options) {", "})",
clientName, opSymbol, () -> {
writer.openBlock("options.APIOptions = append(options.APIOptions, func(stack $P) error {", "})",
stackSymbol, () -> {
writer.write("return $T(stack, capturedReq)",
SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware",
SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build());
});
});
}

/**
Expand All @@ -250,20 +240,21 @@ protected void generateTestAssertions(GoWriter writer) {
writeAssertNil(writer, "err");
writeAssertNotNil(writer, "result");

writeAssertScalarEqual(writer, "c.ExpectMethod", "actualReq.Method", "method");
writeAssertScalarEqual(writer, "c.ExpectURIPath", "actualReq.URL.RawPath", "path");
writeAssertScalarEqual(writer, "c.ExpectMethod", "capturedReq.Method", "method");
writeAssertScalarEqual(writer, "c.ExpectURIPath", "capturedReq.URL.RawPath", "path");

writeQueryItemBreakout(writer, "capturedReq.URL.RawQuery", "queryItems");

writeQueryItemBreakout(writer, "actualReq.URL.RawQuery", "queryItems");
writeAssertHasQuery(writer, "c.ExpectQuery", "queryItems");
writeAssertRequireQuery(writer, "c.RequireQuery", "queryItems");
writeAssertForbidQuery(writer, "c.ForbidQuery", "queryItems");

writeAssertHasHeader(writer, "c.ExpectHeader", "actualReq.Header");
writeAssertRequireHeader(writer, "c.RequireHeader", "actualReq.Header");
writeAssertForbidHeader(writer, "c.ForbidHeader", "actualReq.Header");
writeAssertHasHeader(writer, "c.ExpectHeader", "capturedReq.Header");
writeAssertRequireHeader(writer, "c.RequireHeader", "capturedReq.Header");
writeAssertForbidHeader(writer, "c.ForbidHeader", "capturedReq.Header");

writer.openBlock("if c.BodyAssert != nil {", "}", () -> {
writer.openBlock("if err := c.BodyAssert(actualReq.Body); err != nil {", "}", () -> {
writer.openBlock("if err := c.BodyAssert(capturedReq.Body); err != nil {", "}", () -> {
writer.write("t.Errorf(\"expect body equal, got %v\", err)");
});
});
Expand All @@ -282,7 +273,8 @@ protected void generateTestServer(
String name,
Consumer<GoWriter> handler
) {
super.generateTestServer(writer, name, handler);
// We aren't using a test server, but we do need a URL to set.
writer.write("serverURL := \"http://localhost:8888/\"");
writer.pushState();
writer.putContext("parse", SymbolUtils.createValueSymbolBuilder("Parse", SmithyGoDependency.NET_URL)
.build());
Expand Down
46 changes: 46 additions & 0 deletions transport/http/middleware_capture_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package http

import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
"net/http"
"strconv"
)

const captureRequestID = "CaptureProtocolTestRequest"

// AddCaptureRequestMiddleware captures serialized http request during protocol test for check
func AddCaptureRequestMiddleware(stack *middleware.Stack, req *http.Request) error {
return stack.Build.Add(&captureRequestMiddleware{
req: req,
}, middleware.After)
}

type captureRequestMiddleware struct {
req *http.Request
}

func (*captureRequestMiddleware) ID() string {
return captureRequestID
}

func (m *captureRequestMiddleware) HandleBuild(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler,
) (
output middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
request, ok := input.Request.(*Request)
if !ok {
return output, metadata, fmt.Errorf("error while retrieving http request")
}

*m.req = *request.Build(ctx)
if len(m.req.URL.RawPath) == 0 {
m.req.URL.RawPath = m.req.URL.Path
}
if v := m.req.ContentLength; v != 0 {
m.req.Header.Set("Content-Length", strconv.FormatInt(v, 10))
}

return next.HandleBuild(ctx, input)
}
109 changes: 109 additions & 0 deletions transport/http/middleware_capture_request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package http

import (
"context"
"github.com/aws/smithy-go/middleware"
smithytesting "github.com/aws/smithy-go/testing"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
)

// TestAddCaptureRequestMiddleware tests AddCaptureRequestMiddleware
func TestAddCaptureRequestMiddleware(t *testing.T) {
cases := map[string]struct {
Request *http.Request
ExpectRequest *http.Request
ExpectQuery []smithytesting.QueryItem
Stream io.Reader
}{
"normal request": {
Request: &http.Request{
Method: "PUT",
Header: map[string][]string{
"Foo": {"bar", "too"},
"Checksum": {"SHA256"},
},
URL: &url.URL{
Path: "test/path",
RawQuery: "language=us&region=us-west+east",
},
ContentLength: 100,
},
ExpectRequest: &http.Request{
Method: "PUT",
Header: map[string][]string{
"Foo": {"bar", "too"},
"Checksum": {"SHA256"},
"Content-Length": {"100"},
},
URL: &url.URL{
Path: "test/path",
RawPath: "test/path",
},
Body: io.NopCloser(strings.NewReader("hello world.")),

Check failure on line 47 in transport/http/middleware_capture_request_test.go

View workflow job for this annotation

GitHub Actions / Deprecated Go version SDK Unit Tests (ubuntu-latest, 1.15)

undefined: "io".NopCloser
},
ExpectQuery: []smithytesting.QueryItem{
{
Key: "language",
Value: "us",
},
{
Key: "region",
Value: "us-west%20east",
},
},
Stream: strings.NewReader("hello world."),
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
var err error
req := &Request{
Request: c.Request,
stream: c.Stream,
}
capturedRequest := &http.Request{}
m := captureRequestMiddleware{
req: capturedRequest,
}
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
out middleware.BuildOutput, metadata middleware.Metadata, err error) {
return out, metadata, nil
}),
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := c.ExpectRequest.Method, capturedRequest.Method; e != a {
t.Errorf("expect request method %v found, got %v", e, a)
}
if e, a := c.ExpectRequest.URL.Path, capturedRequest.URL.RawPath; e != a {
t.Errorf("expect %v path, got %v", e, a)
}
if c.ExpectRequest.Body != nil {
expect, err := io.ReadAll(c.ExpectRequest.Body)

Check failure on line 92 in transport/http/middleware_capture_request_test.go

View workflow job for this annotation

GitHub Actions / Deprecated Go version SDK Unit Tests (ubuntu-latest, 1.15)

undefined: "io".ReadAll
if capturedRequest.Body == nil {
t.Errorf("Expect request stream %v captured, get nil", string(expect))
}
actual, err := ioutil.ReadAll(capturedRequest.Body)
if err != nil {
t.Errorf("unable to read captured request body, %v", err)
}
if e, a := string(expect), string(actual); e != a {
t.Errorf("expect request body to be %s, got %s", e, a)
}
}
queryItems := smithytesting.ParseRawQuery(capturedRequest.URL.RawQuery)
smithytesting.AssertHasQuery(t, c.ExpectQuery, queryItems)
smithytesting.AssertHasHeader(t, c.ExpectRequest.Header, capturedRequest.Header)
})
}
}

0 comments on commit 9a47e68

Please sign in to comment.