Skip to content

Commit

Permalink
Implement starknet_estimateMessageFee
Browse files Browse the repository at this point in the history
  • Loading branch information
joshklop committed Jul 20, 2023
1 parent cbbfd20 commit 3bc9dd7
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 21 deletions.
5 changes: 5 additions & 0 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ func makeRPC(httpPort, wsPort uint16, rpcHandler *rpc.Handler, log utils.SimpleL
Params: []jsonrpc.Parameter{{Name: "request"}, {Name: "block_id"}},
Handler: rpcHandler.EstimateFee,
},
{
Name: "starknet_estimateMessageFee",
Params: []jsonrpc.Parameter{{Name: "message"}, {Name: "block_id"}},
Handler: rpcHandler.EstimateMessageFee,
},
}

jsonrpcServer := jsonrpc.NewServer(log).WithValidator(validator.Validator())
Expand Down
31 changes: 29 additions & 2 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,8 +1082,9 @@ func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction, id Block
var txns []core.Transaction
var classes []core.Class

var paidFeeOnL1_ *felt.Felt

Check warning on line 1085 in rpc/handlers.go

View check run for this annotation

Codecov / codecov/patch

rpc/handlers.go#L1085

Added line #L1085 was not covered by tests
for idx := range broadcastedTxns {
txn, declaredClass, aErr := adaptBroadcastedTransaction(&broadcastedTxns[idx], h.network)
txn, declaredClass, paidFeeOnL1, aErr := adaptBroadcastedTransaction(&broadcastedTxns[idx], h.network)

Check warning on line 1087 in rpc/handlers.go

View check run for this annotation

Codecov / codecov/patch

rpc/handlers.go#L1087

Added line #L1087 was not covered by tests
if aErr != nil {
return nil, jsonrpc.Err(jsonrpc.InvalidParams, aErr.Error())
}
Expand All @@ -1092,6 +1093,7 @@ func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction, id Block
if declaredClass != nil {
classes = append(classes, declaredClass)
}
paidFeeOnL1_ = paidFeeOnL1

Check warning on line 1096 in rpc/handlers.go

View check run for this annotation

Codecov / codecov/patch

rpc/handlers.go#L1096

Added line #L1096 was not covered by tests
}

blockNumber := header.Number
Expand All @@ -1103,7 +1105,7 @@ func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction, id Block
blockNumber = height + 1
}

gasesConsumed, err := vm.Execute(txns, classes, blockNumber, header.Timestamp, header.SequencerAddress, state, h.network)
gasesConsumed, err := vm.Execute(txns, classes, blockNumber, header.Timestamp, header.SequencerAddress, state, h.network, paidFeeOnL1_)

Check warning on line 1108 in rpc/handlers.go

View check run for this annotation

Codecov / codecov/patch

rpc/handlers.go#L1108

Added line #L1108 was not covered by tests
if err != nil {
rpcErr := *ErrContractError
rpcErr.Data = err.Error()
Expand All @@ -1120,3 +1122,28 @@ func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction, id Block

return estimates, nil
}

func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *jsonrpc.Error) {
calldata := make([]*felt.Felt, 0, len(msg.Payload)+1)
// The order of the calldata parameters matters. msg.From must be prepended.
calldata = append(calldata, new(felt.Felt).SetBytes(msg.From.Bytes()))
calldata = append(calldata, msg.Payload...)
tx := BroadcastedTransaction{
Transaction: Transaction{
Type: TxnL1Handler,
ContractAddress: msg.To,
EntryPointSelector: msg.Selector,
CallData: &calldata,
Version: new(felt.Felt), // Needed for transaction hash calculation.
Nonce: new(felt.Felt), // Needed for transaction hash calculation.
},
// Needed to marshal to blockifier type.
// Must be greater than zero to successfully execute transaction.
PaidFeeOnL1: new(felt.Felt).SetUint64(1),
}
estimates, rpcErr := h.EstimateFee([]BroadcastedTransaction{tx}, id)
if rpcErr != nil {
return nil, rpcErr
}
return &estimates[0], nil

Check warning on line 1148 in rpc/handlers.go

View check run for this annotation

Codecov / codecov/patch

rpc/handlers.go#L1126-L1148

Added lines #L1126 - L1148 were not covered by tests
}
38 changes: 28 additions & 10 deletions rpc/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ type Transaction struct {
CompiledClassHash *felt.Felt `json:"compiled_class_hash,omitempty" validate:"required_if=Type DECLARE Version 0x2"`
}

