Skip to content

Commit

Permalink
fixup(sg-resolver): Allow multiple SGs with the same Name tag
Browse files Browse the repository at this point in the history
  • Loading branch information
alloveras committed Jul 22, 2024
1 parent e5d625f commit 41f7943
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 16 deletions.
18 changes: 18 additions & 0 deletions pkg/algorithm/slices.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package algorithm

import "cmp"

// RemoveSliceDuplicates returns a copy of the slice without duplicate entries.
func RemoveSliceDuplicates[S ~[]E, E cmp.Ordered](s S) []E {
result := make([]E, 0, len(s))
found := make(map[E]struct{}, len(s))

for _, x := range s {
if _, ok := found[x]; !ok {
found[x] = struct{}{}
result = append(result, x)
}
}

return result
}
46 changes: 46 additions & 0 deletions pkg/algorithm/slices_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package algorithm

import (
"testing"

"github.com/stretchr/testify/assert"
)

func Test_RemoveSliceDuplicates(t *testing.T) {
type args struct {
data []string
}
tests := []struct {
name string
args args
want []string
}{
{
name: "empty",
args: args{
data: []string{},
},
want: []string{},
},
{
name: "no duplicate entries",
args: args{
data: []string{"a", "b", "c", "d"},
},
want: []string{"a", "b", "c", "d"},
},
{
name: "with duplicates",
args: args{
data: []string{"a", "b", "a", "c", "b"},
},
want: []string{"a", "b", "c"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := RemoveSliceDuplicates(tt.args.data)
assert.Equal(t, tt.want, got)
})
}
}
46 changes: 43 additions & 3 deletions pkg/networking/security_group_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
awssdk "github.com/aws/aws-sdk-go/aws"
ec2sdk "github.com/aws/aws-sdk-go/service/ec2"
"github.com/pkg/errors"
"sigs.k8s.io/aws-load-balancer-controller/pkg/algorithm"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
)

Expand Down Expand Up @@ -35,42 +36,60 @@ type defaultSecurityGroupResolver struct {
func (r *defaultSecurityGroupResolver) ResolveViaNameOrID(ctx context.Context, sgNameOrIDs []string) ([]string, error) {
sgIDs, sgNames := r.splitIntoSgNameAndIDs(sgNameOrIDs)
var resolvedSGs []*ec2sdk.SecurityGroup

if len(sgIDs) > 0 {
sgs, err := r.resolveViaGroupID(ctx, sgIDs)
if err != nil {
return nil, err
}
resolvedSGs = append(resolvedSGs, sgs...)
}

if len(sgNames) > 0 {
sgs, err := r.resolveViaGroupName(ctx, sgNames)
if err != nil {
return nil, err
}
resolvedSGs = append(resolvedSGs, sgs...)
}

resolvedSGIDs := make([]string, 0, len(resolvedSGs))
for _, sg := range resolvedSGs {
resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId))
}
if len(resolvedSGIDs) != len(sgNameOrIDs) {
return nil, errors.Errorf("couldn't find all securityGroups, nameOrIDs: %v, found: %v", sgNameOrIDs, resolvedSGIDs)
}

return resolvedSGIDs, nil
}

func (r *defaultSecurityGroupResolver) resolveViaGroupID(ctx context.Context, sgIDs []string) ([]*ec2sdk.SecurityGroup, error) {
req := &ec2sdk.DescribeSecurityGroupsInput{
GroupIds: awssdk.StringSlice(sgIDs),
}

sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req)
if err != nil {
return nil, err
}

resolvedSGIDs := make([]string, 0, len(sgs))
for _, sg := range sgs {
resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId))
}

if len(sgIDs) != len(resolvedSGIDs) {
return nil, errors.Errorf(
"couldn't find all securityGroups, requested ids: [%s], found: [%s]",
strings.Join(sgIDs, ", "),
strings.Join(resolvedSGIDs, ", "),
)
}

return sgs, nil
}

func (r *defaultSecurityGroupResolver) resolveViaGroupName(ctx context.Context, sgNames []string) ([]*ec2sdk.SecurityGroup, error) {
sgNames = algorithm.RemoveSliceDuplicates(sgNames)

req := &ec2sdk.DescribeSecurityGroupsInput{
Filters: []*ec2sdk.Filter{
{
Expand All @@ -83,10 +102,31 @@ func (r *defaultSecurityGroupResolver) resolveViaGroupName(ctx context.Context,
},
},
}

sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req)
if err != nil {
return nil, err
}

resolvedSGNames := make([]string, 0, len(sgs))
for _, sg := range sgs {
for _, tag := range sg.Tags {
if awssdk.StringValue(tag.Key) == "Name" {
resolvedSGNames = append(resolvedSGNames, awssdk.StringValue(tag.Value))
}
}
}

resolvedSGNames = algorithm.RemoveSliceDuplicates(resolvedSGNames)

if len(sgNames) != len(resolvedSGNames) {
return nil, errors.Errorf(
"couldn't find all securityGroups, requested names: [%s], found: [%s]",
strings.Join(sgNames, ", "),
strings.Join(resolvedSGNames, ", "),
)
}

return sgs, nil
}

