Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable starting multiple tss processes with the same peer subset #331

Merged
merged 17 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: enable executing multiple tss processes with the same subset
  • Loading branch information
mpetrun5 committed Jul 16, 2024
commit ba6fa437e8a490513be8b62c1ad529719fe98900
3 changes: 2 additions & 1 deletion chains/btc/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (e *Executor) executeResourceProps(props []*BtcTransferProposal, resource c
if err != nil {
return err
}
return e.coordinator.Execute(executionContext, signing, sigChn)
return e.coordinator.Execute(executionContext, []tss.TssProcess{signing}, sigChn)
})
}
return p.Wait()
Expand Down Expand Up @@ -223,6 +223,7 @@ func (e *Executor) watchExecution(

e.storeProposalsStatus(proposals, store.ExecutedProp)
log.Info().Str("messageID", messageID).Msgf("Sent proposals execution with hash: %s", hash)
return nil
}
case <-timeout.C:
{
Expand Down
2 changes: 1 addition & 1 deletion chains/evm/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (e *Executor) Execute(proposals []*proposal.Proposal) error {
watchContext, cancelWatch := context.WithCancel(context.Background())
ep := pool.New().WithErrors()
ep.Go(func() error {
err := e.coordinator.Execute(executionContext, signing, sigChn)
err := e.coordinator.Execute(executionContext, []tss.TssProcess{signing}, sigChn)
if err != nil {
cancelWatch()
}
Expand Down
8 changes: 4 additions & 4 deletions chains/evm/listener/eventHandlers/event-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (eh *KeygenEventHandler) HandleEvents(

keygenBlockNumber := big.NewInt(0).SetUint64(keygenEvents[0].BlockNumber)
keygen := keygen.NewKeygen(eh.sessionID(keygenBlockNumber), eh.threshold, eh.host, eh.communication, eh.storer)
err = eh.coordinator.Execute(context.Background(), keygen, make(chan interface{}, 1))
err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{keygen}, make(chan interface{}, 1))
if err != nil {
log.Err(err).Msgf("Failed executing keygen")
}
Expand Down Expand Up @@ -289,7 +289,7 @@ func (eh *FrostKeygenEventHandler) HandleEvents(

keygenBlockNumber := big.NewInt(0).SetUint64(keygenEvents[0].BlockNumber)
keygen := frostKeygen.NewKeygen(eh.sessionID(keygenBlockNumber), eh.threshold, eh.host, eh.communication, eh.storer)
err = eh.coordinator.Execute(context.Background(), keygen, make(chan interface{}, 1))
err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{keygen}, make(chan interface{}, 1))
if err != nil {
log.Err(err).Msgf("Failed executing keygen")
}
Expand Down Expand Up @@ -381,14 +381,14 @@ func (eh *RefreshEventHandler) HandleEvents(
resharing := resharing.NewResharing(
eh.sessionID(startBlock), topology.Threshold, eh.host, eh.communication, eh.ecdsaStorer,
)
err = eh.coordinator.Execute(context.Background(), resharing, make(chan interface{}, 1))
err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{resharing}, make(chan interface{}, 1))
if err != nil {
log.Err(err).Msgf("Failed executing ecdsa key refresh")
}
frostResharing := frostResharing.NewResharing(
eh.sessionID(startBlock), topology.Threshold, eh.host, eh.communication, eh.frostStorer,
)
err = eh.coordinator.Execute(context.Background(), frostResharing, make(chan interface{}, 1))
err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{frostResharing}, make(chan interface{}, 1))
if err != nil {
log.Err(err).Msgf("Failed executing frost key refresh")
}
Expand Down
2 changes: 1 addition & 1 deletion chains/substrate/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (e *Executor) Execute(proposals []*proposal.Proposal) error {

pool := pool.New().WithErrors()
pool.Go(func() error {
err := e.coordinator.Execute(executionContext, signing, sigChn)
err := e.coordinator.Execute(executionContext, []tss.TssProcess{signing}, sigChn)
if err != nil {
cancelWatch()
}
Expand Down
77 changes: 49 additions & 28 deletions tss/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ func NewCoordinator(
}

// Execute calculates process leader and coordinates party readiness and start the tss processes.
func (c *Coordinator) Execute(ctx context.Context, tssProcess TssProcess, resultChn chan interface{}) error {
sessionID := tssProcess.SessionID()
// Array of processes can be passed if all the processes have to have the same peer subset and
// the result of all of them is needed. The processes should have an unique session ID for each one.
func (c *Coordinator) Execute(ctx context.Context, tssProcesses []TssProcess, resultChn chan interface{}) error {
sessionID := tssProcesses[0].SessionID()
value, ok := c.pendingProcesses[sessionID]
if ok && value {
log.Warn().Str("SessionID", sessionID).Msgf("Process already pending")
Expand All @@ -89,71 +91,74 @@ func (c *Coordinator) Execute(ctx context.Context, tssProcess TssProcess, result
c.processLock.Lock()
c.pendingProcesses[sessionID] = false
c.processLock.Unlock()
tssProcess.Stop()
for _, process := range tssProcesses {
process.Stop()
}
}()

coordinatorElector := c.electorFactory.CoordinatorElector(sessionID, elector.Static)
coordinator, _ := coordinatorElector.Coordinator(ctx, tssProcess.ValidCoordinators())
coordinator, _ := coordinatorElector.Coordinator(ctx, tssProcesses[0].ValidCoordinators())

log.Info().Str("SessionID", sessionID).Msgf("Starting process with coordinator %s", coordinator.Pretty())

p.Go(func(ctx context.Context) error {
err := c.start(ctx, tssProcess, coordinator, resultChn, []peer.ID{})
err := c.start(ctx, tssProcesses, coordinator, resultChn, []peer.ID{})
if err == nil {
cancel()
}
return err
})
p.Go(func(ctx context.Context) error {
return c.watchExecution(ctx, tssProcess, coordinator)
return c.watchExecution(ctx, tssProcesses[0], coordinator)
})
err := p.Wait()
if err == nil {
return nil
}

if !tssProcess.Retryable() {
if !tssProcesses[0].Retryable() {
return err
}

return c.handleError(ctx, err, tssProcess, resultChn)
return c.handleError(ctx, err, tssProcesses, resultChn)
}

func (c *Coordinator) handleError(ctx context.Context, err error, tssProcess TssProcess, resultChn chan interface{}) error {
func (c *Coordinator) handleError(ctx context.Context, err error, tssProcesses []TssProcess, resultChn chan interface{}) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

rp := pool.New().WithContext(ctx).WithCancelOnError()
rp.Go(func(ctx context.Context) error {
return c.watchExecution(ctx, tssProcess, peer.ID(""))
return c.watchExecution(ctx, tssProcesses[0], peer.ID(""))
})
sessionID := tssProcesses[0].SessionID()
switch err := err.(type) {
case *CoordinatorError:
{
log.Err(err).Str("SessionID", tssProcess.SessionID()).Msgf("Tss process failed with error %+v", err)
log.Err(err).Str("SessionID", sessionID).Msgf("Tss process failed with error %+v", err)

excludedPeers := []peer.ID{err.Peer}
rp.Go(func(ctx context.Context) error { return c.retry(ctx, tssProcess, resultChn, excludedPeers) })
rp.Go(func(ctx context.Context) error { return c.retry(ctx, tssProcesses, resultChn, excludedPeers) })
}
case *comm.CommunicationError:
{
log.Err(err).Str("SessionID", tssProcess.SessionID()).Msgf("Tss process failed with error %+v", err)
rp.Go(func(ctx context.Context) error { return c.retry(ctx, tssProcess, resultChn, []peer.ID{}) })
log.Err(err).Str("SessionID", sessionID).Msgf("Tss process failed with error %+v", err)
rp.Go(func(ctx context.Context) error { return c.retry(ctx, tssProcesses, resultChn, []peer.ID{}) })
}
case *tss.Error:
{
log.Err(err).Str("SessionID", tssProcess.SessionID()).Msgf("Tss process failed with error %+v", err)
log.Err(err).Str("SessionID", sessionID).Msgf("Tss process failed with error %+v", err)
excludedPeers, err := common.PeersFromParties(err.Culprits())
if err != nil {
return err
}
rp.Go(func(ctx context.Context) error { return c.retry(ctx, tssProcess, resultChn, excludedPeers) })
rp.Go(func(ctx context.Context) error { return c.retry(ctx, tssProcesses, resultChn, excludedPeers) })
}
case *SubsetError:
{
// wait for start message if existing singing process fails
rp.Go(func(ctx context.Context) error {
return c.waitForStart(ctx, tssProcess, resultChn, peer.ID(""), c.TssTimeout)
return c.waitForStart(ctx, tssProcesses, resultChn, peer.ID(""), c.TssTimeout)
})
}
default:
Expand Down Expand Up @@ -197,24 +202,24 @@ func (c *Coordinator) watchExecution(ctx context.Context, tssProcess TssProcess,
}

// start initiates listeners for coordinator and participants with static calculated coordinator
func (c *Coordinator) start(ctx context.Context, tssProcess TssProcess, coordinator peer.ID, resultChn chan interface{}, excludedPeers []peer.ID) error {
func (c *Coordinator) start(ctx context.Context, tssProcesses []TssProcess, coordinator peer.ID, resultChn chan interface{}, excludedPeers []peer.ID) error {
if coordinator.Pretty() == c.host.ID().Pretty() {
return c.initiate(ctx, tssProcess, resultChn, excludedPeers)
return c.initiate(ctx, tssProcesses, resultChn, excludedPeers)
} else {
return c.waitForStart(ctx, tssProcess, resultChn, coordinator, c.CoordinatorTimeout)
return c.waitForStart(ctx, tssProcesses, resultChn, coordinator, c.CoordinatorTimeout)
}
}

// retry initiates full bully process to calculate coordinator and starts a new tss process after
// an expected error ocurred during regular tss execution
func (c *Coordinator) retry(ctx context.Context, tssProcess TssProcess, resultChn chan interface{}, excludedPeers []peer.ID) error {
coordinatorElector := c.electorFactory.CoordinatorElector(tssProcess.SessionID(), elector.Bully)
coordinator, err := coordinatorElector.Coordinator(ctx, common.ExcludePeers(tssProcess.ValidCoordinators(), excludedPeers))
func (c *Coordinator) retry(ctx context.Context, tssProcesses []TssProcess, resultChn chan interface{}, excludedPeers []peer.ID) error {
coordinatorElector := c.electorFactory.CoordinatorElector(tssProcesses[0].SessionID(), elector.Bully)
coordinator, err := coordinatorElector.Coordinator(ctx, common.ExcludePeers(tssProcesses[0].ValidCoordinators(), excludedPeers))
if err != nil {
return err
}

return c.start(ctx, tssProcess, coordinator, resultChn, excludedPeers)
return c.start(ctx, tssProcesses, coordinator, resultChn, excludedPeers)
}

// broadcastInitiateMsg sends TssInitiateMsg to all peers
Expand All @@ -228,11 +233,12 @@ func (c *Coordinator) broadcastInitiateMsg(sessionID string) {
// initiate sends initiate message to all peers and waits
// for ready response. After tss process declares that enough
// peers are ready, start message is broadcasted and tss process is started.
func (c *Coordinator) initiate(ctx context.Context, tssProcess TssProcess, resultChn chan interface{}, excludedPeers []peer.ID) error {
func (c *Coordinator) initiate(ctx context.Context, tssProcesses []TssProcess, resultChn chan interface{}, excludedPeers []peer.ID) error {
readyChan := make(chan *comm.WrappedMessage)
readyPeers := make([]peer.ID, 0)
readyPeers = append(readyPeers, c.host.ID())

tssProcess := tssProcesses[0]
subID := c.communication.Subscribe(tssProcess.SessionID(), comm.TssReadyMsg, readyChan)
defer c.communication.UnSubscribe(subID)

Expand Down Expand Up @@ -262,7 +268,14 @@ func (c *Coordinator) initiate(ctx context.Context, tssProcess TssProcess, resul
}

_ = c.communication.Broadcast(c.host.Peerstore().Peers(), startMsgBytes, comm.TssStartMsg, tssProcess.SessionID())
return tssProcess.Run(ctx, true, resultChn, startParams)
p := pool.New().WithContext(ctx).WithCancelOnError()
for _, process := range tssProcesses {
tssProcess := process
p.Go(func(ctx context.Context) error {
return tssProcess.Run(ctx, true, resultChn, startParams)
})
}
return p.Wait()
}
case <-ticker.C:
{
Expand All @@ -280,14 +293,15 @@ func (c *Coordinator) initiate(ctx context.Context, tssProcess TssProcess, resul
// when it receives the start message.
func (c *Coordinator) waitForStart(
ctx context.Context,
tssProcess TssProcess,
tssProcesses []TssProcess,
resultChn chan interface{},
coordinator peer.ID,
timeout time.Duration,
) error {
msgChan := make(chan *comm.WrappedMessage)
startMsgChn := make(chan *comm.WrappedMessage)

tssProcess := tssProcesses[0]
initSubID := c.communication.Subscribe(tssProcess.SessionID(), comm.TssInitiateMsg, msgChan)
defer c.communication.UnSubscribe(initSubID)
startSubID := c.communication.Subscribe(tssProcess.SessionID(), comm.TssStartMsg, startMsgChn)
Expand Down Expand Up @@ -327,7 +341,14 @@ func (c *Coordinator) waitForStart(
return err
}

return tssProcess.Run(ctx, false, resultChn, msg.Params)
p := pool.New().WithContext(ctx).WithCancelOnError()
for _, process := range tssProcesses {
tssProcess := process
p.Go(func(ctx context.Context) error {
return tssProcess.Run(ctx, true, resultChn, msg.Params)
})
}
return p.Wait()
}
case <-coordinatorTimeoutTicker.C:
{
Expand Down
4 changes: 2 additions & 2 deletions tss/ecdsa/keygen/keygen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (s *KeygenTestSuite) Test_ValidKeygenProcess() {
s.MockECDSAStorer.EXPECT().StoreKeyshare(gomock.Any()).Times(3)
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], nil) })
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
}

err := pool.Wait()
Expand Down Expand Up @@ -81,7 +81,7 @@ func (s *KeygenTestSuite) Test_KeygenTimeout() {
s.MockECDSAStorer.EXPECT().StoreKeyshare(gomock.Any()).Times(0)
pool := pool.New().WithContext(context.Background())
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], nil) })
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
}

err := pool.Wait()
Expand Down
12 changes: 8 additions & 4 deletions tss/ecdsa/resharing/resharing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ func (s *ResharingTestSuite) Test_ValidResharingProcess_OldAndNewSubset() {
resultChn := make(chan interface{})
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], resultChn) })
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
})
}

err := pool.Wait()
Expand Down Expand Up @@ -114,7 +116,7 @@ func (s *ResharingTestSuite) Test_ValidResharingProcess_RemovePeer() {
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, processes[i], resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
})
}

Expand Down Expand Up @@ -164,7 +166,7 @@ func (s *ResharingTestSuite) Test_InvalidResharingProcess_InvalidOldThreshold_Le
pool := pool.New().WithContext(context.Background())
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, processes[i], resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
})
}
err := pool.Wait()
Expand Down Expand Up @@ -212,7 +214,9 @@ func (s *ResharingTestSuite) Test_InvalidResharingProcess_InvalidOldThreshold_Bi
resultChn := make(chan interface{})
pool := pool.New().WithContext(context.Background())
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], resultChn) })
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
})
}