type MsgFromL1 struct {
// The address of the L1 contract sending the message.
From common.Address `json:"from_address" validate:"required"`
// The address of the L1 contract sending the message.
To *felt.Felt `json:"to_address" validate:"required"`
// The payload of the message.
Payload []*felt.Felt `json:"payload" validate:"required"`
Selector *felt.Felt `json:"entry_point_selector" validate:"required"`
}

type MsgToL1 struct {
From *felt.Felt `json:"from_address"`
To common.Address `json:"to_address"`
Expand Down Expand Up @@ -128,6 +138,7 @@ type DeclareTxResponse struct {
type BroadcastedTransaction struct {
Transaction
ContractClass json.RawMessage `json:"contract_class,omitempty" validate:"required_if=Transaction.Type DECLARE"`
PaidFeeOnL1 *felt.Felt `json:"paid_fee_on_l1,omitempty" validate:"required_if=Transaction.Type L1_HANDLER"`
}

type FeeEstimate struct {
Expand All @@ -136,34 +147,37 @@ type FeeEstimate struct {
OverallFee *felt.Felt `json:"overall_fee"`
}

func adaptBroadcastedTransaction(broadcastedTxn *BroadcastedTransaction, network utils.Network) (core.Transaction, core.Class, error) {
//nolint:gocyclo
func adaptBroadcastedTransaction(broadcastedTxn *BroadcastedTransaction,
network utils.Network,
) (core.Transaction, core.Class, *felt.Felt, error) {

Check warning on line 153 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L153

Added line #L153 was not covered by tests
var feederTxn feeder.Transaction
if err := copier.Copy(&feederTxn, broadcastedTxn.Transaction); err != nil {
return nil, nil, err
return nil, nil, nil, err

Check warning on line 156 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L156

Added line #L156 was not covered by tests
}
feederTxn.Type = broadcastedTxn.Type.String()

txn, err := feeder2core.AdaptTransaction(&feederTxn)
if err != nil {
return nil, nil, err
return nil, nil, nil, err

Check warning on line 162 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L162

Added line #L162 was not covered by tests
}

var declaredClass core.Class
if len(broadcastedTxn.ContractClass) != 0 {
declaredClass, err = adaptDeclaredClass(broadcastedTxn.ContractClass)
if err != nil {
return nil, nil, err
return nil, nil, nil, err

Check warning on line 169 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L169

Added line #L169 was not covered by tests
}
} else if broadcastedTxn.Type == TxnDeclare {
return nil, nil, errors.New("declare without a class definition")
return nil, nil, nil, errors.New("declare without a class definition")

Check warning on line 172 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L172

Added line #L172 was not covered by tests
}

if t, ok := txn.(*core.DeclareTransaction); ok {
switch c := declaredClass.(type) {
case *core.Cairo0Class:
t.ClassHash, err = vm.Cairo0ClassHash(c)
if err != nil {
return nil, nil, err
return nil, nil, nil, err

Check warning on line 180 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L180

Added line #L180 was not covered by tests
}
case *core.Cairo1Class:
t.ClassHash = c.Hash()
Expand All @@ -172,22 +186,26 @@ func adaptBroadcastedTransaction(broadcastedTxn *BroadcastedTransaction, network

txnHash, err := core.TransactionHash(txn, network)
if err != nil {
return nil, nil, err
return nil, nil, nil, err

Check warning on line 189 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L189

Added line #L189 was not covered by tests
}

var paidFeeOnL1 *felt.Felt

Check warning on line 192 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L192

Added line #L192 was not covered by tests
switch t := txn.(type) {
case *core.DeclareTransaction:
t.TransactionHash = txnHash
case *core.InvokeTransaction:
t.TransactionHash = txnHash
case *core.DeployAccountTransaction:
t.TransactionHash = txnHash
case *core.L1HandlerTransaction:
t.TransactionHash = txnHash
paidFeeOnL1 = broadcastedTxn.PaidFeeOnL1

Check warning on line 202 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L200-L202

Added lines #L200 - L202 were not covered by tests
default:
return nil, nil, errors.New("unsupported transaction")
return nil, nil, nil, errors.New("unsupported transaction")

Check warning on line 204 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L204

Added line #L204 was not covered by tests
}

if txn.Hash() == nil {
return nil, nil, errors.New("deprecated transaction type")
return nil, nil, nil, errors.New("deprecated transaction type")

Check warning on line 208 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L208

Added line #L208 was not covered by tests
}
return txn, declaredClass, nil
return txn, declaredClass, paidFeeOnL1, nil

Check warning on line 210 in rpc/transaction.go

View check run for this annotation

Codecov / codecov/patch

rpc/transaction.go#L210

Added line #L210 was not covered by tests
}
2 changes: 1 addition & 1 deletion vm/rust/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ all:
cargo build --release

clean:
rm -rf target
rm -rf target
96 changes: 91 additions & 5 deletions vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@ use blockifier::{
},
state::cached_state::CachedState,
transaction::{
objects::AccountTransactionContext, transaction_execution::Transaction,
transactions::ExecutableTransaction,
objects::{
AccountTransactionContext, TransactionExecutionInfo,
}, transaction_execution::Transaction,
transactions::{
ExecutableTransaction, Executable,
},
transaction_utils::calculate_tx_resources,
transaction_types::TransactionType,
},
fee::fee_utils::calculate_tx_fee,
};
use cairo_lang_starknet::casm_contract_class::CasmContractClass;
use cairo_lang_starknet::contract_class::ContractClass as SierraContractClass;
Expand All @@ -34,10 +41,11 @@ use starknet_api::{
block::{BlockNumber, BlockTimestamp},
deprecated_contract_class::EntryPointType,
hash::StarkFelt,
transaction::TransactionSignature,
};
use starknet_api::{
core::PatriciaKey,
transaction::{Calldata, Transaction as StarknetApiTransaction},
transaction::{Calldata, Transaction as StarknetApiTransaction, Fee},
};
use starknet_api::{
core::{ChainId, ContractAddress, EntryPointSelector},
Expand Down Expand Up @@ -120,6 +128,7 @@ pub extern "C" fn cairoVMExecute(
block_timestamp: c_ulonglong,
chain_id: *const c_char,
sequencer_address: *const c_uchar,
paid_fee_on_l1_str: *const c_char,
) {
let reader = JunoStateReader::new(reader_handle);
let chain_id_str = unsafe { CStr::from_ptr(chain_id) }.to_str().unwrap();
Expand Down Expand Up @@ -177,15 +186,92 @@ pub extern "C" fn cairoVMExecute(
_ => None,
};

let txn = Transaction::from_api(sn_api_txn.clone(), contract_class, None);
let paid_fee_on_l1: Option<Fee> = match sn_api_txn.clone() {
StarknetApiTransaction::L1Handler(_) => {
let paid_fee_on_l1_str = unsafe { CStr::from_ptr(paid_fee_on_l1_str) }.to_str().unwrap();
let paid_fee_on_l1 = match u128::from_str_radix(paid_fee_on_l1_str, 16) {
Ok(i) => Fee(i),
Err(e) => {
report_error(reader_handle, format!("failed to convert string to u128 reason:{:?}", e).as_str());
return;
}
}
;
Some(paid_fee_on_l1)
},
_ => None,
};


let txn = Transaction::from_api(sn_api_txn.clone(), contract_class, paid_fee_on_l1);
if txn.is_err() {
report_error(reader_handle, txn.unwrap_err().to_string().as_str());
return;
}

let res = match txn.unwrap() {
Transaction::AccountTransaction(t) => t.execute(&mut state, &block_context),
Transaction::L1HandlerTransaction(t) => t.execute(&mut state, &block_context),
Transaction::L1HandlerTransaction(t) => {
// Manually inline L1HandlerTransaction.execute and execute_raw since the
// `actual_fee` in Blockifier is zero (i.e., don't charge for L1
// Handler transactions on L2). This way we can get the full
// TransactionExecutionResult.
// https://github.com/starkware-libs/blockifier/issues/734
let mut transactional_state = CachedState::create_transactional(&mut state);
// Inlined L1HandlerTransaction.execute_raw
let execution_result = || -> Result<TransactionExecutionInfo, _> {
let tx = &t.tx;
let tx_context = AccountTransactionContext {
transaction_hash: tx.transaction_hash,
max_fee: Fee::default(),
version: tx.version,
signature: TransactionSignature::default(),
nonce: tx.nonce,
sender_address: tx.contract_address,
};
let mut resources = ExecutionResources::default();
let mut context = EntryPointExecutionContext::new(
block_context.clone(),
tx_context,
block_context.invoke_tx_max_n_steps,
);
let mut remaining_gas = Transaction::initial_gas();
let execute_call_info =
t.run_execute(&mut transactional_state, &mut resources, &mut context, &mut remaining_gas)?;

let call_infos =
if let Some(call_info) = execute_call_info.as_ref() { vec![call_info] } else { vec![] };
// The calldata includes the "from" field, which is not a part of the payload.
let l1_handler_payload_size = Some(tx.calldata.0.len() - 1);
let actual_resources = calculate_tx_resources(
resources,
&call_infos,
TransactionType::L1Handler,
&mut transactional_state,
l1_handler_payload_size,
)?;
let actual_fee = calculate_tx_fee(&actual_resources, &context.block_context)?;

Ok(TransactionExecutionInfo {
validate_call_info: None,
execute_call_info,
fee_transfer_call_info: None,
actual_fee,
actual_resources,
})
}();

match execution_result {
Ok(value) => {
transactional_state.commit();
Ok(value)
}
Err(error) => {
transactional_state.abort();
Err(error)
}
}
},
};

match res {
Expand Down
2 changes: 2 additions & 0 deletions vm/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func marshalTxn(txn core.Transaction) (json.RawMessage, error) {
txnMap["Declare"] = map[string]any{
"V" + clearQueryBit(t.Version).Text(felt.Base10): t,
}
case *core.L1HandlerTransaction:
txnMap["L1Handler"] = t
default:
return nil, errors.New("unsupported txn type")
}
Expand Down
9 changes: 6 additions & 3 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package vm
// char* chain_id);
//
// extern void cairoVMExecute(char* txns_json, char* classes_json, uintptr_t readerHandle, unsigned long long block_number,
// unsigned long long block_timestamp, char* chain_id, char* sequencer_address);
// unsigned long long block_timestamp, char* chain_id, char* sequencer_address, char* paid_fee_on_l1);
//
// #cgo LDFLAGS: -L./rust/target/release -ljuno_starknet_rs -lm -ldl
import "C"
Expand Down Expand Up @@ -117,7 +117,7 @@ func Call(contractAddr, selector *felt.Felt, calldata []felt.Felt, blockNumber,

// Execute executes a given transaction set and returns the gas spent per transaction
func Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeeOnL1 *felt.Felt,
) ([]*felt.Felt, error) {
context := &callContext{
state: state,
Expand All @@ -130,6 +130,7 @@ func Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber,
return nil, err
}

paidFeeOnL1CStr := C.CString(paidFeeOnL1.Text(felt.Base16))
txnsJSONCstr := C.CString(string(txnsJSON))
classesJSONCStr := C.CString(string(classesJSON))

Expand All @@ -141,9 +142,11 @@ func Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber,
C.ulonglong(blockNumber),
C.ulonglong(blockTimestamp),
chainID,
(*C.char)(unsafe.Pointer(&sequencerAddressBytes[0])))
(*C.char)(unsafe.Pointer(&sequencerAddressBytes[0])),
paidFeeOnL1CStr)

C.free(unsafe.Pointer(classesJSONCStr))
C.free(unsafe.Pointer(paidFeeOnL1CStr))
C.free(unsafe.Pointer(txnsJSONCstr))
C.free(unsafe.Pointer(chainID))

Expand Down

0 comments on commit 3bc9dd7

Please sign in to comment.