Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gzip request compression feature #467

Merged
merged 16 commits into from
Dec 6, 2023
Merged
8 changes: 8 additions & 0 deletions .changelog/80ed28327bcd4301a264f318efaf8216.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "80ed2832-7bcd-4301-a264-f318efaf8216",
"type": "feature",
"description": "Support modeled request compression.",
"modules": [
"."
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public final class SmithyGoDependency {
public static final GoDependency SMITHY_HTTP_TRANSPORT = smithy("transport/http", "smithyhttp");
public static final GoDependency SMITHY_MIDDLEWARE = smithy("middleware");
public static final GoDependency SMITHY_PRIVATE_PROTOCOL = smithy("private/protocol", "smithyprivateprotocol");
public static final GoDependency SMITHY_REQUEST_COMPRESSION =
smithy("private/requestcompression", "smithyrequestcompression");
public static final GoDependency SMITHY_TIME = smithy("time", "smithytime");
public static final GoDependency SMITHY_HTTP_BINDING = smithy("encoding/httpbinding");
public static final GoDependency SMITHY_JSON = smithy("encoding/json", "smithyjson");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,11 @@ public static final class Bearer {
public static final Symbol NewSignHTTPSMessage = SmithyGoDependency.SMITHY_AUTH_BEARER.valueSymbol("NewSignHTTPSMessage");
}
}

public static final class Private {
public static final class RequestCompression {
public static final Symbol AddRequestCompression = SmithyGoDependency.SMITHY_REQUEST_COMPRESSION.valueSymbol("AddRequestCompression");
public static final Symbol AddCaptureUncompressedRequest = SmithyGoDependency.SMITHY_REQUEST_COMPRESSION.valueSymbol("AddCaptureUncompressedRequestMiddleware");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,31 @@

package software.amazon.smithy.go.codegen.integration;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.SmithyGoTypes.Private.RequestCompression.AddCaptureUncompressedRequest;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Consumer;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.model.traits.RequestCompressionTrait;
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase;
import software.amazon.smithy.utils.MapUtils;

/**
* Generates HTTP protocol unit tests for HTTP request test cases.
*/
public class HttpProtocolUnitTestRequestGenerator extends HttpProtocolUnitTestGenerator<HttpRequestTestCase> {
private static final Logger LOGGER = Logger.getLogger(HttpProtocolUnitTestRequestGenerator.class.getName());

private static final Set<String> ALLOWED_ALGORITHMS = new HashSet<>(Arrays.asList("gzip"));

/**
* Initializes the protocol test generator.
*
Expand Down Expand Up @@ -198,6 +209,10 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC
*/
protected void generateTestBodySetup(GoWriter writer) {
writer.write("actualReq := &http.Request{}");
if (operation.hasTrait(RequestCompressionTrait.class)) {
writer.addUseImports(SmithyGoDependency.BYTES);
writer.write("rawBodyBuf := &bytes.Buffer{}");
}
}

/**
Expand Down Expand Up @@ -227,8 +242,29 @@ protected void generateTestInvokeClientOperation(GoWriter writer, String clientN
writer.write("return $T(stack, actualReq)",
SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware",
SmithyGoDependency.SMITHY_PRIVATE_PROTOCOL).build());
});
});
if (operation.hasTrait(RequestCompressionTrait.class)) {
writer.write(goTemplate("""
options.APIOptions = append(options.APIOptions, func(stack $stack:P) error {
return $captureRequest:T(stack, rawBodyBuf)
})
""",
MapUtils.of(
"stack", SmithyGoTypes.Middleware.Stack,
"captureRequest", AddCaptureUncompressedRequest
)));
}
});

if (operation.hasTrait(RequestCompressionTrait.class)) {
writer.write(goTemplate("""
disable := $client:L.Options().DisableRequestCompression
min := $client:L.Options().RequestMinCompressSizeBytes
""",
MapUtils.of(
"client", clientName
)));
}
}

/**
Expand Down Expand Up @@ -259,6 +295,20 @@ protected void generateTestAssertions(GoWriter writer) {
writer.write("t.Errorf(\"expect body equal, got %v\", err)");
});
});

if (operation.hasTrait(RequestCompressionTrait.class)) {
String algorithm = operation.expectTrait(RequestCompressionTrait.class).getEncodings()
.stream().filter(it -> ALLOWED_ALGORITHMS.contains(it)).findFirst().get();
writer.write(goTemplate("""
if err := smithytesting.CompareCompressedBytes(rawBodyBuf, actualReq.Body,
disable, min, $algorithm:S); err != nil {
t.Errorf("unzipped request body not match: %q", err)
}
""",
MapUtils.of(
"algorithm", algorithm
)));
}
}

public static class Builder extends HttpProtocolUnitTestGenerator.Builder<HttpRequestTestCase> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.go.codegen.requestcompression;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;

import java.util.ArrayList;
import java.util.List;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoCodegenPlugin;
import software.amazon.smithy.go.codegen.GoDelegator;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.GoUniverseTypes;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.integration.ConfigField;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.traits.RequestCompressionTrait;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.MapUtils;


public final class RequestCompression implements GoIntegration {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit overall: private fields should go below public fields, and private methods should go towards the bottom of the class below all the public methods.

private static final String DISABLE_REQUEST_COMPRESSION = "DisableRequestCompression";

private static final String REQUEST_MIN_COMPRESSION_SIZE_BYTES = "RequestMinCompressSizeBytes";

private final List<RuntimeClientPlugin> runtimeClientPlugins = new ArrayList<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping state in an integration is generally not best practice.


// Write operation plugin for request compression middleware
@Override
public void processFinalizedModel(GoSettings settings, Model model) {
ServiceShape service = settings.getService(model);
TopDownIndex.of(model)
.getContainedOperations(service).forEach(operation -> {
if (!operation.hasTrait(RequestCompressionTrait.class)) {
return;
}
SymbolProvider symbolProvider = GoCodegenPlugin.createSymbolProvider(model, settings);
String funcName = getAddRequestCompressionMiddlewareFuncName(
symbolProvider.toSymbol(operation).getName()
);
runtimeClientPlugins.add(RuntimeClientPlugin.builder().operationPredicate((m, s, o) -> {
if (!o.hasTrait(RequestCompressionTrait.class)) {
return false;
}
return o.equals(operation);
syall marked this conversation as resolved.
Show resolved Hide resolved
}).registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.buildPackageSymbol(funcName))
.useClientOptions().build())
.build());
});
}

@Override
public void writeAdditionalFiles(
GoSettings settings,
Model model,
SymbolProvider symbolProvider,
GoDelegator goDelegator
) {
ServiceShape service = settings.getService(model);
for (ShapeId operationID : service.getAllOperations()) {
OperationShape operation = model.expectShape(operationID, OperationShape.class);
if (!operation.hasTrait(RequestCompressionTrait.class)) {
continue;
}
goDelegator.useShapeWriter(operation, writeMiddlewareHelper(symbolProvider, operation));
}
}


public static boolean isRequestCompressionService(Model model, ServiceShape service) {
return TopDownIndex.of(model)
.getContainedOperations(service).stream()
.anyMatch(it -> it.hasTrait(RequestCompressionTrait.class));
}

@Override
public List<RuntimeClientPlugin> getClientPlugins() {
runtimeClientPlugins.add(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use List.of() instead, and register all client plugins here immutably instead of using state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI he's doing this because of having to add per-operation middlewares elsewhere in processFinalizedModel. That's the only way I've ever seen it done unless there's something I'm not considering.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getClientPlugins should probably accept the Model and service shape ID much like these other functions do long-term.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have confirmed that in resolved comments, thx for clarification

RuntimeClientPlugin.builder()
.servicePredicate(RequestCompression::isRequestCompressionService)
.configFields(ListUtils.of(
ConfigField.builder()
.name(DISABLE_REQUEST_COMPRESSION)
.type(GoUniverseTypes.Bool)
.documentation(
"Whether to disable automatic request compression for supported operations.")
.build(),
ConfigField.builder()
.name(REQUEST_MIN_COMPRESSION_SIZE_BYTES)
.type(GoUniverseTypes.Int64)
.documentation("The minimum request body size, in bytes, at which compression "
+ "should occur. The default value is 10 KiB. Values must fall within "
+ "[0, 1MiB].")
.build()
))
.build()
);

return runtimeClientPlugins;
}

private GoWriter.Writable generateAlgorithmList(List<String> algorithms) {
return goTemplate("""
[]string{
$W
}
""",
GoWriter.ChainWritable.of(
algorithms.stream()
.map(it -> goTemplate("$S,", it))
.toList()
).compose(false));
}

private static String getAddRequestCompressionMiddlewareFuncName(String operationName) {
return String.format("addOperation%sRequestCompressionMiddleware", operationName);
}

private GoWriter.Writable writeMiddlewareHelper(SymbolProvider symbolProvider, OperationShape operation) {
String operationName = symbolProvider.toSymbol(operation).getName();
RequestCompressionTrait trait = operation.expectTrait(RequestCompressionTrait.class);

return goTemplate("""
func $add:L(stack $stack:P, options Options) error {
return $addInternal:T(stack, options.DisableRequestCompression, options.RequestMinCompressSizeBytes,
$algorithms:W)
}
""",
MapUtils.of(
"add", getAddRequestCompressionMiddlewareFuncName(operationName),
"stack", SmithyGoTypes.Middleware.Stack,
"addInternal", SmithyGoTypes.Private.RequestCompression.AddRequestCompression,
"algorithms", generateAlgorithmList(trait.getEncodings())
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ software.amazon.smithy.go.codegen.endpoints.EndpointClientPluginsGenerator
# modeled auth schemes
software.amazon.smithy.go.codegen.integration.auth.SigV4AuthScheme
software.amazon.smithy.go.codegen.integration.auth.AnonymousAuthScheme

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove extra line

software.amazon.smithy.go.codegen.requestcompression.RequestCompression
30 changes: 30 additions & 0 deletions private/requestcompression/gzip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package requestcompression

import (
"bytes"
"compress/gzip"
"fmt"
"io"
)

func gzipCompress(input io.Reader) ([]byte, error) {
var b bytes.Buffer
w, err := gzip.NewWriterLevel(&b, gzip.DefaultCompression)
if err != nil {
return nil, fmt.Errorf("failed to create gzip writer, %v", err)
}

inBytes, err := io.ReadAll(input)
if err != nil {
return nil, fmt.Errorf("failed read payload to compress, %v", err)
}

if _, err = w.Write(inBytes); err != nil {
return nil, fmt.Errorf("failed to write payload to be compressed, %v", err)
}
if err = w.Close(); err != nil {
return nil, fmt.Errorf("failed to flush payload being compressed, %v", err)
}

return b.Bytes(), nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package requestcompression

import (
"bytes"
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"io"
"net/http"
)

const captureUncompressedRequestID = "CaptureUncompressedRequest"

// AddCaptureUncompressedRequestMiddleware captures http request before compress encoding for check
func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error {
return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{
buf: buf,
}, "RequestCompression", middleware.Before)
}

type captureUncompressedRequestMiddleware struct {
req *http.Request
buf *bytes.Buffer
bytes []byte
}

// ID returns id of the captureUncompressedRequestMiddleware
func (*captureUncompressedRequestMiddleware) ID() string {
return captureUncompressedRequestID
}

// HandleSerialize captures request payload before it is compressed by request compression middleware
func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
) (
output middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
request, ok := input.Request.(*smithyhttp.Request)
if !ok {
return output, metadata, fmt.Errorf("error when retrieving http request")
}

_, err = io.Copy(m.buf, request.GetStream())
if err != nil {
return output, metadata, fmt.Errorf("error when copying http request stream: %q", err)
}
if err = request.RewindStream(); err != nil {
return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err)
}

return next.HandleSerialize(ctx, input)
}
Loading
Loading