err := pool.Wait()
Expand Down
10 changes: 6 additions & 4 deletions tss/ecdsa/signing/signing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (s *SigningTestSuite) Test_ValidSigningProcess() {
for i, coordinator := range coordinators {
coordinator := coordinator
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, processes[i], resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
})
}

Expand Down Expand Up @@ -112,7 +112,9 @@ func (s *SigningTestSuite) Test_SigningTimeout() {
pool := pool.New().WithContext(context.Background())
for i, coordinator := range coordinators {
coordinator := coordinator
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], resultChn) })
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
})
}

err := pool.Wait()
Expand Down Expand Up @@ -140,8 +142,8 @@ func (s *SigningTestSuite) Test_PendingProcessExists() {
s.MockECDSAStorer.EXPECT().UnlockKeyshare().AnyTimes()
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], nil) })
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], nil) })
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
}

err := pool.Wait()
Expand Down
2 changes: 1 addition & 1 deletion tss/frost/keygen/keygen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (s *KeygenTestSuite) Test_ValidKeygenProcess() {

pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], nil) })
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
}

err := pool.Wait()
Expand Down
8 changes: 6 additions & 2 deletions tss/frost/resharing/resharing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ func (s *ResharingTestSuite) Test_ValidResharingProcess_OldAndNewSubset() {
resultChn := make(chan interface{})
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], resultChn) })
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
})
}

err := pool.Wait()
Expand Down Expand Up @@ -113,7 +115,9 @@ func (s *ResharingTestSuite) Test_ValidResharingProcess_RemovePeer() {
resultChn := make(chan interface{})
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, processes[i], resultChn) })
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
})
}

err := pool.Wait()
Expand Down
Loading
Loading