Skip to content

Commit

Permalink
Refactor commitment parallel processing (#2169)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnkushinDaniil authored Oct 1, 2024
1 parent 6b683d8 commit 0c0700c
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 166 deletions.
28 changes: 22 additions & 6 deletions core/receipt.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,31 @@ func messagesSentHash(messages []*L2ToL1Message) *felt.Felt {
}

func receiptCommitment(receipts []*TransactionReceipt) (*felt.Felt, error) {
return calculateCommitment(
receipts,
trie.RunOnTempTriePoseidon,
func(receipt *TransactionReceipt) *felt.Felt {
return receipt.hash()
},
)
}

type (
onTempTrieFunc func(uint8, func(*trie.Trie) error) error
processFunc[T any] func(T) *felt.Felt
)

// General function for parallel processing of items and calculation of a commitment
func calculateCommitment[T any](items []T, runOnTempTrie onTempTrieFunc, process processFunc[T]) (*felt.Felt, error) {
var commitment *felt.Felt
return commitment, trie.RunOnTempTriePoseidon(commitmentTrieHeight, func(trie *trie.Trie) error {
numWorkers := min(runtime.GOMAXPROCS(0), len(receipts))
results := make([]*felt.Felt, len(receipts))
return commitment, runOnTempTrie(commitmentTrieHeight, func(trie *trie.Trie) error {
numWorkers := min(runtime.GOMAXPROCS(0), len(items))
results := make([]*felt.Felt, len(items))
var wg sync.WaitGroup
wg.Add(numWorkers)

jobs := make(chan int, len(receipts))
for idx := range receipts {
jobs := make(chan int, len(items))
for idx := range items {
jobs <- idx
}
close(jobs)
Expand All @@ -81,7 +97,7 @@ func receiptCommitment(receipts []*TransactionReceipt) (*felt.Felt, error) {
go func() {
defer wg.Done()
for i := range jobs {
results[i] = receipts[i].hash()
results[i] = process(items[i])
}
}()
}
Expand Down
233 changes: 73 additions & 160 deletions core/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ import (
"errors"
"fmt"
"math/big"
"runtime"
"slices"
"strings"
"sync"

"github.com/Masterminds/semver/v3"
"github.com/NethermindEth/juno/core/crypto"
Expand All @@ -18,7 +16,6 @@ import (
"github.com/bits-and-blooms/bloom/v3"
"github.com/ethereum/go-ethereum/common"
"github.com/fxamacker/cbor/v2"
"github.com/sourcegraph/conc/pool"
"golang.org/x/crypto/sha3"
)

Expand Down Expand Up @@ -632,62 +629,43 @@ const commitmentTrieHeight = 64
// transactionCommitmentPedersen is the root of a height 64 binary Merkle Patricia tree of the
// transaction hashes and signatures in a block.
func transactionCommitmentPedersen(transactions []Transaction, protocolVersion string) (*felt.Felt, error) {
var commitment *felt.Felt
blockVersion, err := ParseBlockVersion(protocolVersion)
if err != nil {
return nil, err
}

v0_11_1 := semver.MustParse("0.11.1")
return commitment, trie.RunOnTempTriePedersen(commitmentTrieHeight, func(trie *trie.Trie) error {
blockVersion, err := ParseBlockVersion(protocolVersion)
if err != nil {
return err
var hashFunc processFunc[Transaction]
if blockVersion.GreaterThanEqual(v0_11_1) {
hashFunc = func(transaction Transaction) *felt.Felt {
signatureHash := crypto.PedersenArray(transaction.Signature()...)
return crypto.Pedersen(transaction.Hash(), signatureHash)
}

for i, transaction := range transactions {
} else {
hashFunc = func(transaction Transaction) *felt.Felt {
signatureHash := crypto.PedersenArray()

// blockVersion >= 0.11.1
if blockVersion.GreaterThanEqual(v0_11_1) {
signatureHash = crypto.PedersenArray(transaction.Signature()...)
} else if _, ok := transaction.(*InvokeTransaction); ok {
if _, ok := transaction.(*InvokeTransaction); ok {
signatureHash = crypto.PedersenArray(transaction.Signature()...)
}

if _, err = trie.Put(new(felt.Felt).SetUint64(uint64(i)),
crypto.Pedersen(transaction.Hash(), signatureHash)); err != nil {
return err
}
}
root, err := trie.Root()
if err != nil {
return err
return crypto.Pedersen(transaction.Hash(), signatureHash)
}
commitment = root
return nil
})
}
return calculateCommitment(transactions, trie.RunOnTempTriePedersen, hashFunc)
}

func transactionCommitmentPoseidon(transactions []Transaction) (*felt.Felt, error) {
var commitment *felt.Felt
return commitment, trie.RunOnTempTriePoseidon(commitmentTrieHeight, func(trie *trie.Trie) error {
for i, transaction := range transactions {
var digest crypto.PoseidonDigest
digest.Update(transaction.Hash())

switch transaction.(type) {
case *DeployTransaction, *L1HandlerTransaction:
digest.Update(&felt.Zero)
default:
digest.Update(transaction.Signature()...)
}

if _, err := trie.Put(new(felt.Felt).SetUint64(uint64(i)), digest.Finish()); err != nil {
return err
}
return calculateCommitment(transactions, trie.RunOnTempTriePoseidon, func(transaction Transaction) *felt.Felt {
var digest crypto.PoseidonDigest
digest.Update(transaction.Hash())

switch transaction.(type) {
case *DeployTransaction, *L1HandlerTransaction:
digest.Update(&felt.Zero)
default:
digest.Update(transaction.Signature()...)
}
root, err := trie.Root()
if err != nil {
return err
}
commitment = root
return nil

return digest.Finish()
})
}

Expand All @@ -706,125 +684,60 @@ func ParseBlockVersion(protocolVersion string) (*semver.Version, error) {
return semver.NewVersion(strings.Join(digits[:3], sep))
}

type eventWithTxHash struct {
Event *Event
TxHash *felt.Felt
}

// eventCommitmentPoseidon computes the event commitment for a block.
func eventCommitmentPoseidon(receipts []*TransactionReceipt) (*felt.Felt, error) {
var commitment *felt.Felt
return commitment, trie.RunOnTempTriePoseidon(commitmentTrieHeight, func(trie *trie.Trie) error {
eventCount := uint64(0)
numWorkers := runtime.GOMAXPROCS(0)
receiptPerWorker := len(receipts) / numWorkers
if receiptPerWorker == 0 {
receiptPerWorker = 1
}
workerPool := pool.New().WithErrors().WithMaxGoroutines(numWorkers)
var trieMutex sync.Mutex

for receiptIdx := range receipts {
if receiptIdx%receiptPerWorker == 0 {
curReceiptIdx := receiptIdx
curEventIdx := eventCount

workerPool.Go(func() error {
maxIndex := curReceiptIdx + receiptPerWorker
if maxIndex > len(receipts) {
maxIndex = len(receipts)
}
receiptsSliced := receipts[curReceiptIdx:maxIndex]

for _, receipt := range receiptsSliced {
for _, event := range receipt.Events {
hashElems := []*felt.Felt{event.From, receipt.TransactionHash}
hashElems = append(hashElems, new(felt.Felt).SetUint64(uint64(len(event.Keys))))
hashElems = append(hashElems, event.Keys...)
hashElems = append(hashElems, new(felt.Felt).SetUint64(uint64(len(event.Data))))
hashElems = append(hashElems, event.Data...)

eventHash := crypto.PoseidonArray(hashElems...)

eventTrieKey := new(felt.Felt).SetUint64(curEventIdx)
trieMutex.Lock()
_, err := trie.Put(eventTrieKey, eventHash)
trieMutex.Unlock()
if err != nil {
return err
}
curEventIdx++
}
}
return nil
})
}
eventCount += uint64(len(receipts[receiptIdx].Events))
}
if err := workerPool.Wait(); err != nil {
return err
}
root, err := trie.Root()
if err != nil {
return err
eventCounter := 0
for _, receipt := range receipts {
eventCounter += len(receipt.Events)
}
items := make([]*eventWithTxHash, 0, eventCounter)
for _, receipt := range receipts {
for _, event := range receipt.Events {
items = append(items, &eventWithTxHash{
Event: event,
TxHash: receipt.TransactionHash,
})
}
commitment = root
return nil
}
return calculateCommitment(items, trie.RunOnTempTriePoseidon, func(item *eventWithTxHash) *felt.Felt {
return crypto.PoseidonArray(
slices.Concat(
[]*felt.Felt{
item.Event.From,
item.TxHash,
new(felt.Felt).SetUint64(uint64(len(item.Event.Keys))),
},
item.Event.Keys,
[]*felt.Felt{
new(felt.Felt).SetUint64(uint64(len(item.Event.Data))),
},
item.Event.Data,
)...,
)
})
}

// eventCommitmentPedersen computes the event commitment for a block.
func eventCommitmentPedersen(receipts []*TransactionReceipt) (*felt.Felt, error) {
var commitment *felt.Felt
return commitment, trie.RunOnTempTriePedersen(commitmentTrieHeight, func(trie *trie.Trie) error {
eventCount := uint64(0)
numWorkers := runtime.GOMAXPROCS(0)
receiptPerWorker := len(receipts) / numWorkers
if receiptPerWorker == 0 {
receiptPerWorker = 1
}
workerPool := pool.New().WithErrors().WithMaxGoroutines(numWorkers)
var trieMutex sync.Mutex

for receiptIdx := range receipts {
if receiptIdx%receiptPerWorker == 0 {
curReceiptIdx := receiptIdx
curEventIdx := eventCount

workerPool.Go(func() error {
maxIndex := curReceiptIdx + receiptPerWorker
if maxIndex > len(receipts) {
maxIndex = len(receipts)
}
receiptsSliced := receipts[curReceiptIdx:maxIndex]

for _, receipt := range receiptsSliced {
for _, event := range receipt.Events {
eventHash := crypto.PedersenArray(
event.From,
crypto.PedersenArray(event.Keys...),
crypto.PedersenArray(event.Data...),
)

eventTrieKey := new(felt.Felt).SetUint64(curEventIdx)
trieMutex.Lock()
_, err := trie.Put(eventTrieKey, eventHash)
trieMutex.Unlock()
if err != nil {
return err
}
curEventIdx++
}
}
return nil
})
}
eventCount += uint64(len(receipts[receiptIdx].Events))
}
if err := workerPool.Wait(); err != nil {
return err
}
root, err := trie.Root()
if err != nil {
return err
}
commitment = root
return nil
eventCounter := 0
for _, receipt := range receipts {
eventCounter += len(receipt.Events)
}
events := make([]*Event, 0, eventCounter)
for _, receipt := range receipts {
events = append(events, receipt.Events...)
}
return calculateCommitment(events, trie.RunOnTempTriePedersen, func(event *Event) *felt.Felt {
return crypto.PedersenArray(
event.From,
crypto.PedersenArray(event.Keys...),
crypto.PedersenArray(event.Data...),
)
})
}

Expand Down

0 comments on commit 0c0700c

Please sign in to comment.