diff --git a/.github/actions/detect-workflow/main.go b/.github/actions/detect-workflow/main.go index 0a6690345..89e880fb7 100644 --- a/.github/actions/detect-workflow/main.go +++ b/.github/actions/detect-workflow/main.go @@ -28,14 +28,14 @@ import ( ) type action struct { - getenv func(string) string - event map[string]any - client *github.OIDCClient + getenv func(string) string + event map[string]any + getClient func() (*github.OIDCClient, error) } // TODO(github.com/slsa-framework/slsa-github-generator/issues/164): use the github context via the shared library -func newAction(getenv func(string) string, c *github.OIDCClient) (*action, error) { +func newAction(getenv func(string) string, getClient func() (*github.OIDCClient, error)) (*action, error) { eventPath := getenv("GITHUB_EVENT_PATH") if eventPath == "" { return nil, errors.New("GITHUB_EVENT_PATH not set") @@ -52,9 +52,9 @@ func newAction(getenv func(string) string, c *github.OIDCClient) (*action, error } return &action{ - getenv: getenv, - event: event, - client: c, + getenv: getenv, + event: event, + getClient: getClient, }, nil } @@ -107,7 +107,11 @@ func (a *action) getRepoRef(ctx context.Context) (string, string, error) { } audience = path.Join(audience, "detect-workflow") - t, err := a.client.Token(ctx, []string{audience}) + client, err := a.getClient() + if err != nil { + return "", "", fmt.Errorf("creating OIDC client: %w", err) + } + t, err := client.Token(ctx, []string{audience}) if err != nil { return "", "", fmt.Errorf("getting OIDC token: %w", err) } @@ -136,11 +140,7 @@ func (a *action) getRepoRef(ctx context.Context) (string, string, error) { } func main() { - c, err := github.NewOIDCClient() - if err != nil { - log.Fatal(err) - } - a, err := newAction(os.Getenv, c) + a, err := newAction(os.Getenv, github.NewOIDCClient) if err != nil { log.Fatal(err) } diff --git a/.github/actions/detect-workflow/main_test.go b/.github/actions/detect-workflow/main_test.go index 89950c69a..7c70b40af 100644 --- a/.github/actions/detect-workflow/main_test.go +++ b/.github/actions/detect-workflow/main_test.go @@ -122,7 +122,9 @@ func Test_action_getRepoRef(t *testing.T) { } return "" }, - client: c, + getClient: func() (*github.OIDCClient, error) { + return c, nil + }, } repo, ref, err := a.getRepoRef(context.Background()) @@ -151,7 +153,9 @@ func Test_action_getRepoRef(t *testing.T) { } return env[k] }, - client: c, + getClient: func() (*github.OIDCClient, error) { + return c, nil + }, event: map[string]any{ "pull_request": map[string]any{ "head": map[string]any{