Skip to content

Commit

Permalink
Set namespace on API if not present (#3953)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxing1292 authored Feb 16, 2023
1 parent 8917ecc commit a23489b
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 8 deletions.
70 changes: 66 additions & 4 deletions common/rpc/interceptor/namespace_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ import (
)

type (
// NamespaceValidatorInterceptor contains LengthValidationIntercept and StateValidationIntercept
TaskTokenGetter interface {
GetTaskToken() []byte
}

// NamespaceValidatorInterceptor contains NamespaceValidateIntercept and StateValidationIntercept
NamespaceValidatorInterceptor struct {
namespaceRegistry namespace.Registry
tokenSerializer common.TaskTokenSerializer
Expand Down Expand Up @@ -71,7 +75,7 @@ var (
)

var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).StateValidationIntercept
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).LengthValidationIntercept
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).NamespaceValidateIntercept

func NewNamespaceValidatorInterceptor(
namespaceRegistry namespace.Registry,
Expand All @@ -86,12 +90,16 @@ func NewNamespaceValidatorInterceptor(
}
}

func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept(
func (ni *NamespaceValidatorInterceptor) NamespaceValidateIntercept(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
err := ni.setNamespaceIfNotPresent(req)
if err != nil {
return nil, err
}
reqWithNamespace, hasNamespace := req.(NamespaceNameGetter)
if hasNamespace {
namespaceName := namespace.Name(reqWithNamespace.GetNamespace())
Expand All @@ -103,6 +111,60 @@ func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept(
return handler(ctx, req)
}

func (ni *NamespaceValidatorInterceptor) setNamespaceIfNotPresent(
req interface{},
) error {
switch request := req.(type) {
case NamespaceNameGetter:
if request.GetNamespace() == "" {
namespaceEntry, err := ni.extractNamespaceFromTaskToken(req)
if err != nil {
return err
}
ni.setNamespace(namespaceEntry, req)
}
return nil
default:
return nil
}
}

func (ni *NamespaceValidatorInterceptor) setNamespace(
namespaceEntry *namespace.Namespace,
req interface{},
) {
switch request := req.(type) {
case *workflowservice.RespondQueryTaskCompletedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondWorkflowTaskCompletedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondWorkflowTaskFailedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RecordActivityTaskHeartbeatRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondActivityTaskCanceledRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondActivityTaskCompletedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondActivityTaskFailedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
}
}

// StateValidationIntercept validates:
// 1. Namespace is specified in task token if there is a `task_token` field.
// 2. Namespace is specified in request if there is a `namespace` field and no `task_token` field.
Expand Down Expand Up @@ -202,7 +264,7 @@ func (ni *NamespaceValidatorInterceptor) extractNamespaceFromRequest(req interfa
}

func (ni *NamespaceValidatorInterceptor) extractNamespaceFromTaskToken(req interface{}) (*namespace.Namespace, error) {
reqWithTaskToken, hasTaskToken := req.(interface{ GetTaskToken() []byte })
reqWithTaskToken, hasTaskToken := req.(TaskTokenGetter)
if !hasTaskToken {
return nil, nil
}
Expand Down
105 changes: 102 additions & 3 deletions common/rpc/interceptor/namespace_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"testing"

"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
enumspb "go.temporal.io/api/enums/v1"
Expand Down Expand Up @@ -684,18 +685,44 @@ func (s *namespaceValidatorSuite) Test_Intercept_SearchAttributeRequests() {
}
}

func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() {
func (s *namespaceValidatorSuite) Test_NamespaceValidateIntercept() {
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(10))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/random",
}
requestNamespace := namespace.FromPersistentState(
&persistence.GetNamespaceResponse{
Namespace: &persistencespb.NamespaceDetail{
Config: &persistencespb.NamespaceConfig{},
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
Info: &persistencespb.NamespaceInfo{
Id: uuid.New().String(),
Name: "namespace",
State: enumspb.NAMESPACE_STATE_REGISTERED,
},
},
})
requestNamespaceTooLong := namespace.FromPersistentState(
&persistence.GetNamespaceResponse{
Namespace: &persistencespb.NamespaceDetail{
Config: &persistencespb.NamespaceConfig{},
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
Info: &persistencespb.NamespaceInfo{
Id: uuid.New().String(),
Name: "namespaceTooLong",
State: enumspb.NAMESPACE_STATE_REGISTERED,
},
},
})
s.mockRegistry.EXPECT().GetNamespace(namespace.Name("namespace")).Return(requestNamespace, nil).AnyTimes()
s.mockRegistry.EXPECT().GetNamespace(namespace.Name("namespaceTooLong")).Return(requestNamespaceTooLong, nil).AnyTimes()

req := &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespace"}
handlerCalled := false
_, err := nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.NamespaceValidateIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
Expand All @@ -704,10 +731,82 @@ func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() {

req = &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespaceTooLong"}
handlerCalled = false
_, err = nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err = nvi.NamespaceValidateIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
s.False(handlerCalled)
s.Error(err)
}

func (s *namespaceValidatorSuite) TestSetNamespace() {
namespaceRequestName := uuid.New().String()
namespaceEntryName := uuid.New().String()
namespaceEntry := namespace.FromPersistentState(
&persistence.GetNamespaceResponse{
Namespace: &persistencespb.NamespaceDetail{
Config: &persistencespb.NamespaceConfig{},
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
Info: &persistencespb.NamespaceInfo{
Id: uuid.New().String(),
Name: namespaceEntryName,
State: enumspb.NAMESPACE_STATE_REGISTERED,
},
},
})

nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(10),
)

queryReq := &workflowservice.RespondQueryTaskCompletedRequest{}
nvi.setNamespace(namespaceEntry, queryReq)
s.Equal(namespaceEntryName, queryReq.Namespace)
queryReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, queryReq)
s.Equal(namespaceRequestName, queryReq.Namespace)

completeWorkflowTaskReq := &workflowservice.RespondWorkflowTaskCompletedRequest{}
nvi.setNamespace(namespaceEntry, completeWorkflowTaskReq)
s.Equal(namespaceEntryName, completeWorkflowTaskReq.Namespace)
completeWorkflowTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, completeWorkflowTaskReq)
s.Equal(namespaceRequestName, completeWorkflowTaskReq.Namespace)

failWorkflowTaskReq := &workflowservice.RespondWorkflowTaskFailedRequest{}
nvi.setNamespace(namespaceEntry, failWorkflowTaskReq)
s.Equal(namespaceEntryName, failWorkflowTaskReq.Namespace)
failWorkflowTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, failWorkflowTaskReq)
s.Equal(namespaceRequestName, failWorkflowTaskReq.Namespace)

