Skip to content

Commit

Permalink
Messages protocol implementation (#3843)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexshtin authored Jan 25, 2023
1 parent e6113da commit 430d3a9
Show file tree
Hide file tree
Showing 14 changed files with 856 additions and 583 deletions.
818 changes: 451 additions & 367 deletions api/historyservice/v1/request_response.pb.go

Large diffs are not rendered by default.

410 changes: 247 additions & 163 deletions api/matchingservice/v1/request_response.pb.go

Large diffs are not rendered by default.

32 changes: 8 additions & 24 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ import (

"github.com/dgryski/go-farm"
"github.com/gogo/protobuf/proto"
commandpb "go.temporal.io/api/command/v1"
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
historypb "go.temporal.io/api/history/v1"
"go.temporal.io/api/serviceerror"
"go.temporal.io/api/workflowservice/v1"

Expand Down Expand Up @@ -388,31 +386,16 @@ func WorkflowIDToHistoryShard(
return int32(hash%uint32(numberOfShards)) + 1 // ShardID starts with 1
}

// PrettyPrintHistory prints history in human-readable format
func PrettyPrintHistory(history *historypb.History, header ...string) {
func PrettyPrint[T proto.Message](msgs []T, header ...string) {
var sb strings.Builder
sb.WriteString("==========================================================================\n")
_, _ = sb.WriteString("==========================================================================\n")
for _, h := range header {
sb.WriteString(h)
sb.WriteString("\n")
_, _ = sb.WriteString(h)
_, _ = sb.WriteString("\n")
}
sb.WriteString("--------------------------------------------------------------------------\n")
_ = proto.MarshalText(&sb, history)
sb.WriteString("\n")
fmt.Print(sb.String())
}

// PrettyPrintCommands prints commands in human-readable format
func PrettyPrintCommands(commands []*commandpb.Command, header ...string) {
var sb strings.Builder
sb.WriteString("==========================================================================\n")
for _, h := range header {
sb.WriteString(h)
sb.WriteString("\n")
}
sb.WriteString("--------------------------------------------------------------------------\n")
for _, command := range commands {
_ = proto.MarshalText(&sb, command)
_, _ = sb.WriteString("--------------------------------------------------------------------------\n")
for _, m := range msgs {
_ = proto.MarshalText(&sb, m)
}
fmt.Print(sb.String())
}
Expand Down Expand Up @@ -465,6 +448,7 @@ func CreateMatchingPollWorkflowTaskQueueResponse(historyResponse *historyservice
ScheduledTime: historyResponse.ScheduledTime,
StartedTime: historyResponse.StartedTime,
Queries: historyResponse.Queries,
Messages: historyResponse.Messages,
}

return matchingResp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import "temporal/api/taskqueue/v1/message.proto";
import "temporal/api/enums/v1/workflow.proto";
import "temporal/api/workflow/v1/message.proto";
import "temporal/api/query/v1/message.proto";
import "temporal/api/protocol/v1/message.proto";
import "temporal/api/failure/v1/message.proto";

import "temporal/server/api/clock/v1/message.proto";
Expand Down Expand Up @@ -164,6 +165,7 @@ message RecordWorkflowTaskStartedResponse {
google.protobuf.Timestamp started_time = 13 [(gogoproto.stdtime) = true];
map<string, temporal.api.query.v1.WorkflowQuery> queries = 14;
temporal.server.api.clock.v1.VectorClock clock = 15;
repeated temporal.api.protocol.v1.Message messages = 16;
}

message RecordActivityTaskStartedRequest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import "temporal/api/common/v1/message.proto";
import "temporal/api/enums/v1/task_queue.proto";
import "temporal/api/taskqueue/v1/message.proto";
import "temporal/api/query/v1/message.proto";
import "temporal/api/protocol/v1/message.proto";

import "temporal/server/api/clock/v1/message.proto";
import "temporal/server/api/enums/v1/task.proto";
Expand Down Expand Up @@ -65,6 +66,7 @@ message PollWorkflowTaskQueueResponse {
google.protobuf.Timestamp scheduled_time = 15 [(gogoproto.stdtime) = true];
google.protobuf.Timestamp started_time = 16 [(gogoproto.stdtime) = true];
map<string, temporal.api.query.v1.WorkflowQuery> queries = 17;
repeated temporal.api.protocol.v1.Message messages = 18;
}

message PollActivityTaskQueueRequest {
Expand Down
1 change: 1 addition & 0 deletions service/frontend/workflow_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4433,6 +4433,7 @@ func (wh *WorkflowHandler) createPollWorkflowTaskQueueResponse(
ScheduledTime: matchingResp.ScheduledTime,
StartedTime: matchingResp.StartedTime,
Queries: matchingResp.Queries,
Messages: matchingResp.Messages,
}

return resp, nil
Expand Down
9 changes: 9 additions & 0 deletions service/history/commandChecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
commandpb "go.temporal.io/api/command/v1"
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
protocolpb "go.temporal.io/api/protocol/v1"
"go.temporal.io/api/serviceerror"
taskqueuepb "go.temporal.io/api/taskqueue/v1"

Expand Down Expand Up @@ -898,3 +899,11 @@ func (v *commandAttrValidator) commandTypes(
}
return result
}

// TODO (alex-update): move to messageValidator.
func (v *commandAttrValidator) validateMessages(
_ []*protocolpb.Message,
) error {

return nil
}
26 changes: 26 additions & 0 deletions service/history/workflowTaskHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
failurepb "go.temporal.io/api/failure/v1"
protocolpb "go.temporal.io/api/protocol/v1"
"go.temporal.io/api/serviceerror"
"go.temporal.io/api/workflowservice/v1"

Expand Down Expand Up @@ -239,6 +240,31 @@ func (handler *workflowTaskHandlerImpl) handleCommand(ctx context.Context, comma
}
}

func (handler *workflowTaskHandlerImpl) handleMessages(
ctx context.Context,
messages []*protocolpb.Message,
) error {
if err := handler.attrValidator.validateMessages(
messages,
); err != nil {
return err
}

for _, message := range messages {
err := handler.handleMessage(ctx, message)
if err != nil || handler.stopProcessing {
return err
}
}

return nil
}

func (handler *workflowTaskHandlerImpl) handleMessage(_ context.Context, _ *protocolpb.Message) error {

return nil
}

func (handler *workflowTaskHandlerImpl) handleCommandScheduleActivity(
_ context.Context,
attr *commandpb.ScheduleActivityTaskCommandAttributes,
Expand Down
44 changes: 31 additions & 13 deletions service/history/workflowTaskHandlerCallbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,11 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
metrics.OperationTag(metrics.HistoryRespondWorkflowTaskCompletedScope))
}

workflowTaskHeartbeating := request.GetForceCreateNewWorkflowTask() && len(request.Commands) == 0
workflowTaskHeartbeating := request.GetForceCreateNewWorkflowTask() && len(request.Commands) == 0 && len(request.Messages) == 0
var workflowTaskHeartbeatTimeout bool
var completedEvent *historypb.HistoryEvent
var responseMutations []workflowTaskResponseMutation

if workflowTaskHeartbeating {
namespace := namespaceEntry.Name()
timeout := handler.config.WorkflowTaskHeartbeatTimeout(namespace.String())
Expand Down Expand Up @@ -423,11 +425,8 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
wtFailedCause *workflowTaskFailedCause
activityNotStartedCancelled bool
newMutableState workflow.MutableState

hasUnhandledEvents bool
responseMutations []workflowTaskResponseMutation
)
hasUnhandledEvents = ms.HasBufferedEvents()
hasBufferedEvents := ms.HasBufferedEvents()

if request.StickyAttributes == nil || request.StickyAttributes.WorkerTaskQueue == nil {
handler.metricsHandler.Counter(metrics.CompleteWorkflowTaskWithStickyDisabledCounter.GetMetricName()).Record(
Expand Down Expand Up @@ -481,7 +480,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
handler.config,
handler.shard,
handler.searchAttributesMapper,
hasUnhandledEvents,
hasBufferedEvents,
)

if responseMutations, err = workflowTaskHandler.handleCommands(
Expand All @@ -491,6 +490,13 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
return nil, err
}

if err = workflowTaskHandler.handleMessages(
ctx,
request.Messages,
); err != nil {
return nil, err
}

// set the vars used by following logic
// further refactor should also clean up the vars used below
wtFailedCause = workflowTaskHandler.workflowTaskFailedCause
Expand All @@ -501,7 +507,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(

newMutableState = workflowTaskHandler.newMutableState

hasUnhandledEvents = workflowTaskHandler.hasBufferedEvents
hasBufferedEvents = workflowTaskHandler.hasBufferedEvents
}

if wtFailedCause != nil {
Expand All @@ -522,7 +528,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
if err != nil {
return nil, err
}
hasUnhandledEvents = true
hasBufferedEvents = true
newMutableState = nil

if wtFailedCause.workflowFailure != nil {
Expand All @@ -532,24 +538,37 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
if _, err := ms.AddFailWorkflowEvent(nextEventBatchId, enumspb.RETRY_STATE_NON_RETRYABLE_FAILURE, attributes, ""); err != nil {
return nil, err
}
hasUnhandledEvents = false
hasBufferedEvents = false
}
}

createNewWorkflowTask := ms.IsWorkflowExecutionRunning() && (hasUnhandledEvents || request.GetForceCreateNewWorkflowTask() || activityNotStartedCancelled)
newWorkflowTaskType := enumsspb.WORKFLOW_TASK_TYPE_UNSPECIFIED
if ms.IsWorkflowExecutionRunning() && (hasBufferedEvents || request.GetForceCreateNewWorkflowTask() || activityNotStartedCancelled) {
newWorkflowTaskType = enumsspb.WORKFLOW_TASK_TYPE_NORMAL
}
createNewWorkflowTask := newWorkflowTaskType != enumsspb.WORKFLOW_TASK_TYPE_UNSPECIFIED

var newWorkflowTaskScheduledEventID int64
if createNewWorkflowTask {
// TODO (alex-update): Need to support case when ReturnNewWorkflowTask=false and WT.Type=Speculative.
// In this case WT needs to be added directly to matching.
// Current implementation will create normal WT.
bypassTaskGeneration := request.GetReturnNewWorkflowTask() && wtFailedCause == nil
if !bypassTaskGeneration {
// If task generation can't be bypassed workflow task must be of Normal type because Speculative workflow task always skip task generation.
newWorkflowTaskType = enumsspb.WORKFLOW_TASK_TYPE_NORMAL
}

var newWorkflowTask *workflow.WorkflowTaskInfo
var err error
if workflowTaskHeartbeating && !workflowTaskHeartbeatTimeout {
newWorkflowTask, err = ms.AddWorkflowTaskScheduledEventAsHeartbeat(
bypassTaskGeneration,
currentWorkflowTask.OriginalScheduledTime,
enumsspb.WORKFLOW_TASK_TYPE_NORMAL,
enumsspb.WORKFLOW_TASK_TYPE_NORMAL, // Heartbeat workflow task is always of Normal type.
)
} else {
newWorkflowTask, err = ms.AddWorkflowTaskScheduledEvent(bypassTaskGeneration, enumsspb.WORKFLOW_TASK_TYPE_NORMAL)
newWorkflowTask, err = ms.AddWorkflowTaskScheduledEvent(bypassTaskGeneration, newWorkflowTaskType)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -661,7 +680,6 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
}

return resp, nil

}

func (handler *workflowTaskHandlerCallbacksImpl) verifyFirstWorkflowTaskScheduled(
Expand Down
21 changes: 21 additions & 0 deletions tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ package tests

import (
"flag"
"reflect"
"testing"
"time"

"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/types"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
commonpb "go.temporal.io/api/common/v1"
Expand Down Expand Up @@ -79,3 +82,21 @@ func (s *integrationSuite) sendSignal(namespace string, execution *commonpb.Work

return err
}

func unmarshalAny[T proto.Message](s *integrationSuite, a *types.Any) T {
s.T().Helper()
pb := new(T)
ppb := reflect.ValueOf(pb).Elem()
pbNew := reflect.New(reflect.TypeOf(pb).Elem().Elem())
ppb.Set(pbNew)
err := types.UnmarshalAny(a, *pb)
s.NoError(err)
return *pb
}

func marshalAny(s *integrationSuite, pb proto.Message) *types.Any {
s.T().Helper()
a, err := types.MarshalAny(pb)
s.NoError(err)
return a
}
2 changes: 1 addition & 1 deletion tests/integrationbase.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (s *IntegrationBase) randomizeStr(id string) string {

func (s *IntegrationBase) printWorkflowHistory(namespace string, execution *commonpb.WorkflowExecution) {
events := s.getHistory(namespace, execution)
common.PrettyPrintHistory(&historypb.History{Events: events})
common.PrettyPrint(events)
}

//lint:ignore U1000 used for debugging.
Expand Down
Loading

0 comments on commit 430d3a9

Please sign in to comment.