From 5cb116bd779d47578a68a2956b1dd790368eeafc Mon Sep 17 00:00:00 2001 From: Josh Deprez Date: Wed, 26 Jun 2024 10:30:04 +1000 Subject: [PATCH 1/4] Pass ctx idiomatically It shouldn't be passed through an option pattern, since that causes it to be stored in a struct. It should be passed as the first arg. See https://pkg.go.dev/context. This adds a ctx arg to Verify, which uses it, and Sign, which doesn't, but now looks weird without it. --- signature/sign.go | 11 +++------- signature/sign_test.go | 48 +++++++++++++++++++++++++++--------------- signature/steps.go | 11 +++++----- 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/signature/sign.go b/signature/sign.go index 395c427..f04d448 100644 --- a/signature/sign.go +++ b/signature/sign.go @@ -44,7 +44,6 @@ type options struct { env map[string]string logger Logger debugSigning bool - ctx context.Context } type Option interface { @@ -54,22 +53,18 @@ type Option interface { type envOption struct{ env map[string]string } type loggerOption struct{ logger Logger } type debugSigningOption struct{ debugSigning bool } -type contextOption struct{ ctx context.Context } func (o envOption) apply(opts *options) { opts.env = o.env } func (o loggerOption) apply(opts *options) { opts.logger = o.logger } func (o debugSigningOption) apply(opts *options) { opts.debugSigning = o.debugSigning } -func (o contextOption) apply(opts *options) { opts.ctx = o.ctx } func WithEnv(env map[string]string) Option { return envOption{env} } func WithLogger(logger Logger) Option { return loggerOption{logger} } func WithDebugSigning(debugSigning bool) Option { return debugSigningOption{debugSigning} } -func WithContext(ctx context.Context) Option { return contextOption{ctx} } func configureOptions(opts ...Option) options { options := options{ env: make(map[string]string), - ctx: context.Background(), } for _, o := range opts { o.apply(&options) @@ -79,7 +74,7 @@ func configureOptions(opts ...Option) options { // Sign computes a new signature for an environment (env) combined with an // object containing values (sf) using a given key. -func Sign(key jwk.Key, sf SignedFielder, opts ...Option) (*pipeline.Signature, error) { +func Sign(_ context.Context, key jwk.Key, sf SignedFielder, opts ...Option) (*pipeline.Signature, error) { options := configureOptions(opts...) values, err := sf.SignedFields() @@ -150,7 +145,7 @@ func Sign(key jwk.Key, sf SignedFielder, opts ...Option) (*pipeline.Signature, e // Verify verifies an existing signature against environment (env) combined with // an object containing values (sf) using keys from a keySet. -func Verify(s *pipeline.Signature, keySet jwk.Set, sf SignedFielder, opts ...Option) error { +func Verify(ctx context.Context, s *pipeline.Signature, keySet jwk.Set, sf SignedFielder, opts ...Option) error { options := configureOptions(opts...) if len(s.SignedFields) == 0 { @@ -190,7 +185,7 @@ func Verify(s *pipeline.Signature, keySet jwk.Set, sf SignedFielder, opts ...Opt return err } - for it := keySet.Keys(options.ctx); it.Next(options.ctx); { + for it := keySet.Keys(ctx); it.Next(ctx); { pair := it.Pair() publicKey := pair.Value.(jwk.Key) fingerprint, err := publicKey.Thumbprint(crypto.SHA256) diff --git a/signature/sign_test.go b/signature/sign_test.go index 6522c0a..fb5abec 100644 --- a/signature/sign_test.go +++ b/signature/sign_test.go @@ -1,6 +1,7 @@ package signature import ( + "context" "errors" "fmt" "math/rand" @@ -21,6 +22,9 @@ const ( ) func TestSignVerify(t *testing.T) { + t.Parallel() + ctx := context.Background() + step := &pipeline.CommandStep{ Command: "llamas", Plugins: pipeline.Plugins{ @@ -96,7 +100,7 @@ func TestSignVerify(t *testing.T) { t.Fatalf("jwkutil.LoadKey(%v, %v) error = %v", privPath, keyName, err) } - sig, err := Sign(sKey, stepWithInvariants, WithEnv(signEnv)) + sig, err := Sign(ctx, sKey, stepWithInvariants, WithEnv(signEnv)) if err != nil { t.Fatalf("Sign(CommandStep, signer) error = %v", err) } @@ -122,8 +126,8 @@ func TestSignVerify(t *testing.T) { t.Fatalf("verifier.AddKey(%v) error = %v", vKey, err) } - if err := Verify(sig, verifier, stepWithInvariants, WithEnv(verifyEnv)); err != nil { - t.Errorf("Verify(sig,CommandStep, verifier) = %v", err) + if err := Verify(ctx, sig, verifier, stepWithInvariants, WithEnv(verifyEnv)); err != nil { + t.Errorf("Verify(sig, CommandStep, verifier) = %v", err) } }) } @@ -147,6 +151,7 @@ func (m testFields) ValuesForFields(fields []string) (map[string]any, error) { func TestSignConcatenatedFields(t *testing.T) { t.Parallel() + ctx := context.Background() // Tests that Sign is resilient to concatenation. // Specifically, these maps should all have distinct "content". (If you @@ -183,7 +188,7 @@ func TestSignConcatenatedFields(t *testing.T) { } for _, m := range maps { - sig, err := Sign(key, m) + sig, err := Sign(ctx, key, m) if err != nil { t.Fatalf("Sign(%v, pts) error = %v", m, err) } @@ -204,6 +209,7 @@ func TestSignConcatenatedFields(t *testing.T) { func TestUnknownAlgorithm(t *testing.T) { t.Parallel() + ctx := context.Background() signer, _, err := jwkutil.NewSymmetricKeyPairFromString(keyID, "alpacas", jwa.HS256) if err != nil { @@ -217,16 +223,20 @@ func TestUnknownAlgorithm(t *testing.T) { key.Set(jwk.AlgorithmKey, "rot13") - if _, err := Sign( - key, - &CommandStepWithInvariants{CommandStep: pipeline.CommandStep{Command: "llamas"}}, - ); err == nil { + step := &CommandStepWithInvariants{ + CommandStep: pipeline.CommandStep{ + Command: "llamas", + }, + } + + if _, err := Sign(ctx, key, step); err == nil { t.Errorf("Sign(nil, CommandStep, signer) = %v, want non-nil error", err) } } func TestVerifyBadSignature(t *testing.T) { t.Parallel() + ctx := context.Background() cs := &CommandStepWithInvariants{CommandStep: pipeline.CommandStep{Command: "llamas"}} @@ -241,13 +251,14 @@ func TestVerifyBadSignature(t *testing.T) { t.Fatalf("NewSymmetricKeyPairFromString(alpacas) error = %v", err) } - if err := Verify(sig, verifier, cs); err == nil { + if err := Verify(ctx, sig, verifier, cs); err == nil { t.Errorf("Verify(sig,CommandStep, alpacas) = %v, want non-nil error", err) } } func TestSignUnknownStep(t *testing.T) { t.Parallel() + ctx := context.Background() steps := pipeline.Steps{ &pipeline.UnknownStep{ @@ -265,13 +276,14 @@ func TestSignUnknownStep(t *testing.T) { t.Fatalf("signer.Key(0) = _, false, want true") } - if err := SignSteps(steps, key, ""); !errors.Is(err, errSigningRefusedUnknownStepType) { + if err := SignSteps(ctx, steps, key, ""); !errors.Is(err, errSigningRefusedUnknownStepType) { t.Errorf("steps.sign(signer) = %v, want %v", err, errSigningRefusedUnknownStepType) } } func TestSignVerifyEnv(t *testing.T) { t.Parallel() + ctx := context.Background() cases := []struct { name string @@ -353,12 +365,12 @@ func TestSignVerifyEnv(t *testing.T) { RepositoryURL: tc.repositoryURL, } - sig, err := Sign(key, stepWithInvariants, WithEnv(tc.pipelineEnv)) + sig, err := Sign(ctx, key, stepWithInvariants, WithEnv(tc.pipelineEnv)) if err != nil { t.Fatalf("Sign(CommandStep, signer) error = %v", err) } - if err := Verify(sig, verifier, stepWithInvariants, WithEnv(tc.verifyEnv)); err != nil { + if err := Verify(ctx, sig, verifier, stepWithInvariants, WithEnv(tc.verifyEnv)); err != nil { t.Errorf("Verify(sig,CommandStep, verifier) = %v", err) } }) @@ -367,6 +379,7 @@ func TestSignVerifyEnv(t *testing.T) { func TestSignatureStability(t *testing.T) { t.Parallel() + ctx := context.Background() // The idea here is to sign and verify a step that is likely to encode in a // non-stable way if there are ordering bugs. @@ -408,18 +421,19 @@ func TestSignatureStability(t *testing.T) { t.Fatalf("signer.Key(0) = _, false, want true") } - sig, err := Sign(key, stepWithInvariants, WithEnv(env)) + sig, err := Sign(ctx, key, stepWithInvariants, WithEnv(env)) if err != nil { t.Fatalf("Sign(env, CommandStep, signer) error = %v", err) } - if err := Verify(sig, verifier, stepWithInvariants, WithEnv(env)); err != nil { - t.Errorf("Verify(sig,env, CommandStep, verifier) = %v", err) + if err := Verify(ctx, sig, verifier, stepWithInvariants, WithEnv(env)); err != nil { + t.Errorf("Verify(sig, env, CommandStep, verifier) = %v", err) } } func TestDebugSigning(t *testing.T) { t.Parallel() + ctx := context.Background() step := &pipeline.CommandStep{ Command: "llamas", @@ -469,7 +483,7 @@ func TestDebugSigning(t *testing.T) { logger := &mockLogger{ expectedFormat: "Signed Step: %s", } - _, err = Sign(sKey, stepWithInvariants, WithEnv(signEnv), WithDebugSigning(false), WithLogger(logger)) + _, err = Sign(ctx, sKey, stepWithInvariants, WithEnv(signEnv), WithDebugSigning(false), WithLogger(logger)) if err != nil { t.Fatalf("Sign(CommandStep, signer) error = %v", err) } @@ -482,7 +496,7 @@ func TestDebugSigning(t *testing.T) { logger = &mockLogger{ expectedFormat: "Signed Step: %s", } - _, err = Sign(sKey, stepWithInvariants, WithEnv(signEnv), WithDebugSigning(true), WithLogger(logger)) + _, err = Sign(ctx, sKey, stepWithInvariants, WithEnv(signEnv), WithDebugSigning(true), WithLogger(logger)) if err != nil { t.Fatalf("Sign(CommandStep, signer) error = %v", err) } diff --git a/signature/steps.go b/signature/steps.go index 8b82827..0629540 100644 --- a/signature/steps.go +++ b/signature/steps.go @@ -1,6 +1,7 @@ package signature import ( + "context" "errors" "fmt" @@ -12,7 +13,7 @@ var errSigningRefusedUnknownStepType = errors.New("refusing to sign pipeline con // SignSteps adds signatures to each command step (and recursively to any command steps that are within group steps). // The steps are mutated directly, so an error part-way through may leave some steps un-signed. -func SignSteps(s pipeline.Steps, key jwk.Key, repoURL string, opts ...Option) error { +func SignSteps(ctx context.Context, s pipeline.Steps, key jwk.Key, repoURL string, opts ...Option) error { for _, step := range s { switch step := step.(type) { case *pipeline.CommandStep: @@ -21,14 +22,14 @@ func SignSteps(s pipeline.Steps, key jwk.Key, repoURL string, opts ...Option) er RepositoryURL: repoURL, } - sig, err := Sign(key, stepWithInvariants, opts...) + sig, err := Sign(ctx, key, stepWithInvariants, opts...) if err != nil { return fmt.Errorf("signing step with command %q: %w", step.Command, err) } step.Signature = sig case *pipeline.GroupStep: - if err := SignSteps(step.Steps, key, repoURL, opts...); err != nil { + if err := SignSteps(ctx, step.Steps, key, repoURL, opts...); err != nil { return fmt.Errorf("signing group step: %w", err) } @@ -45,8 +46,8 @@ func SignSteps(s pipeline.Steps, key jwk.Key, repoURL string, opts ...Option) er } // SignPipeline adds signatures to each command step (and recursively to any command steps that are within group steps) within a pipeline -func SignPipeline(s pipeline.Steps, key jwk.Key, repo string, opts ...Option) error { - if err := SignSteps(s, key, repo, opts...); err != nil { +func SignPipeline(ctx context.Context, s pipeline.Steps, key jwk.Key, repo string, opts ...Option) error { + if err := SignSteps(ctx, s, key, repo, opts...); err != nil { return fmt.Errorf("signing steps: %w", err) } return nil From 0f703185acf10309a412f4bca342dc72952cf162 Mon Sep 17 00:00:00 2001 From: Josh Deprez Date: Wed, 26 Jun 2024 10:44:14 +1000 Subject: [PATCH 2/4] Flatten debug log implementation "Handle errors before proceeding with the rest of your code" https://google.github.io/styleguide/go/decisions#indent-error-flow Additionally, some `else` keywords are redundant (due to `return` in the `if` branch). Removing those avoids overly indenting the "main flow". --- signature/sign.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/signature/sign.go b/signature/sign.go index f04d448..677658f 100644 --- a/signature/sign.go +++ b/signature/sign.go @@ -112,15 +112,16 @@ func Sign(_ context.Context, key jwk.Key, sf SignedFielder, opts ...Option) (*pi return nil, err } - if pk, err := key.PublicKey(); err == nil && options.logger != nil { + if options.logger != nil { + pk, err := key.PublicKey() + if err != nil { + return nil, fmt.Errorf("unable to generate public key: %w", err) + } fingerprint, err := pk.Thumbprint(crypto.SHA256) if err != nil { return nil, fmt.Errorf("calculating key thumbprint: %w", err) - } else { - debug(options.logger, "Public Key Thumbprint (sha256): %s", hex.EncodeToString(fingerprint)) } - } else if err != nil { - return nil, fmt.Errorf("unable to generate public key: %w", err) + debug(options.logger, "Public Key Thumbprint (sha256): %s", hex.EncodeToString(fingerprint)) } if options.debugSigning { @@ -191,9 +192,8 @@ func Verify(ctx context.Context, s *pipeline.Signature, keySet jwk.Set, sf Signe fingerprint, err := publicKey.Thumbprint(crypto.SHA256) if err != nil { return fmt.Errorf("calculating key thumbprint: %w", err) - } else if options.logger != nil { - debug(options.logger, "Public Key Thumbprint (sha256): %s", hex.EncodeToString(fingerprint)) } + debug(options.logger, "Public Key Thumbprint (sha256): %s", hex.EncodeToString(fingerprint)) } if options.debugSigning { From 438c05f229d2c7fb7766d492176f978d44365051 Mon Sep 17 00:00:00 2001 From: Josh Deprez Date: Wed, 26 Jun 2024 11:46:49 +1000 Subject: [PATCH 3/4] Fix inconsistent test error messages They seem to have drifted apart from the actual function calls. Error messages should identify the functions and arguments actually passed to aid with diagnosing test failures. --- signature/sign_test.go | 57 ++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/signature/sign_test.go b/signature/sign_test.go index fb5abec..065bc82 100644 --- a/signature/sign_test.go +++ b/signature/sign_test.go @@ -102,7 +102,7 @@ func TestSignVerify(t *testing.T) { sig, err := Sign(ctx, sKey, stepWithInvariants, WithEnv(signEnv)) if err != nil { - t.Fatalf("Sign(CommandStep, signer) error = %v", err) + t.Fatalf("Sign(ctx, sKey, %v, WithEnv(%v)) error = %v", stepWithInvariants, signEnv, err) } if sig.Algorithm != tc.alg.String() { @@ -127,7 +127,7 @@ func TestSignVerify(t *testing.T) { } if err := Verify(ctx, sig, verifier, stepWithInvariants, WithEnv(verifyEnv)); err != nil { - t.Errorf("Verify(sig, CommandStep, verifier) = %v", err) + t.Errorf("Verify(ctx, %v, verifier, %v, WithEnv(%v)) = %v", sig, stepWithInvariants, verifyEnv, err) } }) } @@ -177,9 +177,10 @@ func TestSignConcatenatedFields(t *testing.T) { sigs := make(map[string][]testFields) - signer, _, err := jwkutil.NewSymmetricKeyPairFromString(keyID, "alpacas", jwa.HS256) + keyStr, keyAlg := "alpacas", jwa.HS256 + signer, _, err := jwkutil.NewSymmetricKeyPairFromString(keyID, keyStr, keyAlg) if err != nil { - t.Fatalf("NewSymmetricKeyPairFromString(alpacas) error = %v", err) + t.Fatalf("jwkutil.NewSymmetricKeyPairFromString(%q, %q, %q) error = %v", keyID, keyStr, keyAlg, err) } key, ok := signer.Key(0) @@ -190,7 +191,7 @@ func TestSignConcatenatedFields(t *testing.T) { for _, m := range maps { sig, err := Sign(ctx, key, m) if err != nil { - t.Fatalf("Sign(%v, pts) error = %v", m, err) + t.Fatalf("Sign(ctx, key, %v) error = %v", m, err) } sigs[sig.Value] = append(sigs[sig.Value], m) @@ -211,9 +212,10 @@ func TestUnknownAlgorithm(t *testing.T) { t.Parallel() ctx := context.Background() - signer, _, err := jwkutil.NewSymmetricKeyPairFromString(keyID, "alpacas", jwa.HS256) + keyStr, keyAlg := "alpacas", jwa.HS256 + signer, _, err := jwkutil.NewSymmetricKeyPairFromString(keyID, keyStr, keyAlg) if err != nil { - t.Fatalf("NewSymmetricKeyPairFromString(alpacas) error = %v", err) + t.Fatalf("jwkutil.NewSymmetricKeyPairFromString(%q, %q, %q) error = %v", keyID, keyStr, keyAlg, err) } key, ok := signer.Key(0) @@ -230,7 +232,7 @@ func TestUnknownAlgorithm(t *testing.T) { } if _, err := Sign(ctx, key, step); err == nil { - t.Errorf("Sign(nil, CommandStep, signer) = %v, want non-nil error", err) + t.Errorf("Sign(ctx, key, %v) = %v, want non-nil error", step, err) } } @@ -246,13 +248,14 @@ func TestVerifyBadSignature(t *testing.T) { Value: "YWxwYWNhcw==", // base64("alpacas") } - _, verifier, err := jwkutil.NewSymmetricKeyPairFromString(keyID, "alpacas", jwa.HS256) + keyStr, keyAlg := "alpacas", jwa.HS256 + _, verifier, err := jwkutil.NewSymmetricKeyPairFromString(keyID, keyStr, keyAlg) if err != nil { - t.Fatalf("NewSymmetricKeyPairFromString(alpacas) error = %v", err) + t.Fatalf("jwkutil.NewSymmetricKeyPairFromString(%q, %q, %q) error = %v", keyID, keyStr, keyAlg, err) } if err := Verify(ctx, sig, verifier, cs); err == nil { - t.Errorf("Verify(sig,CommandStep, alpacas) = %v, want non-nil error", err) + t.Errorf("Verify(ctx, sig, verifier, %v) = %v, want non-nil error", cs, err) } } @@ -266,9 +269,10 @@ func TestSignUnknownStep(t *testing.T) { }, } - signer, _, err := jwkutil.NewSymmetricKeyPairFromString(keyID, "alpacas", jwa.HS256) + keyStr, keyAlg := "alpacas", jwa.HS256 + signer, _, err := jwkutil.NewSymmetricKeyPairFromString(keyID, keyStr, keyAlg) if err != nil { - t.Fatalf("NewSymmetricKeyPairFromString(alpacas) error = %v", err) + t.Fatalf("jwkutil.NewSymmetricKeyPairFromString(%q, %q, %q) error = %v", keyID, keyStr, keyAlg, err) } key, ok := signer.Key(0) @@ -277,7 +281,7 @@ func TestSignUnknownStep(t *testing.T) { } if err := SignSteps(ctx, steps, key, ""); !errors.Is(err, errSigningRefusedUnknownStepType) { - t.Errorf("steps.sign(signer) = %v, want %v", err, errSigningRefusedUnknownStepType) + t.Errorf(`SignSteps(ctx, %v, key, "") = %v, want %v`, steps, err, errSigningRefusedUnknownStepType) } } @@ -350,9 +354,11 @@ func TestSignVerifyEnv(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - signer, verifier, err := jwkutil.NewSymmetricKeyPairFromString(keyID, "alpacas", jwa.HS256) + + keyStr, keyAlg := "alpacas", jwa.HS256 + signer, verifier, err := jwkutil.NewSymmetricKeyPairFromString(keyID, keyStr, keyAlg) if err != nil { - t.Fatalf("NewSymmetricKeyPairFromString(alpacas) error = %v", err) + t.Fatalf("jwkutil.NewSymmetricKeyPairFromString(%q, %q, %q) error = %v", keyID, keyStr, keyAlg, err) } key, ok := signer.Key(0) @@ -367,11 +373,11 @@ func TestSignVerifyEnv(t *testing.T) { sig, err := Sign(ctx, key, stepWithInvariants, WithEnv(tc.pipelineEnv)) if err != nil { - t.Fatalf("Sign(CommandStep, signer) error = %v", err) + t.Fatalf("Sign(ctx, key, %v, WithEnv(%v)) error = %v", stepWithInvariants, tc.pipelineEnv, err) } if err := Verify(ctx, sig, verifier, stepWithInvariants, WithEnv(tc.verifyEnv)); err != nil { - t.Errorf("Verify(sig,CommandStep, verifier) = %v", err) + t.Errorf("Verify(ctx, %v, verifier, %v, WithEnv(%v)) = %v", sig, stepWithInvariants, tc.verifyEnv, err) } }) } @@ -411,9 +417,10 @@ func TestSignatureStability(t *testing.T) { pluginSubCfg[fmt.Sprintf("key%08x", rand.Uint32())] = fmt.Sprintf("value%08x", rand.Uint32()) } - signer, verifier, err := jwkutil.NewKeyPair(keyID, jwa.ES512) + keyAlg := jwa.ES512 + signer, verifier, err := jwkutil.NewKeyPair(keyID, keyAlg) if err != nil { - t.Fatalf("NewKeyPair error = %v", err) + t.Fatalf("jwk.NewKeyPair(%q, %q) error = %v", keyID, keyAlg, err) } key, ok := signer.Key(0) @@ -423,11 +430,11 @@ func TestSignatureStability(t *testing.T) { sig, err := Sign(ctx, key, stepWithInvariants, WithEnv(env)) if err != nil { - t.Fatalf("Sign(env, CommandStep, signer) error = %v", err) + t.Fatalf("Sign(ctx, key, %v, WithEnv(%v)) error = %v", stepWithInvariants, env, err) } if err := Verify(ctx, sig, verifier, stepWithInvariants, WithEnv(env)); err != nil { - t.Errorf("Verify(sig, env, CommandStep, verifier) = %v", err) + t.Errorf("Verify(ctx, %v, verifier, %v, WithEnv(%v)) = %v", sig, stepWithInvariants, env, err) } } @@ -476,7 +483,7 @@ func TestDebugSigning(t *testing.T) { sKey, err := jwkutil.LoadKey(privPath, keyName) if err != nil { - t.Fatalf("jwkutil.LoadKey(%v, %v) error = %v", privPath, keyName, err) + t.Fatalf("jwkutil.LoadKey(%q, %q) error = %v", privPath, keyName, err) } // Test that step payload is not logged when debugSigning is false @@ -485,7 +492,7 @@ func TestDebugSigning(t *testing.T) { } _, err = Sign(ctx, sKey, stepWithInvariants, WithEnv(signEnv), WithDebugSigning(false), WithLogger(logger)) if err != nil { - t.Fatalf("Sign(CommandStep, signer) error = %v", err) + t.Fatalf("Sign(ctx, sKey, %v, WithEnv(%v), WithDebugSigning(false), WithLogger(logger)) error = %v", stepWithInvariants, signEnv, err) } if logger.passed { @@ -498,7 +505,7 @@ func TestDebugSigning(t *testing.T) { } _, err = Sign(ctx, sKey, stepWithInvariants, WithEnv(signEnv), WithDebugSigning(true), WithLogger(logger)) if err != nil { - t.Fatalf("Sign(CommandStep, signer) error = %v", err) + t.Fatalf("Sign(ctx, sKey, %v, WithEnv(%v), WithDebugSigning(true), WithLogger(logger)) error = %v", stepWithInvariants, signEnv, err) } if !logger.passed { From 2e6f3cbd31b3c3e6a6337f5662b9349ba7f6c807 Mon Sep 17 00:00:00 2001 From: Josh Deprez Date: Wed, 26 Jun 2024 11:58:08 +1000 Subject: [PATCH 4/4] Replace mock logger with fake logger The distinction may seem trivial, but the "mock" approach causes confusion (even in the previous code: if mockLogger.passed { fail } ?) The test can be responsible for deciding if the logged output is valid, without hiding it inside a mock type. --- signature/sign_test.go | 42 +++++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/signature/sign_test.go b/signature/sign_test.go index 065bc82..338ccf9 100644 --- a/signature/sign_test.go +++ b/signature/sign_test.go @@ -8,6 +8,7 @@ import ( "os" "path" "slices" + "strings" "testing" "github.com/buildkite/go-pipeline" @@ -487,45 +488,40 @@ func TestDebugSigning(t *testing.T) { } // Test that step payload is not logged when debugSigning is false - logger := &mockLogger{ - expectedFormat: "Signed Step: %s", - } + logger := &fakeLogger{} _, err = Sign(ctx, sKey, stepWithInvariants, WithEnv(signEnv), WithDebugSigning(false), WithLogger(logger)) if err != nil { t.Fatalf("Sign(ctx, sKey, %v, WithEnv(%v), WithDebugSigning(false), WithLogger(logger)) error = %v", stepWithInvariants, signEnv, err) } - if logger.passed { - t.Errorf("Expected \"%s\" not to be logged, but got %v", logger.expectedFormat, logger.actualFormats) + logged := logger.buf.String() + if want := "Public Key Thumbprint (sha256)"; !strings.Contains(logged, want) { + t.Errorf("logger.buf.String() = %q, missing %q", logged, want) + } + if want := "Signed Step"; strings.Contains(logged, want) { + t.Errorf("logger.buf.String() = %q, contains %q", logged, want) } // Test that step payload is logged when debugSigning is true - logger = &mockLogger{ - expectedFormat: "Signed Step: %s", - } + logger = &fakeLogger{} _, err = Sign(ctx, sKey, stepWithInvariants, WithEnv(signEnv), WithDebugSigning(true), WithLogger(logger)) if err != nil { t.Fatalf("Sign(ctx, sKey, %v, WithEnv(%v), WithDebugSigning(true), WithLogger(logger)) error = %v", stepWithInvariants, signEnv, err) } - if !logger.passed { - t.Errorf("Expected \"%s\" to be logged, but only got %v", logger.expectedFormat, logger.actualFormats) + logged = logger.buf.String() + if want := "Public Key Thumbprint (sha256)"; !strings.Contains(logged, want) { + t.Errorf("logger.buf.String() = %q, missing %q", logged, want) + } + if want := "Signed Step"; !strings.Contains(logged, want) { + t.Errorf("logger.buf.String() = %q, missing %q", logged, want) } } -type mockLogger struct { - passed bool - expectedFormat string - actualFormats []string +type fakeLogger struct { + buf strings.Builder } -func (m *mockLogger) Debug(f string, v ...any) { - if m.passed { - return - } - - m.actualFormats = append(m.actualFormats, f) - if f == m.expectedFormat { - m.passed = true - } +func (l *fakeLogger) Debug(f string, v ...any) { + fmt.Fprintf(&l.buf, f, v...) }