diff --git a/pkg/providers/v1/aws.go b/pkg/providers/v1/aws.go index 805314ed6f..e3a4b90b6e 100644 --- a/pkg/providers/v1/aws.go +++ b/pkg/providers/v1/aws.go @@ -249,6 +249,9 @@ const volumeAttachmentStuck = "VolumeAttachmentStuck" // Indicates that a node has volumes stuck in attaching state and hence it is not fit for scheduling more pods const nodeWithImpairedVolumes = "NodeWithImpairedVolumes" +const headerSourceArn = "x-amz-source-arn" +const headerSourceAccount = "x-amz-source-account" + const ( // volumeAttachmentConsecutiveErrorLimit is the number of consecutive errors we will ignore when waiting for a volume to attach/detach volumeAttachmentStatusConsecutiveErrorLimit = 10 @@ -614,6 +617,11 @@ type CloudConfig struct { // RoleARN is the IAM role to assume when interaction with AWS APIs. RoleARN string + // SourceARN is value which is passed while assuming role specified by RoleARN. When a service + // assumes a role in your account, you can include the aws:SourceAccount and aws:SourceArn global + // condition context keys in your role trust policy to limit access to the role to only requests that are generated + // by expected resources. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html + SourceARN string // KubernetesClusterTag is the legacy cluster id we'll use to identify our cluster resources KubernetesClusterTag string @@ -1260,12 +1268,15 @@ func init() { var creds *credentials.Credentials if cfg.Global.RoleARN != "" { - klog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN) + stsClient, err := getSTSClient(sess, cfg.Global.RoleARN, cfg.Global.SourceARN) + if err != nil { + return nil, fmt.Errorf("unable to create sts client, %v", err) + } creds = credentials.NewChainCredentials( []credentials.Provider{ &credentials.EnvProvider{}, assumeRoleProvider(&stscreds.AssumeRoleProvider{ - Client: sts.New(sess), + Client: stsClient, RoleARN: cfg.Global.RoleARN, }), }) @@ -1276,13 +1287,33 @@ func init() { }) } +func getSTSClient(sess *session.Session, roleARN, sourceARN string) (*sts.STS, error) { + klog.Infof("Using AWS assumed role %v", roleARN) + stsClient := sts.New(sess) + sourceAcct, err := GetSourceAccount(roleARN) + if err != nil { + return nil, err + } + reqHeaders := map[string]string{ + headerSourceAccount: sourceAcct, + } + if sourceARN != "" { + reqHeaders[headerSourceArn] = sourceARN + } + stsClient.Handlers.Sign.PushFront(func(s *request.Request) { + s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) + }) + klog.V(4).Infof("configuring STS client with extra headers, %v", reqHeaders) + return stsClient, nil +} + // readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. func readAWSCloudConfig(config io.Reader) (*CloudConfig, error) { var cfg CloudConfig var err error if config != nil { - err = gcfg.ReadInto(&cfg, config) + err = gcfg.FatalOnly(gcfg.ReadInto(&cfg, config)) if err != nil { return nil, err } diff --git a/pkg/providers/v1/aws_utils.go b/pkg/providers/v1/aws_utils.go index bd373d64e5..621731ed1c 100644 --- a/pkg/providers/v1/aws_utils.go +++ b/pkg/providers/v1/aws_utils.go @@ -17,7 +17,10 @@ limitations under the License. package aws import ( + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" "k8s.io/apimachinery/pkg/util/sets" ) @@ -43,3 +46,21 @@ func stringSetFromPointers(in []*string) sets.String { } return out } + +// GetSourceAccount constructs source acct and return them for use +func GetSourceAccount(roleARN string) (string, error) { + // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) + // arn:partition:service:region:account-id:resource-type/resource-id + // IAM format, region is always blank + // arn:aws:iam::account:role/role-name-with-path + if !arn.IsARN(roleARN) { + return "", fmt.Errorf("incorrect ARN format for role %s", roleARN) + } + + parsedArn, err := arn.Parse(roleARN) + if err != nil { + return "", err + } + + return parsedArn.AccountID, nil +} diff --git a/pkg/providers/v1/aws_utils_test.go b/pkg/providers/v1/aws_utils_test.go new file mode 100644 index 0000000000..5dfe9460d6 --- /dev/null +++ b/pkg/providers/v1/aws_utils_test.go @@ -0,0 +1,60 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import "testing" + +func TestGetSourceAcctAndArn(t *testing.T) { + type args struct { + roleARN string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "corect role arn", + args: args{ + roleARN: "arn:aws:iam::123456789876:role/test-cluster", + }, + want: "123456789876", + wantErr: false, + }, + { + name: "incorect role arn", + args: args{ + roleARN: "arn:aws:iam::123456789876", + }, + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GetSourceAccount(tt.args.roleARN) + if (err != nil) != tt.wantErr { + t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want) + } + }) + } +}