Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return error if cluster metadata is invalid #3879

Merged
merged 4 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions service/history/replication/poller_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
package replication

import (
"errors"
"fmt"

"go.temporal.io/server/common/cluster"
)

type (
pollerManager interface {
getSourceClusterShardIDs(sourceClusterName string) []int32
getSourceClusterShardIDs(sourceClusterName string) ([]int32, error)
}

pollerManagerImpl struct {
Expand All @@ -53,18 +54,27 @@ func newPollerManager(
}
}

func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) []int32 {
func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) ([]int32, error) {
currentCluster := p.clusterMetadata.GetCurrentClusterName()
allClusters := p.clusterMetadata.GetAllClusterInfo()
currentClusterInfo, ok := allClusters[currentCluster]
if !ok {
panic("Cannot get current cluster info from cluster metadata cache")
return nil, errors.New("cannot get current cluster info from cluster metadata cache")
}
remoteClusterInfo, ok := allClusters[sourceClusterName]
if !ok {
panic(fmt.Sprintf("Cannot get source cluster %s info from cluster metadata cache", sourceClusterName))
return nil, errors.New(fmt.Sprintf("cannot get source cluster %s info from cluster metadata cache", sourceClusterName))
}
return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount)

// The remote shard count and local shard count must be multiples.
large, small := remoteClusterInfo.ShardCount, currentClusterInfo.ShardCount
if small > large {
large, small = small, large
}
if large%small != 0 {
return nil, errors.New(fmt.Sprintf("remote shard count %d and local shard count %d are not multiples.", remoteClusterInfo.ShardCount, currentClusterInfo.ShardCount))
}
return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount), nil
}

func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 {
Expand All @@ -75,12 +85,7 @@ func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCoun
}
return pollingShards
}

// remoteShardCount > localShardCount, replication poller will poll from multiple remote shard.
// The remote shard count and local shard count must be multiples.
if remoteShardCount%localShardCount != 0 {
panic(fmt.Sprintf("Remote shard count %d and local shard count %d are not multiples.", remoteShardCount, localShardCount))
}
for i := localShardId; i <= remoteShardCount; i += localShardCount {
pollingShards = append(pollingShards, i)
}
Expand Down
23 changes: 8 additions & 15 deletions service/history/replication/poller_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,60 +36,53 @@ func TestGetPollingShardIds(t *testing.T) {
shardID int32
remoteShardCount int32
localShardCount int32
expectedPanic bool
expectedShardIDs []int32
}{
{
1,
4,
4,
false,
[]int32{1},
},
{
1,
2,
4,
false,
[]int32{1},
},
{
3,
2,
4,
false,
[]int32{},
nil,
},
{
1,
16,
4,
false,
[]int32{1, 5, 9, 13},
},
{
4,
16,
4,
false,
[]int32{4, 8, 12, 16},
},
{
4,
17,
4,
true,
[]int32{},
[]int32{4, 8, 12, 16},
},
{
1,
17,
4,
[]int32{1, 5, 9, 13, 17},
},
}
for idx, tt := range testCases {
t.Run(fmt.Sprintf("Testcase %d", idx), func(t *testing.T) {
t.Parallel()
defer func() {
if r := recover(); tt.expectedPanic && r == nil {
t.Errorf("The code did not panic")
}
}()
shardIDs := generateShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount)
assert.Equal(t, tt.expectedShardIDs, shardIDs)
})
Expand Down
88 changes: 51 additions & 37 deletions service/history/replication/task_processor_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ package replication

import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -72,7 +71,7 @@ type (
logger log.Logger

taskProcessorLock sync.RWMutex
taskProcessors map[string]TaskProcessor
taskProcessors map[string][]TaskProcessor // cluster name - processor
minTxAckedTaskID int64
shutdownChan chan struct{}
}
Expand Down Expand Up @@ -114,7 +113,7 @@ func NewTaskProcessorManager(
),
logger: shard.GetLogger(),
metricsHandler: shard.GetMetricsHandler(),
taskProcessors: make(map[string]TaskProcessor),
taskProcessors: make(map[string][]TaskProcessor),
taskExecutorProvider: taskExecutorProvider,
taskPollerManager: newPollerManager(shard.GetShardID(), shard.GetClusterMetadata()),
minTxAckedTaskID: persistence.EmptyQueueMessageID,
Expand Down Expand Up @@ -149,8 +148,10 @@ func (r *taskProcessorManagerImpl) Stop() {

r.shard.GetClusterMetadata().UnRegisterMetadataChangeCallback(r)
r.taskProcessorLock.Lock()
for _, replicationTaskProcessor := range r.taskProcessors {
replicationTaskProcessor.Stop()
for _, taskProcessors := range r.taskProcessors {
for _, processor := range taskProcessors {
processor.Stop()
}
}
r.taskProcessorLock.Unlock()
}
Expand All @@ -170,44 +171,57 @@ func (r *taskProcessorManagerImpl) handleClusterMetadataUpdate(
r.taskProcessorLock.Lock()
defer r.taskProcessorLock.Unlock()
currentClusterName := r.shard.GetClusterMetadata().GetCurrentClusterName()
// The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address
// The callback covers three cases:
// Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata(1 + 2).

// Case 1 and Case 3
for clusterName := range oldClusterMetadata {
if clusterName == currentClusterName {
continue
}
sourceShardIds := r.taskPollerManager.getSourceClusterShardIDs(clusterName)
for _, processor := range r.taskProcessors[clusterName] {
processor.Stop()
delete(r.taskProcessors, clusterName)
}
}

// Case 2 and Case 3
for clusterName := range newClusterMetadata {
if clusterName == currentClusterName {
continue
}
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo == nil || !clusterInfo.Enabled {
continue
}
sourceShardIds, err := r.taskPollerManager.getSourceClusterShardIDs(clusterName)
if err != nil {
r.logger.Error("Failed to get source shard id list", tag.Error(err), tag.ClusterName(clusterName))
continue
}
var processors []TaskProcessor
for _, sourceShardId := range sourceShardIds {
perShardTaskProcessorKey := fmt.Sprintf(clusterCallbackKey, clusterName, sourceShardId)
// The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address
// The callback covers three cases:
// Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata.
if processor, ok := r.taskProcessors[perShardTaskProcessorKey]; ok {
// Case 1 and Case 3
processor.Stop()
delete(r.taskProcessors, perShardTaskProcessorKey)
}
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled {
// Case 2 and Case 3
fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName)
replicationTaskProcessor := NewTaskProcessor(
sourceShardId,
r.shard,
r.engine,
r.config,
r.shard.GetMetricsHandler(),
fetcher,
r.taskExecutorProvider(TaskExecutorParams{
RemoteCluster: clusterName,
Shard: r.shard,
HistoryResender: r.resender,
DeleteManager: r.deleteMgr,
WorkflowCache: r.workflowCache,
}),
r.eventSerializer,
)
replicationTaskProcessor.Start()
r.taskProcessors[perShardTaskProcessorKey] = replicationTaskProcessor
}
fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName)
replicationTaskProcessor := NewTaskProcessor(
sourceShardId,
r.shard,
r.engine,
r.config,
r.shard.GetMetricsHandler(),
fetcher,
r.taskExecutorProvider(TaskExecutorParams{
RemoteCluster: clusterName,
Shard: r.shard,
HistoryResender: r.resender,
DeleteManager: r.deleteMgr,
WorkflowCache: r.workflowCache,
}),
r.eventSerializer,
)
replicationTaskProcessor.Start()
processors = append(processors, replicationTaskProcessor)
}
r.taskProcessors[clusterName] = processors
}
}

Expand Down