Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VM concurrency support #2059

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/juno/juno.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ const (
callMaxStepsF = "rpc-call-max-steps"
corsEnableF = "rpc-cors-enable"
versionedConstantsFileF = "versioned-constants-file"
vmConcurrencyModeF = "vm-concurrency-mode"

defaultConfig = ""
defaulHost = "localhost"
Expand Down Expand Up @@ -119,6 +120,7 @@ const (
defaultGwTimeout = 5 * time.Second
defaultCorsEnable = false
defaultVersionedConstantsFile = ""
defaultVMConcurrencyMode = false

configFlagUsage = "The YAML configuration file."
logLevelFlagUsage = "Options: trace, debug, info, warn, error."
Expand Down Expand Up @@ -170,6 +172,7 @@ const (
"The upper limit is 4 million steps, and any higher value will still be capped at 4 million."
corsEnableUsage = "Enable CORS on RPC endpoints"
versionedConstantsFileUsage = "Use custom versioned constants from provided file"
vmConcurrencyModeUsage = "Enable VM concurrency mode"
)

var Version string
Expand Down Expand Up @@ -355,6 +358,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr
junoCmd.Flags().Bool(corsEnableF, defaultCorsEnable, corsEnableUsage)
junoCmd.Flags().String(versionedConstantsFileF, defaultVersionedConstantsFile, versionedConstantsFileUsage)
junoCmd.MarkFlagsMutuallyExclusive(p2pFeederNodeF, p2pPeersF)
junoCmd.Flags().Bool(vmConcurrencyModeF, defaultVMConcurrencyMode, vmConcurrencyModeUsage)

junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath))

Expand Down
3 changes: 2 additions & 1 deletion node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type Config struct {
PendingPollInterval time.Duration `mapstructure:"pending-poll-interval"`
RemoteDB string `mapstructure:"remote-db"`
VersionedConstantsFile string `mapstructure:"versioned-constants-file"`
VMConcurrencyMode bool `mapstructure:"vm-concurrency-mode"`

Metrics bool `mapstructure:"metrics"`
MetricsHost string `mapstructure:"metrics-host"`
Expand Down Expand Up @@ -179,7 +180,7 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen
services = append(services, synchronizer)
}

throttledVM := NewThrottledVM(vm.New(false, log), cfg.MaxVMs, int32(cfg.MaxVMQueue))
throttledVM := NewThrottledVM(vm.New(cfg.VMConcurrencyMode, log), cfg.MaxVMs, int32(cfg.MaxVMQueue))