heartbeatActivityTaskReq := &workflowservice.RecordActivityTaskHeartbeatRequest{}
nvi.setNamespace(namespaceEntry, heartbeatActivityTaskReq)
s.Equal(namespaceEntryName, heartbeatActivityTaskReq.Namespace)
heartbeatActivityTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, heartbeatActivityTaskReq)
s.Equal(namespaceRequestName, heartbeatActivityTaskReq.Namespace)

cancelActivityTaskReq := &workflowservice.RespondActivityTaskCanceledRequest{}
nvi.setNamespace(namespaceEntry, cancelActivityTaskReq)
s.Equal(namespaceEntryName, cancelActivityTaskReq.Namespace)
cancelActivityTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, cancelActivityTaskReq)
s.Equal(namespaceRequestName, cancelActivityTaskReq.Namespace)

completeActivityTaskReq := &workflowservice.RespondActivityTaskCompletedRequest{}
nvi.setNamespace(namespaceEntry, completeActivityTaskReq)
s.Equal(namespaceEntryName, completeActivityTaskReq.Namespace)
completeActivityTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, completeActivityTaskReq)
s.Equal(namespaceRequestName, completeActivityTaskReq.Namespace)

failActivityTaskReq := &workflowservice.RespondActivityTaskFailedRequest{}
nvi.setNamespace(namespaceEntry, failActivityTaskReq)
s.Equal(namespaceEntryName, failActivityTaskReq.Namespace)
failActivityTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, failActivityTaskReq)
s.Equal(namespaceRequestName, failActivityTaskReq.Namespace)
}
2 changes: 1 addition & 1 deletion service/frontend/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func GrpcServerOptionsProvider(
interceptors := []grpc.UnaryServerInterceptor{
// Service Error Interceptor should be the most outer interceptor on error handling
rpc.ServiceErrorInterceptor,
namespaceValidatorInterceptor.LengthValidationIntercept,
namespaceValidatorInterceptor.NamespaceValidateIntercept,
namespaceLogInterceptor.Intercept, // TODO: Deprecate this with a outer custom interceptor
grpc.UnaryServerInterceptor(traceInterceptor),
metrics.NewServerMetricsContextInjectorInterceptor(),
Expand Down

0 comments on commit a23489b

Please sign in to comment.