diff --git a/lib/runtime/wazero/imports_test.go b/lib/runtime/wazero/imports_test.go index e2421396ef..64234ca9f1 100644 --- a/lib/runtime/wazero/imports_test.go +++ b/lib/runtime/wazero/imports_test.go @@ -920,9 +920,10 @@ func Test_ext_default_child_storage_clear_version_1(t *testing.T) { _, err = inst.Exec("rtm_ext_default_child_storage_clear_version_1", append(encChildKey, encKey...)) require.NoError(t, err) - val, err = inst.Context.Storage.GetChildStorage(testChildKey, testKey) - require.NoError(t, err) - require.Nil(t, val) + _, err = inst.Context.Storage.GetChildStorage(testChildKey, testKey) + require.ErrorIs(t, err, trie.ErrChildTrieDoesNotExist) + require.EqualError(t, err, "child trie does not exist at key "+ + "0x3a6368696c645f73746f726167653a64656661756c743a6368696c644b6579") } func Test_ext_default_child_storage_clear_prefix_version_1(t *testing.T) { diff --git a/lib/trie/child_storage.go b/lib/trie/child_storage.go index 498a7a3137..00dd28d20b 100644 --- a/lib/trie/child_storage.go +++ b/lib/trie/child_storage.go @@ -111,14 +111,25 @@ func (t *Trie) ClearFromChild(keyToChild, key []byte) error { if err != nil { return err } + if child == nil { return fmt.Errorf("%w at key 0x%x%x", ErrChildTrieDoesNotExist, ChildStorageKeyPrefix, keyToChild) } + origChildHash, err := child.Hash() + if err != nil { + return err + } + err = child.Delete(key) if err != nil { return fmt.Errorf("deleting from child trie located at key 0x%x: %w", keyToChild, err) } - return nil + delete(t.childTries, origChildHash) + if child.root == nil { + return t.DeleteChild(keyToChild) + } + + return t.SetChild(keyToChild, child) } diff --git a/lib/trie/child_storage_test.go b/lib/trie/child_storage_test.go index f8a53a82d3..eb922f102f 100644 --- a/lib/trie/child_storage_test.go +++ b/lib/trie/child_storage_test.go @@ -5,8 +5,11 @@ package trie import ( "bytes" + "encoding/binary" "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestPutAndGetChild(t *testing.T) { @@ -71,3 +74,35 @@ func TestPutAndGetFromChild(t *testing.T) { t.Fatalf("Fail: got %x expected %x", valueRes, testValue) } } + +func TestChildTrieHashAfterClear(t *testing.T) { + trieThatHoldsAChildTrie := NewEmptyTrie() + originalEmptyHash := trieThatHoldsAChildTrie.MustHash() + + keyToChild := []byte("crowdloan") + keyInChild := []byte("account-alice") + contributed := uint64(1000) + contributedWith := make([]byte, 8) + binary.BigEndian.PutUint64(contributedWith, contributed) + + err := trieThatHoldsAChildTrie.PutIntoChild(keyToChild, keyInChild, contributedWith) + require.NoError(t, err) + + // the parent trie hash SHOULT NOT BE EQUAL to the original + // empty hash since it contains a value + require.NotEqual(t, originalEmptyHash, trieThatHoldsAChildTrie.MustHash()) + + // ensure the value is inside the child trie + valueStored, err := trieThatHoldsAChildTrie.GetFromChild(keyToChild, keyInChild) + require.NoError(t, err) + require.Equal(t, contributed, binary.BigEndian.Uint64(valueStored)) + + // clear child trie key value + err = trieThatHoldsAChildTrie.ClearFromChild(keyToChild, keyInChild) + require.NoError(t, err) + + // the parent trie hash SHOULD BE EQUAL to the original + // empty hash since now it does not have any other value in it + require.Equal(t, originalEmptyHash, trieThatHoldsAChildTrie.MustHash()) + +}