var syncReader sync.Reader = &sync.NoopSynchronizer{}
if synchronizer != nil {
Expand Down
2 changes: 1 addition & 1 deletion node/throttled_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type ThrottledVM struct {

func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *ThrottledVM {
return &ThrottledVM{
Throttler: utils.NewThrottler[vm.VM](concurrenyBudget, &res).WithMaxQueueLen(maxQueueLen),
Throttler: utils.NewThrottler(concurrenyBudget, &res).WithMaxQueueLen(maxQueueLen),
}
}

Expand Down
3 changes: 2 additions & 1 deletion vm/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
serde = "1.0.208"
serde_json = { version = "1.0.125", features = ["raw_value"] }
blockifier = "0.8.0-rc.3"
blockifier = { version = "0.8.0-rc.3", features = ["concurrency"] }
starknet_api = "0.13.0-rc.1"
cairo-vm = "=1.0.1"
starknet-types-core = { version = "0.1.5", features = ["hash", "prime-bigint"] }
Expand All @@ -18,6 +18,7 @@ once_cell = "1.19.0"
lazy_static = "1.4.0"
semver = "1.0.22"
anyhow = "1.0.81"
num_cpus = "1.15"

[lib]
crate-type = ["staticlib"]
94 changes: 68 additions & 26 deletions vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ use blockifier::bouncer::BouncerConfig;
use blockifier::fee::{fee_utils, gas_usage};
use blockifier::transaction::objects::GasVector;
use blockifier::{
blockifier::{
config::{ConcurrencyConfig, TransactionExecutorConfig},
transaction_executor::{TransactionExecutor, TransactionExecutorError},
},
context::{BlockContext, ChainInfo, FeeTokenAddresses, TransactionContext},
execution::{
contract_class::ClassInfo,
Expand All @@ -33,7 +37,6 @@ use blockifier::{
},
objects::{DeprecatedTransactionInfo, HasRelatedFeeType, TransactionInfo},
transaction_execution::Transaction,
transactions::ExecutableTransaction,
},
versioned_constants::VersionedConstants,
};
Expand Down Expand Up @@ -230,11 +233,26 @@ pub extern "C" fn cairoVMExecute(
None,
concurrency_mode,
);
let charge_fee = skip_charge_fee == 0;
let validate = skip_validate == 0;
let _charge_fee = skip_charge_fee == 0;
let _validate = skip_validate == 0;

let mut trace_buffer = Vec::with_capacity(10_000);

let n_workers = num_cpus::get() / 2;
// Initialize the TransactionExecutor
let config = TransactionExecutorConfig {
concurrency_config: ConcurrencyConfig {
enabled: concurrency_mode,
chunk_size: n_workers * 3,
n_workers,
},
};

let mut executor = TransactionExecutor::new(state, block_context.clone(), config);

let mut transactions: Vec<Transaction> = Vec::new();

// Prepare transactions
for (txn_index, txn_and_query_bit) in txns_and_query_bits.iter().enumerate() {
let class_info = match txn_and_query_bit.txn.clone() {
StarknetApiTransaction::Declare(_) => {
Expand Down Expand Up @@ -277,37 +295,43 @@ pub extern "C" fn cairoVMExecute(
return;
}

let mut txn_state = CachedState::create_transactional(&mut state);
let fee_type;
let minimal_l1_gas_amount_vector: Option<GasVector>;
let res = match txn.unwrap() {
Transaction::AccountTransaction(t) => {
fee_type = t.fee_type();
minimal_l1_gas_amount_vector =
Some(gas_usage::estimate_minimal_gas_vector(&block_context, &t).unwrap());
t.execute(&mut txn_state, &block_context, charge_fee, validate)
}
Transaction::L1HandlerTransaction(t) => {
fee_type = t.fee_type();
minimal_l1_gas_amount_vector = None;
t.execute(&mut txn_state, &block_context, charge_fee, validate)
match txn {
Ok(txn) => transactions.push(txn),
Err(_) => {
report_error(
reader_handle,
"failed to create transaction",
txn_index as i64,
);
return;
}
};
}
}

// Execute transactions
let results = executor.execute_txs(&transactions);
let mut block_state = executor.block_state.take().unwrap();

// Process results
for (txn_index, res) in results.into_iter().enumerate() {
match res {
Err(error) => {
let err_string = match &error {
ContractConstructorExecutionFailed(e) => format!("{error} {e}"),
ExecutionError { error: e, .. } | ValidateTransactionError { error: e, .. } => {
format!("{error} {e}")
}
TransactionExecutorError::TransactionExecutionError(err) => match err {
ContractConstructorExecutionFailed(e) => format!("{error} {e}"),
ExecutionError { error: e, .. }
| ValidateTransactionError { error: e, .. } => {
format!("{error} {e}")
}
other => other.to_string(),
},
other => other.to_string(),
};
report_error(
reader_handle,
format!(
"failed txn {} reason: {}",
txn_and_query_bit.txn_hash, err_string,
txns_and_query_bits[txn_index].txn_hash, err_string,
)
.as_str(),
txn_index as i64,
Expand All @@ -326,6 +350,20 @@ pub extern "C" fn cairoVMExecute(

// we are estimating fee, override actual fee calculation
if t.transaction_receipt.fee.0 == 0 {
let minimal_l1_gas_amount_vector: Option<GasVector>;
let fee_type;
match &transactions[txn_index] {
Transaction::AccountTransaction(at) => {
fee_type = at.fee_type();
minimal_l1_gas_amount_vector = Some(
gas_usage::estimate_minimal_gas_vector(&block_context, at).unwrap(),
);
}
Transaction::L1HandlerTransaction(ht) => {
fee_type = ht.fee_type();
minimal_l1_gas_amount_vector = None;
}
}
let minimal_l1_gas_amount_vector =
minimal_l1_gas_amount_vector.unwrap_or_default();
let gas_consumed = t
Expand Down Expand Up @@ -359,8 +397,13 @@ pub extern "C" fn cairoVMExecute(
.try_into()
.unwrap_or(u64::MAX);

let trace =
jsonrpc::new_transaction_trace(&txn_and_query_bit.txn, t, &mut txn_state);
let mut txn_state = CachedState::create_transactional(&mut block_state);

let trace = jsonrpc::new_transaction_trace(
&txns_and_query_bits[txn_index].txn,
t,
&mut txn_state,
);
if let Err(e) = trace {
report_error(
reader_handle,
Expand All @@ -381,7 +424,6 @@ pub extern "C" fn cairoVMExecute(
append_trace(reader_handle, trace.as_ref().unwrap(), &mut trace_buffer);
}
}
txn_state.commit();
}
}

Expand Down