diff --git a/pkg/ec2provider/ec2provider.go b/pkg/ec2provider/ec2provider.go index 8dbd66265..d760f0bda 100644 --- a/pkg/ec2provider/ec2provider.go +++ b/pkg/ec2provider/ec2provider.go @@ -109,8 +109,7 @@ func newSession(roleARN, sourceARN, region string, qps int, burst int) *session. logrus.Errorf("Getting error = %s while creating rate limited client ", err) } - stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), - roleARN, sourceARN) + stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), sourceARN) ap := &stscreds.AssumeRoleProvider{ Client: stsClient, RoleARN: roleARN, @@ -286,23 +285,22 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string } } -func applySTSRequestHeaders(stsClient *sts.STS, roleARN, sourceARN string) *sts.STS { - logrus.Infof("Using AWS assumed role %v", roleARN) - sourceAcct, err := getSourceAccount(roleARN) - if err != nil { - logrus.Errorf("failed to parse source account from role ARN %v: %v", roleARN, err) - return stsClient - } - reqHeaders := map[string]string{ - headerSourceAccount: sourceAcct, - } +func applySTSRequestHeaders(stsClient *sts.STS, sourceARN string) *sts.STS { + // parse both source account and source arn from the sourceARN, and add them as headers to the STS client if sourceARN != "" { - reqHeaders[headerSourceArn] = sourceARN + sourceAcct, err := getSourceAccount(sourceARN) + if err != nil { + panic(fmt.Sprintf("%s is not a valid arn, err: %v", sourceARN, err)) + } + reqHeaders := map[string]string{ + headerSourceAccount: sourceAcct, + headerSourceArn: sourceARN, + } + stsClient.Handlers.Sign.PushFront(func(s *request.Request) { + s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) + }) + logrus.Infof("configuring STS client with extra headers, %v", reqHeaders) } - stsClient.Handlers.Sign.PushFront(func(s *request.Request) { - s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) - }) - logrus.Infof("configuring STS client with extra headers, %v", reqHeaders) return stsClient }