Skip to content

Commit

Permalink
Parse source account from sourceARN
Browse files Browse the repository at this point in the history
  • Loading branch information
haoranleo committed Aug 26, 2024
1 parent e379bc7 commit ea8bd8a
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions pkg/ec2provider/ec2provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ func newSession(roleARN, sourceARN, region string, qps int, burst int) *session.
logrus.Errorf("Getting error = %s while creating rate limited client ", err)
}

stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)),
roleARN, sourceARN)
stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), sourceARN)
ap := &stscreds.AssumeRoleProvider{
Client: stsClient,
RoleARN: roleARN,
Expand Down Expand Up @@ -286,23 +285,22 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string
}
}

func applySTSRequestHeaders(stsClient *sts.STS, roleARN, sourceARN string) *sts.STS {
logrus.Infof("Using AWS assumed role %v", roleARN)
sourceAcct, err := getSourceAccount(roleARN)
if err != nil {
logrus.Errorf("failed to parse source account from role ARN %v: %v", roleARN, err)
return stsClient
}
reqHeaders := map[string]string{
headerSourceAccount: sourceAcct,
}
func applySTSRequestHeaders(stsClient *sts.STS, sourceARN string) *sts.STS {
// parse both source account and source arn from the sourceARN, and add them as headers to the STS client
if sourceARN != "" {
reqHeaders[headerSourceArn] = sourceARN
sourceAcct, err := getSourceAccount(sourceARN)
if err != nil {
panic(fmt.Sprintf("%s is not a valid arn, err: %v", sourceARN, err))
}
reqHeaders := map[string]string{
headerSourceAccount: sourceAcct,
headerSourceArn: sourceARN,
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders))
})
logrus.Infof("configuring STS client with extra headers, %v", reqHeaders)
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders))
})
logrus.Infof("configuring STS client with extra headers, %v", reqHeaders)
return stsClient
}

Expand Down

0 comments on commit ea8bd8a

Please sign in to comment.