Skip to content

Commit

Permalink
refactor backend SG provider (#2836)
Browse files Browse the repository at this point in the history
* refactor backend SG provider

* fix ExtractIngresses array append

* make classifiedIngress type satisfy ObjectMetaAccessor

* refactor backend SG provider apis
  • Loading branch information
kishorj authored Apr 24, 2023
1 parent 4cf7c33 commit ff8c13d
Show file tree
Hide file tree
Showing 15 changed files with 947 additions and 125 deletions.
24 changes: 14 additions & 10 deletions controllers/ingress/group_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
corev1 "k8s.io/api/core/v1"
networking "k8s.io/api/networking/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/record"
elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1"
Expand Down Expand Up @@ -45,7 +46,8 @@ const (
func NewGroupReconciler(cloud aws.Cloud, k8sClient client.Client, eventRecorder record.EventRecorder,
finalizerManager k8s.FinalizerManager, networkingSGManager networkingpkg.SecurityGroupManager,
networkingSGReconciler networkingpkg.SecurityGroupReconciler, subnetsResolver networkingpkg.SubnetsResolver,
controllerConfig config.ControllerConfig, backendSGProvider networkingpkg.BackendSGProvider, logger logr.Logger) *groupReconciler {
controllerConfig config.ControllerConfig, backendSGProvider networkingpkg.BackendSGProvider,
sgResolver networkingpkg.SecurityGroupResolver, logger logr.Logger) *groupReconciler {

annotationParser := annotations.NewSuffixAnnotationParser(annotations.AnnotationPrefixIngress)
authConfigBuilder := ingress.NewDefaultAuthConfigBuilder(annotationParser)
Expand All @@ -58,7 +60,7 @@ func NewGroupReconciler(cloud aws.Cloud, k8sClient client.Client, eventRecorder
annotationParser, subnetsResolver,
authConfigBuilder, enhancedBackendBuilder, trackingProvider, elbv2TaggingManager, controllerConfig.FeatureGates,
cloud.VpcID(), controllerConfig.ClusterName, controllerConfig.DefaultTags, controllerConfig.ExternalManagedTags,
controllerConfig.DefaultSSLPolicy, controllerConfig.DefaultTargetType, backendSGProvider,
controllerConfig.DefaultSSLPolicy, controllerConfig.DefaultTargetType, backendSGProvider, sgResolver,
controllerConfig.EnableBackendSecurityGroup, controllerConfig.DisableRestrictedSGRules, controllerConfig.FeatureGates.Enabled(config.EnableIPTargetType), logger)
stackMarshaller := deploy.NewDefaultStackMarshaller()
stackDeployer := deploy.NewDefaultStackDeployer(cloud, k8sClient, networkingSGManager, networkingSGReconciler,
Expand Down Expand Up @@ -144,12 +146,6 @@ func (r *groupReconciler) reconcile(ctx context.Context, req ctrl.Request) error
}
}

if len(ingGroup.Members) == 0 {
if err := r.backendSGProvider.Release(ctx); err != nil {
return err
}
}

if len(ingGroup.InactiveMembers) > 0 {
if err := r.groupFinalizerManager.RemoveGroupFinalizer(ctx, ingGroupID, ingGroup.InactiveMembers); err != nil {
r.recordIngressGroupEvent(ctx, ingGroup, corev1.EventTypeWarning, k8s.IngressEventReasonFailedRemoveFinalizer, fmt.Sprintf("Failed remove finalizer due to %v", err))
Expand All @@ -162,7 +158,7 @@ func (r *groupReconciler) reconcile(ctx context.Context, req ctrl.Request) error
}

func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingress.Group) (core.Stack, *elbv2model.LoadBalancer, error) {
stack, lb, secrets, err := r.modelBuilder.Build(ctx, ingGroup)
stack, lb, secrets, backendSGRequired, err := r.modelBuilder.Build(ctx, ingGroup)
if err != nil {
r.recordIngressGroupEvent(ctx, ingGroup, corev1.EventTypeWarning, k8s.IngressEventReasonFailedBuildModel, fmt.Sprintf("Failed build model due to %v", err))
return nil, nil, err
Expand All @@ -180,7 +176,15 @@ func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingr
}
r.logger.Info("successfully deployed model", "ingressGroup", ingGroup.ID)
r.secretsManager.MonitorSecrets(ingGroup.ID.String(), secrets)
return stack, lb, err
var inactiveResources []types.NamespacedName
inactiveResources = append(inactiveResources, k8s.ToSliceOfNamespacedNames(ingGroup.InactiveMembers)...)
if !backendSGRequired {
inactiveResources = append(inactiveResources, k8s.ToSliceOfNamespacedNames(ingGroup.Members)...)
}
if err := r.backendSGProvider.Release(ctx, networkingpkg.ResourceTypeIngress, inactiveResources); err != nil {
return nil, nil, err
}
return stack, lb, nil
}

func (r *groupReconciler) recordIngressGroupEvent(_ context.Context, ingGroup ingress.Group, eventType string, reason string, message string) {
Expand Down
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ func main() {
mgr.GetEventRecorderFor("targetGroupBinding"), ctrl.Log)
backendSGProvider := networking.NewBackendSGProvider(controllerCFG.ClusterName, controllerCFG.BackendSecurityGroup,
cloud.VpcID(), cloud.EC2(), mgr.GetClient(), controllerCFG.DefaultTags, ctrl.Log.WithName("backend-sg-provider"))
sgResolver := networking.NewDefaultSecurityGroupResolver(cloud.EC2(), cloud.VpcID())
ingGroupReconciler := ingress.NewGroupReconciler(cloud, mgr.GetClient(), mgr.GetEventRecorderFor("ingress"),
finalizerManager, sgManager, sgReconciler, subnetResolver,
controllerCFG, backendSGProvider, ctrl.Log.WithName("controllers").WithName("ingress"))
controllerCFG, backendSGProvider, sgResolver, ctrl.Log.WithName("controllers").WithName("ingress"))
svcReconciler := service.NewServiceReconciler(cloud, mgr.GetClient(), mgr.GetEventRecorderFor("service"),
finalizerManager, sgManager, sgReconciler, subnetResolver, vpcInfoProvider,
controllerCFG, ctrl.Log.WithName("controllers").WithName("service"))
Expand Down
3 changes: 2 additions & 1 deletion pkg/deploy/elbv2/listener_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package elbv2

import (
"context"
"time"

awssdk "github.com/aws/aws-sdk-go/aws"
elbv2sdk "github.com/aws/aws-sdk-go/service/elbv2"
"github.com/go-logr/logr"
Expand All @@ -15,7 +17,6 @@ import (
elbv2equality "sigs.k8s.io/aws-load-balancer-controller/pkg/equality/elbv2"
elbv2model "sigs.k8s.io/aws-load-balancer-controller/pkg/model/elbv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/runtime"
"time"
)

// ListenerManager is responsible for create/update/delete Listener resources.
Expand Down
5 changes: 5 additions & 0 deletions pkg/ingress/class.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ingress

import (
networking "k8s.io/api/networking/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1"
)

Expand All @@ -19,3 +20,7 @@ type ClassConfiguration struct {
// The IngressClassParams for Ingress if any.
IngClassParams *elbv2api.IngressClassParams
}

func (c ClassifiedIngress) GetObjectMeta() metav1.Object {
return c.Ing
}
59 changes: 5 additions & 54 deletions pkg/ingress/model_build_load_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/hex"
"fmt"
"regexp"
"strings"

awssdk "github.com/aws/aws-sdk-go/aws"
ec2sdk "github.com/aws/aws-sdk-go/service/ec2"
Expand Down Expand Up @@ -284,11 +283,12 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont
if !t.enableBackendSG {
t.backendSGIDToken = managedSG.GroupID()
} else {
backendSGID, err := t.backendSGProvider.Get(ctx)
backendSGID, err := t.backendSGProvider.Get(ctx, networking.ResourceTypeIngress, k8s.ToSliceOfNamespacedNames(t.ingGroup.Members))
if err != nil {
return nil, err
}
t.backendSGIDToken = core.LiteralStringToken((backendSGID))
t.backendSGAllocated = true
lbSGTokens = append(lbSGTokens, t.backendSGIDToken)
}
t.logger.Info("Auto Create SG", "LB SGs", lbSGTokens, "backend SG", t.backendSGIDToken)
Expand All @@ -297,7 +297,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont
if err != nil {
return nil, err
}
frontendSGIDs, err := t.resolveSecurityGroupIDsViaNameOrIDSlice(ctx, sgNameOrIDsViaAnnotation)
frontendSGIDs, err := t.sgResolver.ResolveViaNameOrID(ctx, sgNameOrIDsViaAnnotation)
if err != nil {
return nil, err
}
Expand All @@ -309,11 +309,12 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont
if !t.enableBackendSG {
return nil, errors.New("backendSG feature is required to manage worker node SG rules when frontendSG manually specified")
}
backendSGID, err := t.backendSGProvider.Get(ctx)
backendSGID, err := t.backendSGProvider.Get(ctx, networking.ResourceTypeIngress, k8s.ToSliceOfNamespacedNames(t.ingGroup.Members))
if err != nil {
return nil, err
}
t.backendSGIDToken = core.LiteralStringToken(backendSGID)
t.backendSGAllocated = true
lbSGTokens = append(lbSGTokens, t.backendSGIDToken)
}
t.logger.Info("SG configured via annotation", "LB SGs", lbSGTokens, "backend SG", t.backendSGIDToken)
Expand Down Expand Up @@ -390,56 +391,6 @@ func (t *defaultModelBuildTask) buildLoadBalancerTags(_ context.Context) (map[st
return algorithm.MergeStringMap(t.defaultTags, ingGroupTags), nil
}

func (t *defaultModelBuildTask) resolveSecurityGroupIDsViaNameOrIDSlice(ctx context.Context, sgNameOrIDs []string) ([]string, error) {
var sgIDs []string
var sgNames []string
for _, nameOrID := range sgNameOrIDs {
if strings.HasPrefix(nameOrID, "sg-") {
sgIDs = append(sgIDs, nameOrID)
} else {
sgNames = append(sgNames, nameOrID)
}
}
var resolvedSGs []*ec2sdk.SecurityGroup
if len(sgIDs) > 0 {
req := &ec2sdk.DescribeSecurityGroupsInput{
GroupIds: awssdk.StringSlice(sgIDs),
}
sgs, err := t.ec2Client.DescribeSecurityGroupsAsList(ctx, req)
if err != nil {
return nil, err
}
resolvedSGs = append(resolvedSGs, sgs...)
}
if len(sgNames) > 0 {
req := &ec2sdk.DescribeSecurityGroupsInput{
Filters: []*ec2sdk.Filter{
{
Name: awssdk.String("tag:Name"),
Values: awssdk.StringSlice(sgNames),
},
{
Name: awssdk.String("vpc-id"),
Values: awssdk.StringSlice([]string{t.vpcID}),
},
},
}
sgs, err := t.ec2Client.DescribeSecurityGroupsAsList(ctx, req)
if err != nil {
return nil, err
}
resolvedSGs = append(resolvedSGs, sgs...)
}
resolvedSGIDs := make([]string, 0, len(resolvedSGs))
for _, sg := range resolvedSGs {
resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId))
}
if len(resolvedSGIDs) != len(sgNameOrIDs) {
return nil, errors.Errorf("couldn't find all securityGroups, nameOrIDs: %v, found: %v", sgNameOrIDs, resolvedSGIDs)
}
return resolvedSGIDs, nil
}

func buildLoadBalancerSubnetMappingsWithSubnets(subnets []*ec2sdk.Subnet) []elbv2model.SubnetMapping {
subnetMappings := make([]elbv2model.SubnetMapping, 0, len(subnets))
for _, subnet := range subnets {
Expand Down
16 changes: 11 additions & 5 deletions pkg/ingress/model_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const (
// ModelBuilder is responsible for build mode stack for a IngressGroup.
type ModelBuilder interface {
// build mode stack for a IngressGroup.
Build(ctx context.Context, ingGroup Group) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, error)
Build(ctx context.Context, ingGroup Group) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, bool, error)
}

// NewDefaultModelBuilder constructs new defaultModelBuilder.
Expand All @@ -42,7 +42,8 @@ func NewDefaultModelBuilder(k8sClient client.Client, eventRecorder record.EventR
authConfigBuilder AuthConfigBuilder, enhancedBackendBuilder EnhancedBackendBuilder,
trackingProvider tracking.Provider, elbv2TaggingManager elbv2deploy.TaggingManager, featureGates config.FeatureGates,
vpcID string, clusterName string, defaultTags map[string]string, externalManagedTags []string, defaultSSLPolicy string, defaultTargetType string,
backendSGProvider networkingpkg.BackendSGProvider, enableBackendSG bool, disableRestrictedSGRules bool, enableIPTargetType bool, logger logr.Logger) *defaultModelBuilder {
backendSGProvider networkingpkg.BackendSGProvider, sgResolver networkingpkg.SecurityGroupResolver,
enableBackendSG bool, disableRestrictedSGRules bool, enableIPTargetType bool, logger logr.Logger) *defaultModelBuilder {
certDiscovery := NewACMCertDiscovery(acmClient, logger)
ruleOptimizer := NewDefaultRuleOptimizer(logger)
return &defaultModelBuilder{
Expand All @@ -54,6 +55,7 @@ func NewDefaultModelBuilder(k8sClient client.Client, eventRecorder record.EventR
annotationParser: annotationParser,
subnetsResolver: subnetsResolver,
backendSGProvider: backendSGProvider,
sgResolver: sgResolver,
certDiscovery: certDiscovery,
authConfigBuilder: authConfigBuilder,
enhancedBackendBuilder: enhancedBackendBuilder,
Expand Down Expand Up @@ -86,6 +88,7 @@ type defaultModelBuilder struct {
annotationParser annotations.Parser
subnetsResolver networkingpkg.SubnetsResolver
backendSGProvider networkingpkg.BackendSGProvider
sgResolver networkingpkg.SecurityGroupResolver
certDiscovery CertDiscovery
authConfigBuilder AuthConfigBuilder
enhancedBackendBuilder EnhancedBackendBuilder
Expand All @@ -105,7 +108,7 @@ type defaultModelBuilder struct {
}

// build mode stack for a IngressGroup.
func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, error) {
func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, bool, error) {
stack := core.NewDefaultStack(core.StackID(ingGroup.ID))
task := &defaultModelBuildTask{
k8sClient: b.k8sClient,
Expand All @@ -123,6 +126,7 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group) (core.S
elbv2TaggingManager: b.elbv2TaggingManager,
featureGates: b.featureGates,
backendSGProvider: b.backendSGProvider,
sgResolver: b.sgResolver,
logger: b.logger,
enableBackendSG: b.enableBackendSG,
disableRestrictedSGRules: b.disableRestrictedSGRules,
Expand Down Expand Up @@ -153,9 +157,9 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group) (core.S
backendServices: make(map[types.NamespacedName]*corev1.Service),
}
if err := task.run(ctx); err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
return task.stack, task.loadBalancer, task.secretKeys, nil
return task.stack, task.loadBalancer, task.secretKeys, task.backendSGAllocated, nil
}

// the default model build task
Expand All @@ -168,6 +172,7 @@ type defaultModelBuildTask struct {
annotationParser annotations.Parser
subnetsResolver networkingpkg.SubnetsResolver
backendSGProvider networkingpkg.BackendSGProvider
sgResolver networkingpkg.SecurityGroupResolver
certDiscovery CertDiscovery
authConfigBuilder AuthConfigBuilder
enhancedBackendBuilder EnhancedBackendBuilder
Expand All @@ -181,6 +186,7 @@ type defaultModelBuildTask struct {
sslRedirectConfig *SSLRedirectConfig
stack core.Stack
backendSGIDToken core.StringToken
backendSGAllocated bool
enableBackendSG bool
disableRestrictedSGRules bool
enableIPTargetType bool
Expand Down
10 changes: 6 additions & 4 deletions pkg/ingress/model_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2920,13 +2920,14 @@ func Test_defaultModelBuilder_Build(t *testing.T) {
trackingProvider := tracking.NewDefaultProvider("ingress.k8s.aws", clusterName)
stackMarshaller := deploy.NewDefaultStackMarshaller()
backendSGProvider := networkingpkg.NewMockBackendSGProvider(ctrl)
sgResolver := networkingpkg.NewDefaultSecurityGroupResolver(ec2Client, vpcID)
if tt.fields.enableBackendSG {
if len(tt.fields.backendSecurityGroup) > 0 {
backendSGProvider.EXPECT().Get(gomock.Any()).Return(tt.fields.backendSecurityGroup, nil).AnyTimes()
backendSGProvider.EXPECT().Get(gomock.Any(), networkingpkg.ResourceType(networkingpkg.ResourceTypeIngress), gomock.Any()).Return(tt.fields.backendSecurityGroup, nil).AnyTimes()
} else {
backendSGProvider.EXPECT().Get(gomock.Any()).Return("sg-auto", nil).AnyTimes()
backendSGProvider.EXPECT().Get(gomock.Any(), networkingpkg.ResourceType(networkingpkg.ResourceTypeIngress), gomock.Any()).Return("sg-auto", nil).AnyTimes()
}
backendSGProvider.EXPECT().Release(gomock.Any()).Return(nil).AnyTimes()
backendSGProvider.EXPECT().Release(gomock.Any(), networkingpkg.ResourceType(networkingpkg.ResourceTypeIngress), gomock.Any()).Return(nil).AnyTimes()
}
defaultTargetType := tt.defaultTargetType
if defaultTargetType == "" {
Expand All @@ -2941,6 +2942,7 @@ func Test_defaultModelBuilder_Build(t *testing.T) {
clusterName: clusterName,
annotationParser: annotationParser,
subnetsResolver: subnetsResolver,
sgResolver: sgResolver,
backendSGProvider: backendSGProvider,
certDiscovery: certDiscovery,
authConfigBuilder: authConfigBuilder,
Expand All @@ -2962,7 +2964,7 @@ func Test_defaultModelBuilder_Build(t *testing.T) {
b.enableIPTargetType = *tt.enableIPTargetType
}

gotStack, _, _, err := b.Build(context.Background(), tt.args.ingGroup)
gotStack, _, _, _, err := b.Build(context.Background(), tt.args.ingGroup)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
} else {
Expand Down
9 changes: 9 additions & 0 deletions pkg/k8s/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,12 @@ func NamespacedName(obj metav1.Object) types.NamespacedName {
Name: obj.GetName(),
}
}

// ToSliceOfNamespacedNames gets the slice of types.NamespacedName from the input slice s
func ToSliceOfNamespacedNames[T metav1.ObjectMetaAccessor](s []T) []types.NamespacedName {
result := make([]types.NamespacedName, len(s))
for i, v := range s {
result[i] = NamespacedName(v.GetObjectMeta())
}
return result
}
Loading

0 comments on commit ff8c13d

Please sign in to comment.