diff --git a/core/state.go b/core/state.go index 32ae9d908c..c11b15bfd5 100644 --- a/core/state.go +++ b/core/state.go @@ -42,6 +42,8 @@ type StateReader interface { ContractNonce(addr *felt.Felt) (*felt.Felt, error) ContractStorage(addr, key *felt.Felt) (*felt.Felt, error) Class(classHash *felt.Felt) (*DeclaredClass, error) + + // NOTE: Not a best way to add them here - it assumes current state and atm cannot be implemented for hitsrical states ClassTrie() (*trie.Trie, func() error, error) StorageTrie() (*trie.Trie, func() error, error) StorageTrieForAddr(addr *felt.Felt) (*trie.Trie, error) diff --git a/mocks/mock_state.go b/mocks/mock_state.go index 8994085984..4842ee4516 100644 --- a/mocks/mock_state.go +++ b/mocks/mock_state.go @@ -14,6 +14,7 @@ import ( core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" + trie "github.com/NethermindEth/juno/core/trie" gomock "go.uber.org/mock/gomock" ) @@ -55,6 +56,22 @@ func (mr *MockStateHistoryReaderMockRecorder) Class(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Class", reflect.TypeOf((*MockStateHistoryReader)(nil).Class), arg0) } +// ClassTrie mocks base method. +func (m *MockStateHistoryReader) ClassTrie() (*trie.Trie, func() error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClassTrie") + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(func() error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ClassTrie indicates an expected call of ClassTrie. +func (mr *MockStateHistoryReaderMockRecorder) ClassTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClassTrie", reflect.TypeOf((*MockStateHistoryReader)(nil).ClassTrie)) +} + // ContractClassHash mocks base method. func (m *MockStateHistoryReader) ContractClassHash(arg0 *felt.Felt) (*felt.Felt, error) { m.ctrl.T.Helper() @@ -159,3 +176,50 @@ func (mr *MockStateHistoryReaderMockRecorder) ContractStorageAt(arg0, arg1, arg2 mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractStorageAt", reflect.TypeOf((*MockStateHistoryReader)(nil).ContractStorageAt), arg0, arg1, arg2) } + +// StateAndClassRoot mocks base method. +func (m *MockStateHistoryReader) StateAndClassRoot() (*felt.Felt, *felt.Felt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StateAndClassRoot") + ret0, _ := ret[0].(*felt.Felt) + ret1, _ := ret[1].(*felt.Felt) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// StateAndClassRoot indicates an expected call of StateAndClassRoot. +func (mr *MockStateHistoryReaderMockRecorder) StateAndClassRoot() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateAndClassRoot", reflect.TypeOf((*MockStateHistoryReader)(nil).StateAndClassRoot)) +} + +// StorageTrie mocks base method. +func (m *MockStateHistoryReader) StorageTrie() (*trie.Trie, func() error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StorageTrie") + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(func() error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// StorageTrie indicates an expected call of StorageTrie. +func (mr *MockStateHistoryReaderMockRecorder) StorageTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StorageTrie", reflect.TypeOf((*MockStateHistoryReader)(nil).StorageTrie)) +} + +// StorageTrieForAddr mocks base method. +func (m *MockStateHistoryReader) StorageTrieForAddr(arg0 *felt.Felt) (*trie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StorageTrieForAddr", arg0) + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StorageTrieForAddr indicates an expected call of StorageTrieForAddr. +func (mr *MockStateHistoryReaderMockRecorder) StorageTrieForAddr(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StorageTrieForAddr", reflect.TypeOf((*MockStateHistoryReader)(nil).StorageTrieForAddr), arg0) +} diff --git a/rpc/handlers.go b/rpc/handlers.go index 8704abdb2e..19f3ce7d6a 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -58,6 +58,9 @@ var ( // These errors can be only be returned by Juno-specific methods. ErrSubscriptionNotFound = &jsonrpc.Error{Code: 100, Message: "Subscription not found"} + + // TODO[pnowosie]: Update the error while specification describe it + ErrBlockNotRecentForProof = &jsonrpc.Error{Code: 1001, Message: "Block is not sufficiently recent for storage proofs"} ) const ( diff --git a/rpc/storage.go b/rpc/storage.go index b00024c80b..a70c4d8ffc 100644 --- a/rpc/storage.go +++ b/rpc/storage.go @@ -1,9 +1,13 @@ package rpc import ( + "errors" + "fmt" + "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" ) @@ -46,6 +50,10 @@ func (h *Handler) StorageProof(id BlockID, classes, contracts []felt.Felt, stora return nil, ErrInternal.CloneWithData(err) } + if !(id.Latest || head.Number == id.Number) { + return nil, ErrBlockNotRecentForProof + } + storageRoot, classRoot, err := stateReader.StateAndClassRoot() if err != nil { return nil, ErrInternal.CloneWithData(err) @@ -166,12 +174,7 @@ func getContractsProof(reader core.StateReader, contracts []felt.Felt) (*Contrac } for _, contract := range contracts { - leafData := &LeafData{} - leafData.Nonce, err = reader.ContractNonce(&contract) - if err != nil { - return nil, err - } - leafData.ClassHash, err = reader.ContractClassHash(&contract) + leafData, err := addLeafDataIfExists(reader, &contract) if err != nil { return nil, err } @@ -187,6 +190,25 @@ func getContractsProof(reader core.StateReader, contracts []felt.Felt) (*Contrac return result, nil } +func addLeafDataIfExists(reader core.StateReader, contract *felt.Felt) (*LeafData, error) { + nonce, err := reader.ContractNonce(contract) + if errors.Is(err, db.ErrKeyNotFound) { + return nil, nil + } + if err != nil { + return nil, err + } + classHash, err := reader.ContractClassHash(contract) + if err != nil { + return nil, err + } + + return &LeafData{ + Nonce: nonce, + ClassHash: classHash, + }, nil +} + func getContractsStorageProofs(reader core.StateReader, keys []StorageKeys) ([][]*HashToNode, error) { result := make([][]*HashToNode, 0, len(keys)) @@ -196,8 +218,8 @@ func getContractsStorageProofs(reader core.StateReader, keys []StorageKeys) ([][ // Note: if contract does not exist, `StorageTrieForAddr()` returns an empty trie, not an error return nil, err } + nodes := []*HashToNode{} - result = append(result, nodes) for _, slot := range key.Keys { proof, err := getProof(cstrie, &slot) if err != nil { @@ -205,6 +227,7 @@ func getContractsStorageProofs(reader core.StateReader, keys []StorageKeys) ([][ } nodes = append(nodes, proof...) } + result = append(result, nodes) } return result, nil @@ -214,6 +237,10 @@ func getProof(t *trie.Trie, elt *felt.Felt) ([]*HashToNode, error) { feltBytes := elt.Bytes() key := trie.NewKey(core.ContractStorageTrieHeight, feltBytes[:]) nodes, err := trie.GetProof(&key, t) + for i, n := range nodes { + fmt.Printf("[%d]", i) + n.PrettyPrint() + } if err != nil { return nil, err } @@ -230,7 +257,8 @@ func getProof(t *trie.Trie, elt *felt.Felt) ([]*HashToNode, error) { } } if edge, ok := node.(*trie.Edge); ok { - f := edge.Path.Felt() + path := edge.Path + f := path.Felt() merkle = &MerkleEdgeNode{ Path: &f, // TODO[pnowosie]: specs says its int Length: int(edge.Len()), diff --git a/rpc/storage_test.go b/rpc/storage_test.go index da657867f4..7a2bd9b436 100644 --- a/rpc/storage_test.go +++ b/rpc/storage_test.go @@ -1,16 +1,22 @@ package rpc_test import ( + "encoding/json" "errors" + "fmt" + "testing" + + "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebble" "github.com/NethermindEth/juno/mocks" "github.com/NethermindEth/juno/rpc" "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "testing" ) func TestStorageAt(t *testing.T) { @@ -96,85 +102,236 @@ func TestStorageAt(t *testing.T) { } func TestStorageProof(t *testing.T) { + // dummy values + var ( + blkHash = utils.HexToFelt(t, "0x11ead") + clsRoot = utils.HexToFelt(t, "0xc1a55") + stgRoor = utils.HexToFelt(t, "0xc0ffee") + key = new(felt.Felt).SetUint64(1) + noSuchKey = new(felt.Felt).SetUint64(0) + value = new(felt.Felt).SetUint64(51) + blockNumber = uint64(1313) + nopCloser = func() error { + return nil + } + ) + + tempTrie := emptyTrie(t) + + _, err := tempTrie.Put(key, value) + require.NoError(t, err) + _, err = tempTrie.Put(new(felt.Felt).SetUint64(8), new(felt.Felt).SetUint64(59)) + require.NoError(t, err) + require.NoError(t, tempTrie.Commit()) + + // TODO[pnowosie]: There is smth wrong with proof verification, see `Trie proofs sanity check` test + //verifyIf := func(proof []*rpc.HashToNode, key *felt.Felt, value *felt.Felt) bool { + // root, err := tempTrie.Root() + // require.NoError(t, err) + // fmt.Println("root", root) + // + // pnodes := []trie.ProofNode{} + // for _, hn := range proof { + // pnodes = append(pnodes, NodeToProofNode(hn)) + // } + // + // kbs := key.Bytes() + // kkey := trie.NewKey(251, kbs[:]) + // result := trie.VerifyProof(root, &kkey, value, pnodes, tempTrie.HashFunc()) + // fmt.Println("result", result) + // + // return result + //} + mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) mockReader := mocks.NewMockReader(mockCtrl) + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + + mockReader.EXPECT().HeadState().Return(mockState, func() error { + return nil + }, nil).AnyTimes() + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}}, nil).AnyTimes() + mockState.EXPECT().StateAndClassRoot().Return(stgRoor, clsRoot, nil).AnyTimes() + mockState.EXPECT().ClassTrie().Return(tempTrie, nopCloser, nil).AnyTimes() + mockState.EXPECT().StorageTrie().Return(tempTrie, nopCloser, nil).AnyTimes() + log := utils.NewNopZapLogger() handler := rpc.New(mockReader, nil, nil, "", log) blockLatest := rpc.BlockID{Latest: true} - t.Run("empty blockchain", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - + t.Run("Trie proofs sanity check", func(t *testing.T) { + t.Skip("It is not working as (I'd) expected") + kbs := key.Bytes() + kKey := trie.NewKey(251, kbs[:]) + proof, err := trie.GetProof(&kKey, tempTrie) + require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) + require.True(t, trie.VerifyProof(root, &kKey, value, proof, tempTrie.HashFunc())) + }) + t.Run("global roots are filled", func(t *testing.T) { proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) - require.Nil(t, proof) + require.Nil(t, rpcErr) - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + require.NotNil(t, proof) + require.NotNil(t, proof.GlobalRoots) + require.Equal(t, blkHash, proof.GlobalRoots.BlockHash) + require.Equal(t, clsRoot, proof.GlobalRoots.ClassesTreeRoot) + require.Equal(t, stgRoor, proof.GlobalRoots.ContractsTreeRoot) }) - t.Run("class trie hash does not exist in a trie", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) + t.Run("error is returned whenever not latest block is requested", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(rpc.BlockID{Number: 1}, nil, nil, nil) + assert.Equal(t, rpc.ErrBlockNotRecentForProof, rpcErr) require.Nil(t, proof) - - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + }) + t.Run("no error when blknum matches head", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(rpc.BlockID{Number: blockNumber}, nil, nil, nil) + assert.Nil(t, rpcErr) + require.NotNil(t, proof) + }) + t.Run("class trie hash does not exist in a trie", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*noSuchKey}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + require.NotNil(t, proof.ClassesProof) + require.True(t, len(proof.ClassesProof) > 0) }) t.Run("class trie hash exists in a trie", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) - require.Nil(t, proof) - - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + require.True(t, len(proof.ClassesProof) > 0) + require.Len(t, proof.ContractsStorageProofs, 0) + require.NotNil(t, proof.ContractsProof) + require.Len(t, proof.ContractsProof.Nodes, 0) + jsonStr, err := json.Marshal(proof) + require.NoError(t, err) + fmt.Println(string(jsonStr)) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) - require.Nil(t, proof) + mockState.EXPECT().ContractNonce(noSuchKey).Return(nil, db.ErrKeyNotFound) + mockState.EXPECT().ContractClassHash(noSuchKey).Return(nil, db.ErrKeyNotFound) - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*noSuchKey}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + require.Len(t, proof.ClassesProof, 0) + require.Len(t, proof.ContractsStorageProofs, 0) + require.NotNil(t, proof.ContractsProof) + require.True(t, len(proof.ContractsProof.Nodes) > 0) + require.Len(t, proof.ContractsProof.LeavesData, 1) + require.Nil(t, proof.ContractsProof.LeavesData[0]) }) t.Run("storage trie address exists in a trie", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) + nonce := new(felt.Felt).SetUint64(121) + mockState.EXPECT().ContractNonce(key).Return(nonce, nil) + classHasah := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil) - proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) - require.Nil(t, proof) - - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*key}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + require.Len(t, proof.ClassesProof, 0) + require.Len(t, proof.ContractsStorageProofs, 0) + require.NotNil(t, proof.ContractsProof) + require.True(t, len(proof.ContractsProof.Nodes) > 0) + require.Len(t, proof.ContractsProof.LeavesData, 1) + require.NotNil(t, proof.ContractsProof.LeavesData[0]) + ld := proof.ContractsProof.LeavesData[0] + require.Equal(t, nonce, ld.Nonce) + require.Equal(t, classHasah, ld.ClassHash) }) t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) - require.Nil(t, proof) + contract := utils.HexToFelt(t, "0xdead") + mockState.EXPECT().StorageTrieForAddr(contract).Return(emptyTrie(t), nil).Times(1) - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + require.Len(t, proof.ClassesProof, 0) + require.NotNil(t, proof.ContractsProof) + require.Len(t, proof.ContractsProof.Nodes, 0) + require.Len(t, proof.ContractsStorageProofs, 1) + require.Len(t, proof.ContractsStorageProofs[0], 0) }) t.Run("contract storage trie key slot does not exist in a trie", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) + contract := utils.HexToFelt(t, "0xabcd") + mockState.EXPECT().StorageTrieForAddr(gomock.Any()).Return(tempTrie, nil).Times(1) - proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) - require.Nil(t, proof) - - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*noSuchKey}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + require.Len(t, proof.ClassesProof, 0) + require.NotNil(t, proof.ContractsProof) + require.Len(t, proof.ContractsProof.Nodes, 0) + require.Len(t, proof.ContractsStorageProofs, 1) + require.True(t, len(proof.ContractsStorageProofs[0]) > 0) }) t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) + contract := utils.HexToFelt(t, "0xabcd") + mockState.EXPECT().StorageTrieForAddr(gomock.Any()).Return(tempTrie, nil).Times(1) - proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) - require.Nil(t, proof) - - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + require.Len(t, proof.ClassesProof, 0) + require.NotNil(t, proof.ContractsProof) + require.Len(t, proof.ContractsProof.Nodes, 0) + require.Len(t, proof.ContractsStorageProofs, 1) + require.True(t, len(proof.ContractsStorageProofs[0]) > 0) }) t.Run("class & storage tries proofs requested", func(t *testing.T) { - //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) - require.Nil(t, proof) + nonce := new(felt.Felt).SetUint64(121) + mockState.EXPECT().ContractNonce(key).Return(nonce, nil) + classHasah := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil) - assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, []felt.Felt{*key}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + require.True(t, len(proof.ClassesProof) > 0) + require.Len(t, proof.ContractsStorageProofs, 0) + require.NotNil(t, proof.ContractsProof) + require.True(t, len(proof.ContractsProof.Nodes) > 0) + require.Len(t, proof.ContractsProof.LeavesData, 1) }) } + +func emptyTrie(t *testing.T) *trie.Trie { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + return tempTrie +} + +func NodeToProofNode(hn *rpc.HashToNode) trie.ProofNode { + var proofNode trie.ProofNode + + switch pnode := hn.Node.(type) { + case *rpc.MerkleEdgeNode: + pbs := pnode.Path.Bytes() + path := trie.NewKey(uint8(pnode.Length), pbs[:]) + proofNode = &trie.Edge{ + Path: &path, + Child: pnode.Child, + } + case *rpc.MerkleBinaryNode: + proofNode = &trie.Binary{ + LeftHash: pnode.Left, + RightHash: pnode.Right, + } + default: + panic(fmt.Errorf("unsupported node type %T", pnode)) + } + + return proofNode +}