Expand Down
94 changes: 81 additions & 13 deletions pkg/networking/security_group_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,15 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) {
resp: []*ec2sdk.SecurityGroup{
{
GroupId: awssdk.String("sg-0912f63b"),
Tags: []*ec2sdk.Tag{
{Key: awssdk.String("Name"), Value: awssdk.String("sg group one")},
},
},
{
GroupId: awssdk.String("sg-08982de7"),
Tags: []*ec2sdk.Tag{
{Key: awssdk.String("Name"), Value: awssdk.String("sg group two")},
},
},
},
},
Expand All @@ -101,6 +107,50 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) {
"sg-0912f63b",
},
},
{
name: "single name multiple ids",
args: args{
nameOrIDs: []string{
"sg group one",
},
describeSGCalls: []describeSecurityGroupsAsListCall{
{
req: &ec2sdk.DescribeSecurityGroupsInput{
Filters: []*ec2sdk.Filter{
{
Name: awssdk.String("tag:Name"),
Values: awssdk.StringSlice([]string{
"sg group one",
}),
},
{
Name: awssdk.String("vpc-id"),
Values: awssdk.StringSlice([]string{defaultVPCID}),
},
},
},
resp: []*ec2sdk.SecurityGroup{
{
GroupId: awssdk.String("sg-id1"),
Tags: []*ec2sdk.Tag{
{Key: awssdk.String("Name"), Value: awssdk.String("sg group one")},
},
},
{
GroupId: awssdk.String("sg-id2"),
Tags: []*ec2sdk.Tag{
{Key: awssdk.String("Name"), Value: awssdk.String("sg group one")},
},
},
},
},
},
},
want: []string{
"sg-id1",
"sg-id2",
},
},
{
name: "mixed group name and id",
args: args{
Expand All @@ -127,6 +177,9 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) {
resp: []*ec2sdk.SecurityGroup{
{
GroupId: awssdk.String("sg-0912f63b"),
Tags: []*ec2sdk.Tag{
{Key: awssdk.String("Name"), Value: awssdk.String("sg group one")},
},
},
},
},
Expand Down Expand Up @@ -205,13 +258,34 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) {
wantErr: errors.New("Describe.Error: unable to describe security groups"),
},
{
name: "unable to resolve all security groups",
name: "unable to resolve all security group ids",
args: args{
nameOrIDs: []string{
"sg group one",
"sg-id1",
"sg-id404",
},
describeSGCalls: []describeSecurityGroupsAsListCall{
{
req: &ec2sdk.DescribeSecurityGroupsInput{
GroupIds: awssdk.StringSlice([]string{"sg-id1", "sg-id404"}),
},
resp: []*ec2sdk.SecurityGroup{
{
GroupId: awssdk.String("sg-id1"),
},
},
},
},
},
wantErr: errors.New("couldn't find all securityGroups, requested ids: [sg-id1, sg-id404], found: [sg-id1]"),
},
{
name: "unable to resolve all security groups names",
args: args{
nameOrIDs: []string{
"sg group one",
"sg group two",
},
describeSGCalls: []describeSecurityGroupsAsListCall{
{
req: &ec2sdk.DescribeSecurityGroupsInput{
Expand All @@ -220,6 +294,7 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) {
Name: awssdk.String("tag:Name"),
Values: awssdk.StringSlice([]string{
"sg group one",
"sg group two",
}),
},
{
Expand All @@ -231,22 +306,15 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) {
resp: []*ec2sdk.SecurityGroup{
{
GroupId: awssdk.String("sg-0912f63b"),
},
},
},
{
req: &ec2sdk.DescribeSecurityGroupsInput{
GroupIds: awssdk.StringSlice([]string{"sg-id1", "sg-id404"}),
},
resp: []*ec2sdk.SecurityGroup{
{
GroupId: awssdk.String("sg-id1"),
Tags: []*ec2sdk.Tag{
{Key: awssdk.String("Name"), Value: awssdk.String("sg group one")},
},
},
},
},
},
},
wantErr: errors.New("couldn't find all securityGroups, nameOrIDs: [sg group one sg-id1 sg-id404], found: [sg-id1 sg-0912f63b]"),
wantErr: errors.New("couldn't find all securityGroups, requested names: [sg group one, sg group two], found: [sg group one]"),
},
}

Expand Down

0 comments on commit 41f7943

Please sign in to comment.