Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle namespace not found case in redirection interceptor #3947

Merged
merged 6 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/rpc/interceptor/caller_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (i *CallerInfoInterceptor) Intercept(

updateInfo := false
if callerInfo.CallerName == "" {
callerInfo.CallerName = string(GetNamespace(i.namespaceRegistry, req))
callerInfo.CallerName = string(MustGetNamespaceName(i.namespaceRegistry, req))
updateInfo = callerInfo.CallerName != ""
}
if callerInfo.CallerType == "" {
Expand Down
29 changes: 23 additions & 6 deletions common/rpc/interceptor/namespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
package interceptor

import (
"fmt"

"go.temporal.io/api/serviceerror"
"go.temporal.io/server/common/namespace"
)

Expand All @@ -40,28 +43,42 @@ type (
}
)

func GetNamespace(
// MustGetNamespaceName returns request namespace name
// or EmptyName if there's error when retriving namespace name,
// e.g. unable to find namespace
func MustGetNamespaceName(
namespaceRegistry namespace.Registry,
req interface{},
) namespace.Name {
namespaceName, err := GetNamespaceName(namespaceRegistry, req)
if err != nil {
return namespace.EmptyName
}
return namespaceName
}

func GetNamespaceName(
namespaceRegistry namespace.Registry,
req interface{},
) (namespace.Name, error) {
switch request := req.(type) {
case NamespaceNameGetter:
namespaceName := namespace.Name(request.GetNamespace())
_, err := namespaceRegistry.GetNamespace(namespaceName)
if err != nil {
return namespace.EmptyName
return namespace.EmptyName, err
}
return namespaceName
return namespaceName, nil

case NamespaceIDGetter:
namespaceID := namespace.ID(request.GetNamespaceId())
namespaceName, err := namespaceRegistry.GetNamespaceName(namespaceID)
if err != nil {
return namespace.EmptyName
return namespace.EmptyName, err
}
return namespaceName
return namespaceName, nil

default:
return namespace.EmptyName
return namespace.EmptyName, serviceerror.NewInternal(fmt.Sprintf("unable to extract namespace info from request: %+v", req))
}
}
2 changes: 1 addition & 1 deletion common/rpc/interceptor/namespace_count_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (ni *NamespaceCountLimitInterceptor) Intercept(
// token will default to 0
token := ni.tokens[methodName]
if token != 0 {
nsName := GetNamespace(ni.namespaceRegistry, req)
nsName := MustGetNamespaceName(ni.namespaceRegistry, req)
counter := ni.counter(nsName, methodName)
count := atomic.AddInt32(counter, int32(token))
defer atomic.AddInt32(counter, -int32(token))
Expand Down
2 changes: 1 addition & 1 deletion common/rpc/interceptor/namespace_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (nli *NamespaceLogInterceptor) Intercept(

if nli.logger != nil {
_, methodName := SplitMethodName(info.FullMethod)
namespace := GetNamespace(nli.namespaceRegistry, req)
namespace := MustGetNamespaceName(nli.namespaceRegistry, req)
tlsInfo := authorization.TLSInfoFormContext(ctx)
var serverName string
var certThumbprint string
Expand Down
2 changes: 1 addition & 1 deletion common/rpc/interceptor/namespace_rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (ni *NamespaceRateLimitInterceptor) Intercept(
token = NamespaceRateLimitDefaultToken
}

namespace := GetNamespace(ni.namespaceRegistry, req)
namespace := MustGetNamespaceName(ni.namespaceRegistry, req)
if !ni.rateLimiter.Allow(time.Now().UTC(), quotas.NewRequest(
methodName,
token,
Expand Down
2 changes: 1 addition & 1 deletion common/rpc/interceptor/namespace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (s *namespaceSuite) TestGetNamespace() {
}

for _, testCase := range testCases {
extractedNamespace := GetNamespace(register, testCase.method)
extractedNamespace := MustGetNamespaceName(register, testCase.method)
s.Equal(testCase.namespaceName, extractedNamespace)
}
}
2 changes: 1 addition & 1 deletion common/rpc/interceptor/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func (ti *TelemetryInterceptor) metricsHandlerLogTags(

overridedMethodName := ti.overrideOperationTag(fullMethod, methodName, req)

nsName := GetNamespace(ti.namespaceRegistry, req)
nsName := MustGetNamespaceName(ti.namespaceRegistry, req)
if nsName == "" {
return ti.metricsHandler.WithTags(metrics.OperationTag(overridedMethodName), metrics.NamespaceUnknownTag()),
[]tag.Tag{tag.Operation(overridedMethodName)}
Expand Down
5 changes: 4 additions & 1 deletion service/frontend/redirection_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ func (i *RedirectionInterceptor) Intercept(
return i.handleLocalAPIInvocation(ctx, req, handler, methodName)
}
if raFn, ok := globalAPIResponses[methodName]; ok {
namespaceName := interceptor.GetNamespace(i.namespaceCache, req)
namespaceName, err := interceptor.GetNamespaceName(i.namespaceCache, req)
if err != nil {
return nil, err
}
return i.handleRedirectAPIInvocation(ctx, req, info, handler, methodName, raFn, namespaceName)
}

Expand Down
27 changes: 26 additions & 1 deletion service/frontend/redirection_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.temporal.io/api/serviceerror"
"go.temporal.io/api/workflowservice/v1"
"google.golang.org/grpc"

Expand Down Expand Up @@ -258,7 +259,7 @@ func (s *redirectionInterceptorSuite) TestHandleGlobalAPIInvocation_Local() {
s.True(functionInvoked)
}

func (s *redirectionInterceptorSuite) TestHandleLocalAPIInvocation_Redirect() {
func (s *redirectionInterceptorSuite) TestHandleGlobalAPIInvocation_Redirect() {
ctx := context.Background()
req := &workflowservice.SignalWithStartWorkflowExecutionRequest{}
info := &grpc.UnaryServerInfo{
Expand Down Expand Up @@ -300,6 +301,30 @@ func (s *redirectionInterceptorSuite) TestHandleLocalAPIInvocation_Redirect() {
s.IsType(&workflowservice.SignalWithStartWorkflowExecutionResponse{}, resp)
}

func (s *redirectionInterceptorSuite) TestHandleGlobalAPIInvocation_NamespaceNotFound() {
ctx := context.Background()
req := &workflowservice.PollWorkflowTaskQueueRequest{}
info := &grpc.UnaryServerInfo{
FullMethod: "/temporal.api.workflowservice.v1.WorkflowService/PollWorkflowTaskQueue",
}

namespaceName := namespace.Name("unknown_namespace")
s.namespaceCache.EXPECT().GetNamespace(namespaceName).Return(nil, &serviceerror.NamespaceNotFound{}).AnyTimes()
methodName := "PollWorkflowTaskQueue"

resp, err := s.redirector.handleRedirectAPIInvocation(
ctx,
req,
info,
nil,
methodName,
globalAPIResponses[methodName],
namespaceName,
)
s.Nil(resp)
s.IsType(&serviceerror.NamespaceNotFound{}, err)
}

type (
mockClientConnInterface struct {
*suite.Suite
Expand Down
10 changes: 8 additions & 2 deletions tests/activity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ func (s *integrationSuite) TestActivityHeartBeatWorkflow_Success() {
for i := 0; i < 10; i++ {
s.Logger.Info("Heartbeating for activity", tag.WorkflowActivityID(activityID), tag.Counter(i))
_, err := s.engine.RecordActivityTaskHeartbeat(NewContext(), &workflowservice.RecordActivityTaskHeartbeatRequest{
TaskToken: taskToken, Details: payloads.EncodeString("details")})
Namespace: s.namespace,
TaskToken: taskToken,
Details: payloads.EncodeString("details"),
})
s.NoError(err)
time.Sleep(10 * time.Millisecond)
}
Expand Down Expand Up @@ -685,7 +688,10 @@ func (s *integrationSuite) TestTryActivityCancellationFromWorkflow() {
s.Logger.Info("Heartbeating for activity", tag.WorkflowActivityID(activityID), tag.Counter(i))
response, err := s.engine.RecordActivityTaskHeartbeat(NewContext(),
&workflowservice.RecordActivityTaskHeartbeatRequest{
TaskToken: taskToken, Details: payloads.EncodeString("details")})
Namespace: s.namespace,
TaskToken: taskToken,
Details: payloads.EncodeString("details"),
})
if response != nil && response.CancelRequested {
activityCanceled = true
return payloads.EncodeString("Activity Cancelled"), true, nil
Expand Down
15 changes: 14 additions & 1 deletion tests/taskpoller.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ Loop:
if response.Query != nil {
blob, err := p.QueryHandler(response)

completeRequest := &workflowservice.RespondQueryTaskCompletedRequest{TaskToken: response.TaskToken}
completeRequest := &workflowservice.RespondQueryTaskCompletedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
}
if err != nil {
completeType := enumspb.QUERY_RESULT_TYPE_FAILED
completeRequest.CompletedType = completeType
Expand All @@ -252,6 +255,7 @@ Loop:
if err != nil {
p.Logger.Error("Failing workflow task. Workflow messages handler failed with error", tag.Error(err))
_, err = p.Engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Cause: enumspb.WORKFLOW_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE,
Failure: newApplicationFailure(err, false, nil),
Expand All @@ -276,6 +280,7 @@ Loop:
if err != nil {
p.Logger.Error("Failing workflow task. Workflow task handler failed with error", tag.Error(err))
_, err = p.Engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Cause: enumspb.WORKFLOW_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE,
Failure: newApplicationFailure(err, false, nil),
Expand All @@ -293,6 +298,7 @@ Loop:
if !respondStickyTaskQueue {
// non sticky taskqueue
newTask, err := p.Engine.RespondWorkflowTaskCompleted(NewContext(), &workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Identity: p.Identity,
Commands: commands,
Expand All @@ -307,6 +313,7 @@ Loop:
newTask, err := p.Engine.RespondWorkflowTaskCompleted(
NewContext(),
&workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Identity: p.Identity,
Commands: commands,
Expand Down Expand Up @@ -354,6 +361,7 @@ func (p *TaskPoller) HandlePartialWorkflowTask(response *workflowservice.PollWor
if err != nil {
p.Logger.Error("Failing workflow task. Workflow messages handler failed with error", tag.Error(err))
_, err = p.Engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Cause: enumspb.WORKFLOW_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE,
Failure: newApplicationFailure(err, false, nil),
Expand All @@ -368,6 +376,7 @@ func (p *TaskPoller) HandlePartialWorkflowTask(response *workflowservice.PollWor
if err != nil {
p.Logger.Error("Failing workflow task. Workflow task handler failed with error", tag.Error(err))
_, err = p.Engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Cause: enumspb.WORKFLOW_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE,
Failure: newApplicationFailure(err, false, nil),
Expand All @@ -386,6 +395,7 @@ func (p *TaskPoller) HandlePartialWorkflowTask(response *workflowservice.PollWor
newTask, err := p.Engine.RespondWorkflowTaskCompleted(
NewContext(),
&workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Identity: p.Identity,
Commands: commands,
Expand Down Expand Up @@ -437,6 +447,7 @@ retry:
if cancel {
p.Logger.Info("Executing RespondActivityTaskCanceled")
_, err := p.Engine.RespondActivityTaskCanceled(NewContext(), &workflowservice.RespondActivityTaskCanceledRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Details: payloads.EncodeString("details"),
Identity: p.Identity,
Expand All @@ -446,6 +457,7 @@ retry:

if err2 != nil {
_, err := p.Engine.RespondActivityTaskFailed(NewContext(), &workflowservice.RespondActivityTaskFailedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Failure: newApplicationFailure(err2, false, nil),
Identity: p.Identity,
Expand All @@ -454,6 +466,7 @@ retry:
}

_, err = p.Engine.RespondActivityTaskCompleted(NewContext(), &workflowservice.RespondActivityTaskCompletedRequest{
Namespace: p.Namespace,
TaskToken: response.TaskToken,
Identity: p.Identity,
Result: result,
Expand Down
11 changes: 11 additions & 0 deletions tests/workflow_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func (s *integrationSuite) TestWorkflowTaskHeartbeatingWithEmptyResult() {
hbTimeout := 0
for i := 0; i < 12; i++ {
resp2, err2 := s.engine.RespondWorkflowTaskCompleted(NewContext(), &workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: s.namespace,
TaskToken: taskToken,
Commands: []*commandpb.Command{},
StickyAttributes: &taskqueuepb.StickyExecutionAttributes{
Expand Down Expand Up @@ -125,6 +126,7 @@ func (s *integrationSuite) TestWorkflowTaskHeartbeatingWithEmptyResult() {
s.Equal(2, hbTimeout)

resp5, err5 := s.engine.RespondWorkflowTaskCompleted(NewContext(), &workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: s.namespace,
TaskToken: taskToken,
Commands: []*commandpb.Command{
{
Expand Down Expand Up @@ -198,6 +200,7 @@ func (s *integrationSuite) TestWorkflowTaskHeartbeatingWithLocalActivitiesResult
s.assertLastHistoryEvent(we, 3, enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED)

resp2, err2 := s.engine.RespondWorkflowTaskCompleted(NewContext(), &workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: s.namespace,
TaskToken: resp1.GetTaskToken(),
Commands: []*commandpb.Command{},
StickyAttributes: &taskqueuepb.StickyExecutionAttributes{
Expand All @@ -210,6 +213,7 @@ func (s *integrationSuite) TestWorkflowTaskHeartbeatingWithLocalActivitiesResult
s.NoError(err2)

resp3, err3 := s.engine.RespondWorkflowTaskCompleted(NewContext(), &workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: s.namespace,
TaskToken: resp2.WorkflowTask.GetTaskToken(),
Commands: []*commandpb.Command{
{
Expand All @@ -231,6 +235,7 @@ func (s *integrationSuite) TestWorkflowTaskHeartbeatingWithLocalActivitiesResult
s.NoError(err3)

resp4, err4 := s.engine.RespondWorkflowTaskCompleted(NewContext(), &workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: s.namespace,
TaskToken: resp3.WorkflowTask.GetTaskToken(),
Commands: []*commandpb.Command{
{
Expand All @@ -252,6 +257,7 @@ func (s *integrationSuite) TestWorkflowTaskHeartbeatingWithLocalActivitiesResult
s.NoError(err4)

resp5, err5 := s.engine.RespondWorkflowTaskCompleted(NewContext(), &workflowservice.RespondWorkflowTaskCompletedRequest{
Namespace: s.namespace,
TaskToken: resp4.WorkflowTask.GetTaskToken(),
Commands: []*commandpb.Command{
{
Expand Down Expand Up @@ -496,6 +502,7 @@ func (s *integrationSuite) TestWorkflowTerminationSignalAfterRegularWorkflowTask

// fail this workflow task to flush buffer, and then another workflow task will be scheduled
_, err2 := s.engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: s.namespace,
TaskToken: resp1.GetTaskToken(),
Cause: cause,
Identity: "integ test",
Expand Down Expand Up @@ -573,6 +580,7 @@ func (s *integrationSuite) TestWorkflowTerminationSignalBeforeTransientWorkflowT
}

_, err2 := s.engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: s.namespace,
TaskToken: resp1.GetTaskToken(),
Cause: cause,
Identity: "integ test",
Expand Down Expand Up @@ -676,6 +684,7 @@ func (s *integrationSuite) TestWorkflowTerminationSignalAfterTransientWorkflowTa
}

_, err2 := s.engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: s.namespace,
TaskToken: resp1.GetTaskToken(),
Cause: cause,
Identity: "integ test",
Expand Down Expand Up @@ -776,6 +785,7 @@ func (s *integrationSuite) TestWorkflowTerminationSignalAfterTransientWorkflowTa
}

_, err2 := s.engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: s.namespace,
TaskToken: resp1.GetTaskToken(),
Cause: cause,
Identity: "integ test",
Expand Down Expand Up @@ -809,6 +819,7 @@ func (s *integrationSuite) TestWorkflowTerminationSignalAfterTransientWorkflowTa

// fail this workflow task to flush buffer
_, err2 := s.engine.RespondWorkflowTaskFailed(NewContext(), &workflowservice.RespondWorkflowTaskFailedRequest{
Namespace: s.namespace,
TaskToken: resp1.GetTaskToken(),
Cause: cause,
Identity: "integ test",
Expand Down