Skip to content

Commit

Permalink
Check that the L1 node is on the correct network
Browse files Browse the repository at this point in the history
How we achieve this:

- Add a `ChainID` method to the `l1.Subscriber` interface.
- Add a `network` field to the `l1.Client` struct.
- Set the `l1.Client.network` field with the value retrieved from
  `blockchain.Network()` in `l1.NewClient(...)`.
- Compare the result of `l1.Subscriber.ChainID` to `l1.Client.network`
  in `l1.Client.checkChainID`, which is called in `l1.Client.Run`.

The behavior can be manually tested with a command like:

```
juno --eth-node <mainnet-eth-node> --network goerli
```

which should log an error and exit.
  • Loading branch information
joshklop committed Jun 19, 2023
1 parent c110e15 commit 0ec74ea
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 4 deletions.
5 changes: 5 additions & 0 deletions l1/eth_subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package l1

import (
"context"
"math/big"

"github.com/NethermindEth/juno/l1/contract"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
Expand Down Expand Up @@ -41,6 +42,10 @@ func (s *EthSubscriber) WatchLogStateUpdate(ctx context.Context, sink chan<- *co
return s.filterer.WatchLogStateUpdate(&bind.WatchOpts{Context: ctx}, sink)
}

func (s *EthSubscriber) ChainID(ctx context.Context) (*big.Int, error) {
return s.ethClient.ChainID(ctx)
}

func (s *EthSubscriber) Close() {
s.ethClient.Close()
}
24 changes: 24 additions & 0 deletions l1/l1.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package l1
import (
"context"
"fmt"
"math/big"

"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/core"
Expand All @@ -18,13 +19,15 @@ import (
type Subscriber interface {
WatchHeader(ctx context.Context, sink chan<- *types.Header) (event.Subscription, error)
WatchLogStateUpdate(ctx context.Context, sink chan<- *contract.StarknetLogStateUpdate) (event.Subscription, error)
ChainID(ctx context.Context) (*big.Int, error)
}

type Client struct {
l1 Subscriber
l2Chain *blockchain.Blockchain
log utils.SimpleLogger
confirmationQueue *queue
network utils.Network
}

var _ service.Service = (*Client)(nil)
Expand All @@ -34,6 +37,7 @@ func NewClient(l1 Subscriber, chain *blockchain.Blockchain, confirmationPeriod u
l1: l1,
l2Chain: chain,
log: log,
network: chain.Network(),
confirmationQueue: newQueue(confirmationPeriod),
}
}
Expand All @@ -56,7 +60,27 @@ func (c *Client) subscribeToUpdates(ctx context.Context,
return sub, nil
}

func (c *Client) checkChainID(ctx context.Context) error {
gotChainID, err := c.l1.ChainID(ctx)
if err != nil {
return fmt.Errorf("retrieve Ethereum chain ID: %w", err)
}

wantChainID := c.network.DefaultL1ChainID()
if gotChainID.Cmp(wantChainID) == 0 {
return nil
}

// NOTE: for now we return an error. If we want to support users who fork
// Starknet to create a "custom" Starknet network, we will need to log a warning instead.
return fmt.Errorf("mismatched L1 and L2 networks: L2 network %s; is the L1 node on the correct network?", c.network)
}

func (c *Client) Run(ctx context.Context) error {
if err := c.checkChainID(ctx); err != nil {
return err
}

buffer := 128

logStateUpdateChan := make(chan *contract.StarknetLogStateUpdate, buffer)
Expand Down
18 changes: 16 additions & 2 deletions l1/l1_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ func TestClient(t *testing.T) {

ctrl := gomock.NewController(t)
nopLog := utils.NewNopZapLogger()
chain := blockchain.New(pebble.NewMemTest(), utils.MAINNET, nopLog)
network := utils.MAINNET
chain := blockchain.New(pebble.NewMemTest(), network, nopLog)
client := NewClient(nil, chain, tt.confirmationPeriod, nopLog)

// We loop over each block and check that the state of the chain aligns with our expectations.
Expand All @@ -331,6 +332,12 @@ func TestClient(t *testing.T) {
Return(newFakeSubscription(), nil).
Times(1)

subscriber.
EXPECT().
ChainID(gomock.Any()).
Return(network.DefaultL1ChainID(), nil).
Times(1)

// Replace the subscriber.
client.l1 = subscriber

Expand Down Expand Up @@ -360,7 +367,8 @@ func TestUnreliableSubscription(t *testing.T) {

ctrl := gomock.NewController(t)
nopLog := utils.NewNopZapLogger()
chain := blockchain.New(pebble.NewMemTest(), utils.MAINNET, nopLog)
network := utils.MAINNET
chain := blockchain.New(pebble.NewMemTest(), network, nopLog)
client := NewClient(nil, chain, 3, nopLog)

err := errors.New("test err")
Expand Down Expand Up @@ -390,6 +398,12 @@ func TestUnreliableSubscription(t *testing.T) {
Times(1).
After(failedUpdateCall)

subscriber.
EXPECT().
ChainID(gomock.Any()).
Return(network.DefaultL1ChainID(), nil).
Times(1)

failedHeaderSub := newFakeSubscription(err)
failedHeaderCall := subscriber.
EXPECT().
Expand Down
71 changes: 69 additions & 2 deletions l1/l1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"math/big"
"testing"
"time"

"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/db/pebble"
Expand Down Expand Up @@ -38,7 +39,7 @@ func (s *fakeSubscription) Unsubscribe() {
}
}

func TestGracefulErrorHandling(t *testing.T) {
func TestFailedSubscription(t *testing.T) {
t.Parallel()

err := errors.New("test error")
Expand All @@ -65,9 +66,10 @@ func TestGracefulErrorHandling(t *testing.T) {
t.Run(tt.description, func(t *testing.T) {
t.Parallel()

network := utils.MAINNET
ctrl := gomock.NewController(t)
nopLog := utils.NewNopZapLogger()
chain := blockchain.New(pebble.NewMemTest(), utils.MAINNET, nopLog)
chain := blockchain.New(pebble.NewMemTest(), network, nopLog)

subscriber := mocks.NewMockSubscriber(ctrl)

Expand All @@ -88,9 +90,74 @@ func TestGracefulErrorHandling(t *testing.T) {
Return(tt.watchHeaderRets...).
AnyTimes()

subscriber.
EXPECT().
ChainID(gomock.Any()).
Return(network.DefaultL1ChainID(), nil).
Times(1)

client := l1.NewClient(subscriber, chain, 0, nopLog)

require.ErrorIs(t, client.Run(context.Background()), err)
})
}
}

func TestChainID(t *testing.T) {
t.Parallel()

helper := func(t *testing.T, matching bool) error {
t.Helper()

network := utils.MAINNET
ctrl := gomock.NewController(t)
nopLog := utils.NewNopZapLogger()
chain := blockchain.New(pebble.NewMemTest(), network, nopLog)

subscriber := mocks.NewMockSubscriber(ctrl)

subscriber.
EXPECT().
WatchLogStateUpdate(gomock.Any(), gomock.Any()).
Return(newFakeSubscription(), nil).
AnyTimes()

subscriber.
EXPECT().
WatchHeader(gomock.Any(), gomock.Any()).
Do(func(_ context.Context, sink chan<- *types.Header) {
sink <- &types.Header{
Number: new(big.Int),
}
}).
Return(newFakeSubscription(), nil).
AnyTimes()

l1ChainID := new(big.Int)
if matching {
l1ChainID.Set(network.DefaultL1ChainID())
}

subscriber.
EXPECT().
ChainID(gomock.Any()).
Return(l1ChainID, nil).
Times(1)

client := l1.NewClient(subscriber, chain, 0, nopLog)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return client.Run(ctx)
}

t.Run("matching chain IDs", func(t *testing.T) {
t.Parallel()
require.NoError(t, helper(t, true))
})

t.Run("mismatched chain IDs", func(t *testing.T) {
t.Parallel()
require.Error(t, helper(t, false))
})
}
16 changes: 16 additions & 0 deletions l1/mocks/mock_subscriber.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions utils/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"encoding"
"errors"
"math/big"

"github.com/NethermindEth/juno/core/felt"
"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -108,6 +109,20 @@ func (n Network) ChainID() *felt.Felt {
}
}

func (n Network) DefaultL1ChainID() *big.Int {
var chainID int64
switch n {
case MAINNET:
chainID = 1
case GOERLI, GOERLI2, INTEGRATION:
chainID = 5
default:
// Should not happen.
panic(ErrUnknownNetwork)
}
return big.NewInt(chainID)
}

func (n Network) CoreContractAddress() (common.Address, error) {
var address common.Address
// The docs states the addresses for each network: https://docs.starknet.io/documentation/useful_info/
Expand Down
14 changes: 14 additions & 0 deletions utils/network_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils_test

import (
"math/big"
"strings"
"testing"

Expand Down Expand Up @@ -54,6 +55,19 @@ func TestNetwork(t *testing.T) {
}
}
})
t.Run("default L1 chainId", func(t *testing.T) {
for n := range networkStrings {
got := n.DefaultL1ChainID()
switch n {
case utils.MAINNET:
assert.Equal(t, big.NewInt(1), got)
case utils.GOERLI, utils.GOERLI2, utils.INTEGRATION:
assert.Equal(t, big.NewInt(5), got)
default:
assert.Fail(t, "unexpected network")
}
}
})
}

//nolint:dupl // see comment in utils/log_test.go
Expand Down

0 comments on commit 0ec74ea

Please sign in to comment.