Skip to content

Commit

Permalink
codegen: middleware snapshot tests (aws#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws authored Feb 22, 2024
1 parent 1c1f3f0 commit 0dbd505
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ private GoStdlibTypes() { }

public static final class Context {
public static final Symbol Context = SmithyGoDependency.CONTEXT.valueSymbol("Context");
public static final Symbol Background = SmithyGoDependency.CONTEXT.valueSymbol("Background");
}

public static final class Fmt {
Expand All @@ -42,4 +43,8 @@ public static final class Http {
public static final class Path {
public static final Symbol Join = SmithyGoDependency.PATH.valueSymbol("Join");
}

public static final class Testing {
public static final Symbol T = SmithyGoDependency.TESTING.pointableSymbol("T");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public final class GoWriter extends AbstractCodeWriter<GoWriter> {
private final ImportDeclarations imports = new ImportDeclarations();
private final List<SymbolDependency> dependencies = new ArrayList<>();
private final boolean innerWriter;
private final List<String> buildTags = new ArrayList<>();

private int docWrapLength = DEFAULT_DOC_WRAP_LENGTH;
private AbstractCodeWriter<GoWriter> packageDocs;
Expand Down Expand Up @@ -93,6 +94,7 @@ private void init() {
putFormatter('T', new GoSymbolFormatter());
putFormatter('P', new PointableGoSymbolFormatter());
putFormatter('W', new GoWritableInjector());
putFormatter('D', new GoDependencyFormatter());

if (!innerWriter) {
packageDocs = new GoWriter(this.fullPackageName, true);
Expand Down Expand Up @@ -881,6 +883,11 @@ public void write(Writable w) {
write("$W", w);
}

public GoWriter addBuildTag(String tag) {
buildTags.add(tag);
return this;
}

@Override
public String toString() {
String contents = super.toString();
Expand All @@ -889,6 +896,9 @@ public String toString() {
return contents;
}

var tags = buildTags.isEmpty()
? ""
: "//go:build " + String.join(",", buildTags) + "\n";

String[] packageParts = fullPackageName.split("/");
String header = String.format("// Code generated by smithy-go-codegen DO NOT EDIT.%n%n");
Expand Down Expand Up @@ -919,7 +929,7 @@ public String toString() {
return header + strippedImportString + "\n" + strippedContents;
}

return header + packageDocs + packageStatement + importString + contents;
return header + packageDocs + tags + packageStatement + importString + contents;
}

/**
Expand Down Expand Up @@ -1013,6 +1023,22 @@ public String apply(Object type, String indent) {
}
}

/**
* Implements Go symbol formatting for the {@code $D} formatter.
*/
private class GoDependencyFormatter implements BiFunction<Object, String, String> {
@Override
public String apply(Object type, String indent) {
if (type instanceof GoDependency) {
addUseImports((GoDependency) type);
} else {
throw new CodegenException(
"Invalid type provided to $D. Expected a GoDependency, but found `" + type + "`");
}
return "";
}
}

public interface Writable extends Consumer<GoWriter> {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public final class SmithyGoDependency {
public static final GoDependency JSON = stdlib("encoding/json");
public static final GoDependency IO = stdlib("io");
public static final GoDependency IOUTIL = stdlib("io/ioutil");
public static final GoDependency FS = stdlib("io/fs");
public static final GoDependency CRYPTORAND = stdlib("crypto/rand", "cryptorand");
public static final GoDependency TESTING = stdlib("testing");
public static final GoDependency ERRORS = stdlib("errors");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.go.codegen.integration;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;

import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoDelegator;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.GoStdlibTypes;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.utils.MapUtils;

public class MiddlewareStackSnapshotTests implements GoIntegration {
@Override
public void writeAdditionalFiles(
GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator
) {
goDelegator.useFileWriter("snapshot_test.go", settings.getModuleName(), writer -> {
writer.addBuildTag("snapshot");
writer.write(commonTestSource());
writer.write(snapshotTests(model, settings.getService(model), symbolProvider));
writer.write(snapshotUpdaters(model, settings.getService(model), symbolProvider));
});
}

private GoWriter.Writable commonTestSource() {
return goTemplate("""
$os:D $fs:D $io:D $errors:D $fmt:D $middleware:D
const ssprefix = "snapshot"
type snapshotOK struct{}
func (snapshotOK) Error() string { return "error: success" }
func createp(path string) (*os.File, error) {
if err := os.Mkdir(ssprefix, 0700); err != nil && !errors.Is(err, fs.ErrExist) {
return nil, err
}
return os.Create(path)
}
func sspath(op string) string {
return fmt.Sprintf("%s/api_op_%s.go.snap", ssprefix, op)
}
func updateSnapshot(stack *middleware.Stack, operation string) error {
f, err := createp(sspath(operation))
if err != nil {
return err
}
defer f.Close()
if _, err := f.Write([]byte(stack.String())); err != nil {
return err
}
return snapshotOK{}
}
func testSnapshot(stack *middleware.Stack, operation string) error {
f, err := os.Open(sspath(operation))
if errors.Is(err, fs.ErrNotExist) {
return snapshotOK{}
}
if err != nil {
return err
}
defer f.Close()
expected, err := io.ReadAll(f)
if err != nil {
return err
}
if actual := stack.String(); actual != string(expected) {
return fmt.Errorf("%s != %s", expected, actual)
}
return snapshotOK{}
}
""",
MapUtils.of(
"errors", SmithyGoDependency.ERRORS, "fmt", SmithyGoDependency.FMT,
"fs", SmithyGoDependency.FS, "io", SmithyGoDependency.IO,
"middleware", SmithyGoDependency.SMITHY_MIDDLEWARE, "os", SmithyGoDependency.OS
));
}

private GoWriter.Writable snapshotUpdaters(Model model, ServiceShape service, SymbolProvider symbolProvider) {
return GoWriter.ChainWritable.of(
TopDownIndex.of(model).getContainedOperations(service).stream()
.map(it -> testUpdateSnapshot(it, symbolProvider))
.toList()
).compose();
}

private GoWriter.Writable snapshotTests(Model model, ServiceShape service, SymbolProvider symbolProvider) {
return GoWriter.ChainWritable.of(
TopDownIndex.of(model).getContainedOperations(service).stream()
.map(it -> testCheckSnapshot(it, symbolProvider))
.toList()
).compose();
}

private GoWriter.Writable testUpdateSnapshot(OperationShape operation, SymbolProvider symbolProvider) {
return goTemplate("""
func TestUpdateSnapshot_$operation:L(t $testingT:P) {
svc := New(Options{})
_, err := svc.$operation:L($contextBackground:T(), nil, func(o *Options) {
o.APIOptions = append(o.APIOptions, func(stack $middlewareStack:P) error {
return updateSnapshot(stack, $operation:S)
})
})
if _, ok := err.(snapshotOK); !ok && err != nil {
t.Fatal(err)
}
}
""",
MapUtils.of(
"testingT", GoStdlibTypes.Testing.T,
"contextBackground", GoStdlibTypes.Context.Background,
"middlewareStack", SmithyGoTypes.Middleware.Stack,
"operation", symbolProvider.toSymbol(operation).getName()
));
}

private GoWriter.Writable testCheckSnapshot(OperationShape operation, SymbolProvider symbolProvider) {
return goTemplate("""
func TestCheckSnapshot_$operation:L(t $testingT:P) {
svc := New(Options{})
_, err := svc.$operation:L($contextBackground:T(), nil, func(o *Options) {
o.APIOptions = append(o.APIOptions, func(stack $middlewareStack:P) error {
return testSnapshot(stack, $operation:S)
})
})
if _, ok := err.(snapshotOK); !ok && err != nil {
t.Fatal(err)
}
}
""",
MapUtils.of(
"testingT", GoStdlibTypes.Testing.T,
"contextBackground", GoStdlibTypes.Context.Background,
"middlewareStack", SmithyGoTypes.Middleware.Stack,
"operation", symbolProvider.toSymbol(operation).getName()
));
}
}

0 comments on commit 0dbd505

Please sign in to comment.