From a23489b9606722d69e59fae1767c39269b7bf068 Mon Sep 17 00:00:00 2001 From: wxing1292 Date: Wed, 15 Feb 2023 16:22:38 -0800 Subject: [PATCH] Set namespace on API if not present (#3953) --- common/rpc/interceptor/namespace_validator.go | 70 +++++++++++- .../interceptor/namespace_validator_test.go | 105 +++++++++++++++++- service/frontend/fx.go | 2 +- 3 files changed, 169 insertions(+), 8 deletions(-) diff --git a/common/rpc/interceptor/namespace_validator.go b/common/rpc/interceptor/namespace_validator.go index 85569fe7b57..0036a7c00cc 100644 --- a/common/rpc/interceptor/namespace_validator.go +++ b/common/rpc/interceptor/namespace_validator.go @@ -40,7 +40,11 @@ import ( ) type ( - // NamespaceValidatorInterceptor contains LengthValidationIntercept and StateValidationIntercept + TaskTokenGetter interface { + GetTaskToken() []byte + } + + // NamespaceValidatorInterceptor contains NamespaceValidateIntercept and StateValidationIntercept NamespaceValidatorInterceptor struct { namespaceRegistry namespace.Registry tokenSerializer common.TaskTokenSerializer @@ -71,7 +75,7 @@ var ( ) var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).StateValidationIntercept -var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).LengthValidationIntercept +var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).NamespaceValidateIntercept func NewNamespaceValidatorInterceptor( namespaceRegistry namespace.Registry, @@ -86,12 +90,16 @@ func NewNamespaceValidatorInterceptor( } } -func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept( +func (ni *NamespaceValidatorInterceptor) NamespaceValidateIntercept( ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { + err := ni.setNamespaceIfNotPresent(req) + if err != nil { + return nil, err + } reqWithNamespace, hasNamespace := req.(NamespaceNameGetter) if hasNamespace { namespaceName := namespace.Name(reqWithNamespace.GetNamespace()) @@ -103,6 +111,60 @@ func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept( return handler(ctx, req) } +func (ni *NamespaceValidatorInterceptor) setNamespaceIfNotPresent( + req interface{}, +) error { + switch request := req.(type) { + case NamespaceNameGetter: + if request.GetNamespace() == "" { + namespaceEntry, err := ni.extractNamespaceFromTaskToken(req) + if err != nil { + return err + } + ni.setNamespace(namespaceEntry, req) + } + return nil + default: + return nil + } +} + +func (ni *NamespaceValidatorInterceptor) setNamespace( + namespaceEntry *namespace.Namespace, + req interface{}, +) { + switch request := req.(type) { + case *workflowservice.RespondQueryTaskCompletedRequest: + if request.Namespace == "" { + request.Namespace = namespaceEntry.Name().String() + } + case *workflowservice.RespondWorkflowTaskCompletedRequest: + if request.Namespace == "" { + request.Namespace = namespaceEntry.Name().String() + } + case *workflowservice.RespondWorkflowTaskFailedRequest: + if request.Namespace == "" { + request.Namespace = namespaceEntry.Name().String() + } + case *workflowservice.RecordActivityTaskHeartbeatRequest: + if request.Namespace == "" { + request.Namespace = namespaceEntry.Name().String() + } + case *workflowservice.RespondActivityTaskCanceledRequest: + if request.Namespace == "" { + request.Namespace = namespaceEntry.Name().String() + } + case *workflowservice.RespondActivityTaskCompletedRequest: + if request.Namespace == "" { + request.Namespace = namespaceEntry.Name().String() + } + case *workflowservice.RespondActivityTaskFailedRequest: + if request.Namespace == "" { + request.Namespace = namespaceEntry.Name().String() + } + } +} + // StateValidationIntercept validates: // 1. Namespace is specified in task token if there is a `task_token` field. // 2. Namespace is specified in request if there is a `namespace` field and no `task_token` field. @@ -202,7 +264,7 @@ func (ni *NamespaceValidatorInterceptor) extractNamespaceFromRequest(req interfa } func (ni *NamespaceValidatorInterceptor) extractNamespaceFromTaskToken(req interface{}) (*namespace.Namespace, error) { - reqWithTaskToken, hasTaskToken := req.(interface{ GetTaskToken() []byte }) + reqWithTaskToken, hasTaskToken := req.(TaskTokenGetter) if !hasTaskToken { return nil, nil } diff --git a/common/rpc/interceptor/namespace_validator_test.go b/common/rpc/interceptor/namespace_validator_test.go index 820aac1c733..face0917a6a 100644 --- a/common/rpc/interceptor/namespace_validator_test.go +++ b/common/rpc/interceptor/namespace_validator_test.go @@ -30,6 +30,7 @@ import ( "testing" "github.com/golang/mock/gomock" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" enumspb "go.temporal.io/api/enums/v1" @@ -684,7 +685,7 @@ func (s *namespaceValidatorSuite) Test_Intercept_SearchAttributeRequests() { } } -func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() { +func (s *namespaceValidatorSuite) Test_NamespaceValidateIntercept() { nvi := NewNamespaceValidatorInterceptor( s.mockRegistry, dynamicconfig.GetBoolPropertyFn(false), @@ -692,10 +693,36 @@ func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() { serverInfo := &grpc.UnaryServerInfo{ FullMethod: "/temporal/random", } + requestNamespace := namespace.FromPersistentState( + &persistence.GetNamespaceResponse{ + Namespace: &persistencespb.NamespaceDetail{ + Config: &persistencespb.NamespaceConfig{}, + ReplicationConfig: &persistencespb.NamespaceReplicationConfig{}, + Info: &persistencespb.NamespaceInfo{ + Id: uuid.New().String(), + Name: "namespace", + State: enumspb.NAMESPACE_STATE_REGISTERED, + }, + }, + }) + requestNamespaceTooLong := namespace.FromPersistentState( + &persistence.GetNamespaceResponse{ + Namespace: &persistencespb.NamespaceDetail{ + Config: &persistencespb.NamespaceConfig{}, + ReplicationConfig: &persistencespb.NamespaceReplicationConfig{}, + Info: &persistencespb.NamespaceInfo{ + Id: uuid.New().String(), + Name: "namespaceTooLong", + State: enumspb.NAMESPACE_STATE_REGISTERED, + }, + }, + }) + s.mockRegistry.EXPECT().GetNamespace(namespace.Name("namespace")).Return(requestNamespace, nil).AnyTimes() + s.mockRegistry.EXPECT().GetNamespace(namespace.Name("namespaceTooLong")).Return(requestNamespaceTooLong, nil).AnyTimes() req := &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespace"} handlerCalled := false - _, err := nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) { + _, err := nvi.NamespaceValidateIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) { handlerCalled = true return &workflowservice.StartWorkflowExecutionResponse{}, nil }) @@ -704,10 +731,82 @@ func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() { req = &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespaceTooLong"} handlerCalled = false - _, err = nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) { + _, err = nvi.NamespaceValidateIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) { handlerCalled = true return &workflowservice.StartWorkflowExecutionResponse{}, nil }) s.False(handlerCalled) s.Error(err) } + +func (s *namespaceValidatorSuite) TestSetNamespace() { + namespaceRequestName := uuid.New().String() + namespaceEntryName := uuid.New().String() + namespaceEntry := namespace.FromPersistentState( + &persistence.GetNamespaceResponse{ + Namespace: &persistencespb.NamespaceDetail{ + Config: &persistencespb.NamespaceConfig{}, + ReplicationConfig: &persistencespb.NamespaceReplicationConfig{}, + Info: &persistencespb.NamespaceInfo{ + Id: uuid.New().String(), + Name: namespaceEntryName, + State: enumspb.NAMESPACE_STATE_REGISTERED, + }, + }, + }) + + nvi := NewNamespaceValidatorInterceptor( + s.mockRegistry, + dynamicconfig.GetBoolPropertyFn(false), + dynamicconfig.GetIntPropertyFn(10), + ) + + queryReq := &workflowservice.RespondQueryTaskCompletedRequest{} + nvi.setNamespace(namespaceEntry, queryReq) + s.Equal(namespaceEntryName, queryReq.Namespace) + queryReq.Namespace = namespaceRequestName + nvi.setNamespace(namespaceEntry, queryReq) + s.Equal(namespaceRequestName, queryReq.Namespace) + + completeWorkflowTaskReq := &workflowservice.RespondWorkflowTaskCompletedRequest{} + nvi.setNamespace(namespaceEntry, completeWorkflowTaskReq) + s.Equal(namespaceEntryName, completeWorkflowTaskReq.Namespace) + completeWorkflowTaskReq.Namespace = namespaceRequestName + nvi.setNamespace(namespaceEntry, completeWorkflowTaskReq) + s.Equal(namespaceRequestName, completeWorkflowTaskReq.Namespace) + + failWorkflowTaskReq := &workflowservice.RespondWorkflowTaskFailedRequest{} + nvi.setNamespace(namespaceEntry, failWorkflowTaskReq) + s.Equal(namespaceEntryName, failWorkflowTaskReq.Namespace) + failWorkflowTaskReq.Namespace = namespaceRequestName + nvi.setNamespace(namespaceEntry, failWorkflowTaskReq) + s.Equal(namespaceRequestName, failWorkflowTaskReq.Namespace) + + heartbeatActivityTaskReq := &workflowservice.RecordActivityTaskHeartbeatRequest{} + nvi.setNamespace(namespaceEntry, heartbeatActivityTaskReq) + s.Equal(namespaceEntryName, heartbeatActivityTaskReq.Namespace) + heartbeatActivityTaskReq.Namespace = namespaceRequestName + nvi.setNamespace(namespaceEntry, heartbeatActivityTaskReq) + s.Equal(namespaceRequestName, heartbeatActivityTaskReq.Namespace) + + cancelActivityTaskReq := &workflowservice.RespondActivityTaskCanceledRequest{} + nvi.setNamespace(namespaceEntry, cancelActivityTaskReq) + s.Equal(namespaceEntryName, cancelActivityTaskReq.Namespace) + cancelActivityTaskReq.Namespace = namespaceRequestName + nvi.setNamespace(namespaceEntry, cancelActivityTaskReq) + s.Equal(namespaceRequestName, cancelActivityTaskReq.Namespace) + + completeActivityTaskReq := &workflowservice.RespondActivityTaskCompletedRequest{} + nvi.setNamespace(namespaceEntry, completeActivityTaskReq) + s.Equal(namespaceEntryName, completeActivityTaskReq.Namespace) + completeActivityTaskReq.Namespace = namespaceRequestName + nvi.setNamespace(namespaceEntry, completeActivityTaskReq) + s.Equal(namespaceRequestName, completeActivityTaskReq.Namespace) + + failActivityTaskReq := &workflowservice.RespondActivityTaskFailedRequest{} + nvi.setNamespace(namespaceEntry, failActivityTaskReq) + s.Equal(namespaceEntryName, failActivityTaskReq.Namespace) + failActivityTaskReq.Namespace = namespaceRequestName + nvi.setNamespace(namespaceEntry, failActivityTaskReq) + s.Equal(namespaceRequestName, failActivityTaskReq.Namespace) +} diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 7b26f6effd0..44e50a3ec75 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -178,7 +178,7 @@ func GrpcServerOptionsProvider( interceptors := []grpc.UnaryServerInterceptor{ // Service Error Interceptor should be the most outer interceptor on error handling rpc.ServiceErrorInterceptor, - namespaceValidatorInterceptor.LengthValidationIntercept, + namespaceValidatorInterceptor.NamespaceValidateIntercept, namespaceLogInterceptor.Intercept, // TODO: Deprecate this with a outer custom interceptor grpc.UnaryServerInterceptor(traceInterceptor), metrics.NewServerMetricsContextInjectorInterceptor(),