Skip to content

Commit

Permalink
Check session identity when creds are static
Browse files Browse the repository at this point in the history
  • Loading branch information
mtibben committed Mar 19, 2023
1 parent af6df83 commit bfa952d
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 94 deletions.
204 changes: 146 additions & 58 deletions cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/skratchdot/open-golang/open"
)

Expand Down Expand Up @@ -74,108 +75,99 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) {
return err
}

err = LoginCommand(input, f, keyring)
err = LoginCommand(context.Background(), input, f, keyring)
app.FatalIfError(err, "login")
return nil
})
}

func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
}

var credsProvider aws.CredentialsProvider

func getCredsProvider(input LoginCommandInput, config *vault.ProfileConfig, keyring keyring.Keyring) (credsProvider aws.CredentialsProvider, err error) {
if input.ProfileName == "" {
// When no profile is specified, source credentials from the environment
configFromEnv, err := awsconfig.NewEnvConfig()
if err != nil {
return fmt.Errorf("unable to authenticate to AWS through your environment variables: %w", err)
return nil, fmt.Errorf("unable to authenticate to AWS through your environment variables: %w", err)
}
credsProvider = credentials.StaticCredentialsProvider{Value: configFromEnv.Credentials}
if configFromEnv.Credentials.SessionToken == "" {
credsProvider, err = vault.NewFederationTokenProvider(context.TODO(), credsProvider, config)
if err != nil {
return err
}

if configFromEnv.Credentials.AccessKeyID == "" {
return nil, fmt.Errorf("argument 'profile' not provided, nor any AWS env vars found. Try --help")
}

credsProvider = credentials.StaticCredentialsProvider{Value: configFromEnv.Credentials}
} else {
// Use a profile from the AWS config file
ckr := &vault.CredentialKeyring{Keyring: keyring}
if config.HasRole() || config.HasSSOStartURL() || config.HasCredentialProcess() || config.HasWebIdentity() {
// If AssumeRole or sso.GetRoleCredentials isn't used, GetFederationToken has to be used for IAM credentials
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession, false)
} else {
credsProvider, err = vault.NewFederationTokenCredentialsProvider(context.TODO(), input.ProfileName, ckr, config)
t := vault.TempCredentialsCreator{
Keyring: ckr,
DisableSessions: input.NoSession,
DisableSessionsForProfile: config.ProfileName,
}
credsProvider, err = t.GetProviderForProfile(config)
if err != nil {
return fmt.Errorf("profile %s: %w", input.ProfileName, err)
return nil, fmt.Errorf("profile %s: %w", input.ProfileName, err)
}
}

creds, err := credsProvider.Retrieve(context.TODO())
return credsProvider, err
}

// LoginCommand creates a login URL for the AWS Management Console using the method described at
// https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_enable-console-custom-url.html
func LoginCommand(ctx context.Context, input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName)
if err != nil {
return fmt.Errorf("Failed to get credentials: %w", err)
}
if creds.AccessKeyID == "" && input.ProfileName == "" {
return fmt.Errorf("argument 'profile' not provided, nor any AWS env vars found. Try --help")
return fmt.Errorf("Error loading config: %w", err)
}

jsonBytes, err := json.Marshal(map[string]string{
"sessionId": creds.AccessKeyID,
"sessionKey": creds.SecretAccessKey,
"sessionToken": creds.SessionToken,
})
credsProvider, err := getCredsProvider(input, config, keyring)
if err != nil {
return err
}

loginURLPrefix, destination := generateLoginURL(config.Region, input.Path)

req, err := http.NewRequestWithContext(context.TODO(), "GET", loginURLPrefix, nil)
// if we already know the type of credentials being created, avoid calling isCallerIdentityAssumedRole
canCredsBeUsedInLoginURL, err := canProviderBeUsedForLogin(credsProvider)
if err != nil {
return err
}

if creds.CanExpire {
log.Printf("Creating login token, expires in %s", time.Until(creds.Expires))
}
if !canCredsBeUsedInLoginURL {
// use a static creds provider so that we don't request credentials from AWS more than once
credsProvider, err = createStaticCredentialsProvider(ctx, credsProvider)
if err != nil {
return err
}

q := req.URL.Query()
q.Add("Action", "getSigninToken")
q.Add("Session", string(jsonBytes))
req.URL.RawQuery = q.Encode()
// if the credentials have come from an unknown source like credential_process, check the
// caller identity to see if it's an assumed role
isAssumedRole, err := isCallerIdentityAssumedRole(ctx, credsProvider, config)
if err != nil {
return err
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
if !isAssumedRole {
log.Println("Creating a federated session")
credsProvider, err = vault.NewFederationTokenProvider(ctx, credsProvider, config)
if err != nil {
return err
}
}
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
creds, err := credsProvider.Retrieve(ctx)
if err != nil {
return err
}

if resp.StatusCode != http.StatusOK {
log.Printf("Response body was %s", body)
return fmt.Errorf("Call to getSigninToken failed with %v", resp.Status)
if creds.CanExpire {
log.Printf("Requesting a signin token for session expiring in %s", time.Until(creds.Expires))
}

var respParsed map[string]string

err = json.Unmarshal(body, &respParsed)
loginURLPrefix, destination := generateLoginURL(config.Region, input.Path)
signinToken, err := requestSigninToken(ctx, creds, loginURLPrefix)
if err != nil {
return err
}

signinToken, ok := respParsed["SigninToken"]
if !ok {
return fmt.Errorf("Expected a response with SigninToken")
}

loginURL := fmt.Sprintf("%s?Action=login&Issuer=aws-vault&Destination=%s&SigninToken=%s",
loginURLPrefix, url.QueryEscape(destination), url.QueryEscape(signinToken))

Expand Down Expand Up @@ -212,3 +204,99 @@ func generateLoginURL(region string, path string) (string, string) {
}
return loginURLPrefix, destination
}

func isCallerIdentityAssumedRole(ctx context.Context, credsProvider aws.CredentialsProvider, config *vault.ProfileConfig) (bool, error) {
cfg := vault.NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints)
client := sts.NewFromConfig(cfg)
id, err := client.GetCallerIdentity(ctx, nil)
if err != nil {
return false, err
}
arn := aws.ToString(id.Arn)
arnParts := strings.Split(arn, ":")
if len(arnParts) < 6 {
return false, fmt.Errorf("unable to parse ARN: %s", arn)
}
if strings.HasPrefix(arnParts[5], "assumed-role") {
return true, nil
}
return false, nil
}

func createStaticCredentialsProvider(ctx context.Context, credsProvider aws.CredentialsProvider) (sc credentials.StaticCredentialsProvider, err error) {
creds, err := credsProvider.Retrieve(ctx)
if err != nil {
return sc, err
}
return credentials.StaticCredentialsProvider{Value: creds}, nil
}

// canProviderBeUsedForLogin returns true if the credentials produced by the provider is known to be usable by the login URL endpoint
func canProviderBeUsedForLogin(credsProvider aws.CredentialsProvider) (bool, error) {
if _, ok := credsProvider.(*vault.AssumeRoleProvider); ok {
return true, nil
}
if _, ok := credsProvider.(*vault.SSORoleCredentialsProvider); ok {
return true, nil
}
if _, ok := credsProvider.(*vault.AssumeRoleWithWebIdentityProvider); ok {
return true, nil
}
if c, ok := credsProvider.(*vault.CachedSessionProvider); ok {
return canProviderBeUsedForLogin(c.SessionProvider)
}

return false, nil
}

// Create a signin token
func requestSigninToken(ctx context.Context, creds aws.Credentials, loginURLPrefix string) (string, error) {
jsonSession, err := json.Marshal(map[string]string{
"sessionId": creds.AccessKeyID,
"sessionKey": creds.SecretAccessKey,
"sessionToken": creds.SessionToken,
})
if err != nil {
return "", err
}

req, err := http.NewRequestWithContext(ctx, "GET", loginURLPrefix, nil)
if err != nil {
return "", err
}

q := req.URL.Query()
q.Add("Action", "getSigninToken")
q.Add("Session", string(jsonSession))
req.URL.RawQuery = q.Encode()

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}

if resp.StatusCode != http.StatusOK {
log.Printf("Response body was %s", body)
return "", fmt.Errorf("Call to getSigninToken failed with %v", resp.Status)
}

var respParsed map[string]string

err = json.Unmarshal(body, &respParsed)
if err != nil {
return "", err
}

signinToken, ok := respParsed["SigninToken"]
if !ok {
return "", fmt.Errorf("Expected a response with SigninToken")
}

return signinToken, nil
}
4 changes: 2 additions & 2 deletions vault/assumeroleprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type AssumeRoleProvider struct {

// Retrieve generates a new set of temporary credentials using STS AssumeRole
func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
role, err := p.assumeRole(ctx)
role, err := p.RetrieveStsCredentials(ctx)
if err != nil {
return aws.Credentials{}, err
}
Expand All @@ -49,7 +49,7 @@ func (p *AssumeRoleProvider) roleSessionName() string {
return p.RoleSessionName
}

func (p *AssumeRoleProvider) assumeRole(ctx context.Context) (*ststypes.Credentials, error) {
func (p *AssumeRoleProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
var err error

input := &sts.AssumeRoleInput{
Expand Down
4 changes: 2 additions & 2 deletions vault/assumerolewithwebidentityprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type AssumeRoleWithWebIdentityProvider struct {

// Retrieve generates a new set of temporary credentials using STS AssumeRoleWithWebIdentity
func (p *AssumeRoleWithWebIdentityProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
creds, err := p.assumeRole(ctx)
creds, err := p.RetrieveStsCredentials(ctx)
if err != nil {
return aws.Credentials{}, err
}
Expand All @@ -48,7 +48,7 @@ func (p *AssumeRoleWithWebIdentityProvider) roleSessionName() string {
return p.RoleSessionName
}

func (p *AssumeRoleWithWebIdentityProvider) assumeRole(ctx context.Context) (*ststypes.Credentials, error) {
func (p *AssumeRoleWithWebIdentityProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
var err error

webIdentityToken, err := p.webIdentityToken()
Expand Down
28 changes: 21 additions & 7 deletions vault/cachedsessionprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,48 @@ import (
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
)

type StsSessionProvider interface {
aws.CredentialsProvider
RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error)
}

// CachedSessionProvider retrieves cached credentials from the keyring, or if no credentials are cached
// retrieves temporary credentials using the CredentialsFunc
type CachedSessionProvider struct {
SessionKey SessionMetadata
CredentialsFunc func(context.Context) (*ststypes.Credentials, error)
SessionProvider StsSessionProvider
Keyring *SessionKeyring
ExpiryWindow time.Duration
}

// Retrieve returns cached credentials from the keyring, or if no credentials are cached
// generates a new set of temporary credentials using the CredentialsFunc
func (p *CachedSessionProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
creds, err := p.Keyring.Get(p.SessionKey)

if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow {
// lookup missed, we need to create a new one.
creds, err = p.CredentialsFunc(ctx)
creds, err = p.SessionProvider.RetrieveStsCredentials(ctx)
if err != nil {
return aws.Credentials{}, err
return nil, err
}
err = p.Keyring.Set(p.SessionKey, creds)
if err != nil {
return aws.Credentials{}, err
return nil, err
}
} else {
log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String())
}

return creds, nil
}

// Retrieve returns cached credentials from the keyring, or if no credentials are cached
// generates a new set of temporary credentials using the CredentialsFunc
func (p *CachedSessionProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
creds, err := p.RetrieveStsCredentials(ctx)
if err != nil {
return aws.Credentials{}, err
}

return aws.Credentials{
AccessKeyID: aws.ToString(creds.AccessKeyId),
SecretAccessKey: aws.ToString(creds.SecretAccessKey),
Expand Down
2 changes: 1 addition & 1 deletion vault/credentialprocessprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (p *CredentialProcessProvider) retrieveWith(ctx context.Context, fn func(st
}, nil
}

func (p *CredentialProcessProvider) callCredentialProcess(ctx context.Context) (*ststypes.Credentials, error) {
func (p *CredentialProcessProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
return p.callCredentialProcessWith(ctx, executeProcess)
}

Expand Down
4 changes: 2 additions & 2 deletions vault/sessiontokenprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type SessionTokenProvider struct {

// Retrieve generates a new set of temporary credentials using STS GetSessionToken
func (p *SessionTokenProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
creds, err := p.GetSessionToken(ctx)
creds, err := p.RetrieveStsCredentials(ctx)
if err != nil {
return aws.Credentials{}, err
}
Expand All @@ -34,7 +34,7 @@ func (p *SessionTokenProvider) Retrieve(ctx context.Context) (aws.Credentials, e
}

// GetSessionToken generates a new set of temporary credentials using STS GetSessionToken
func (p *SessionTokenProvider) GetSessionToken(ctx context.Context) (*ststypes.Credentials, error) {
func (p *SessionTokenProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
var err error

input := &sts.GetSessionTokenInput{
Expand Down
4 changes: 4 additions & 0 deletions vault/ssorolecredentialsprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*s
return resp.RoleCredentials, nil
}

func (p *SSORoleCredentialsProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
return p.getRoleCredentialsAsStsCredemtials(ctx)
}

// getRoleCredentialsAsStsCredemtials returns getRoleCredentials as sts.Credentials because sessions.Store expects it
func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx context.Context) (*ststypes.Credentials, error) {
creds, err := p.getRoleCredentials(ctx)
Expand Down
Loading

0 comments on commit bfa952d

Please sign in to comment.