Skip to content

Commit

Permalink
Fix estimations in v0.6 RPC methods (#1740)
Browse files Browse the repository at this point in the history
  • Loading branch information
kirugan authored Mar 12, 2024
1 parent 6026484 commit e7469e9
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 64 deletions.
16 changes: 8 additions & 8 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions node/throttled_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *Thrott
}

func (tvm *ThrottledVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, state core.StateReader,
network *utils.Network, maxSteps uint64,
network *utils.Network, maxSteps uint64, useBlobData bool,
) ([]*felt.Felt, error) {
var ret []*felt.Felt
return ret, tvm.Do(func(vm *vm.VM) error {
var err error
ret, err = (*vm).Call(callInfo, blockInfo, state, network, maxSteps)
ret, err = (*vm).Call(callInfo, blockInfo, state, network, maxSteps, useBlobData)
return err
})
}

func (tvm *ThrottledVM) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt,
blockInfo *vm.BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert bool,
blockInfo *vm.BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool,
) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, error) {
var ret []*felt.Felt
var traces []vm.TransactionTrace
var dataGasConsumed []*felt.Felt
return ret, dataGasConsumed, traces, tvm.Do(func(vm *vm.VM) error {
var err error
ret, dataGasConsumed, traces, err = (*vm).Execute(txns, declaredClasses, paidFeesOnL1, blockInfo, state, network,
skipChargeFee, skipValidate, errOnRevert)
skipChargeFee, skipValidate, errOnRevert, useBlobData)
return err
})
}
73 changes: 48 additions & 25 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,15 @@ func (h *Handler) Version() (string, *jsonrpc.Error) {
}

// https://github.com/starkware-libs/starknet-specs/blob/e0b76ed0d8d8eba405e182371f9edac8b2bcbc5a/api/starknet_api_openrpc.json#L401-L445
func (h *Handler) Call(call FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Error) { //nolint:gocritic
func (h *Handler) Call(funcCall FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Error) { //nolint:gocritic
return h.call(funcCall, id, true)
}

func (h *Handler) CallV0_6(call FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Error) { //nolint:gocritic
return h.call(call, id, false)
}

func (h *Handler) call(funcCall FunctionCall, id BlockID, useBlobData bool) ([]*felt.Felt, *jsonrpc.Error) { //nolint:gocritic
state, closer, rpcErr := h.stateByBlockID(&id)
if rpcErr != nil {
return nil, rpcErr
Expand All @@ -1330,7 +1338,7 @@ func (h *Handler) Call(call FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Er
return nil, rpcErr
}

classHash, err := state.ContractClassHash(&call.ContractAddress)
classHash, err := state.ContractClassHash(&funcCall.ContractAddress)
if err != nil {
return nil, ErrContractNotFound
}
Expand All @@ -1341,14 +1349,14 @@ func (h *Handler) Call(call FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Er
}

res, err := h.vm.Call(&vm.CallInfo{
ContractAddress: &call.ContractAddress,
Selector: &call.EntryPointSelector,
Calldata: call.Calldata,
ContractAddress: &funcCall.ContractAddress,
Selector: &funcCall.EntryPointSelector,
Calldata: funcCall.Calldata,
ClassHash: classHash,
}, &vm.BlockInfo{
Header: header,
BlockHashToBeRevealed: blockHashToBeRevealed,
}, state, h.bcReader.Network(), h.callMaxSteps)
}, state, h.bcReader.Network(), h.callMaxSteps, useBlobData)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrInternal.CloneWithData(throttledVMErr)
Expand Down Expand Up @@ -1451,6 +1459,27 @@ func (h *Handler) EstimateFeeV0_6(broadcastedTxns []BroadcastedTransaction,
}

func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic
return h.estimateMessageFee(msg, id, h.EstimateFee)
}

func (h *Handler) EstimateMessageFeeV0_6(msg MsgFromL1, id BlockID) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic
feeEstimate, rpcErr := h.estimateMessageFee(msg, id, h.EstimateFeeV0_6)
if rpcErr != nil {
return nil, rpcErr
}

feeEstimate.v0_6Response = true
feeEstimate.DataGasPrice = nil
feeEstimate.DataGasConsumed = nil

return feeEstimate, nil
}

type estimateFeeHandler func(broadcastedTxns []BroadcastedTransaction,
simulationFlags []SimulationFlag, id BlockID,
) ([]FeeEstimate, *jsonrpc.Error)

func (h *Handler) estimateMessageFee(msg MsgFromL1, id BlockID, f estimateFeeHandler) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic
calldata := make([]*felt.Felt, 0, len(msg.Payload)+1)
// The order of the calldata parameters matters. msg.From must be prepended.
calldata = append(calldata, new(felt.Felt).SetBytes(msg.From.Bytes()))
Expand All @@ -1470,7 +1499,7 @@ func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *
// Must be greater than zero to successfully execute transaction.
PaidFeeOnL1: new(felt.Felt).SetUint64(1),
}
estimates, rpcErr := h.EstimateFee([]BroadcastedTransaction{tx}, nil, id)
estimates, rpcErr := f([]BroadcastedTransaction{tx}, nil, id)
if rpcErr != nil {
if rpcErr.Code == ErrTransactionExecutionError.Code {
data := rpcErr.Data.(TransactionExecutionErrorData)
Expand All @@ -1481,19 +1510,6 @@ func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *
return &estimates[0], nil
}

func (h *Handler) EstimateMessageFeeV0_6(msg MsgFromL1, id BlockID) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic
feeEstimate, err := h.EstimateMessageFee(msg, id)
if err != nil {
return nil, err
}

feeEstimate.v0_6Response = true
feeEstimate.DataGasPrice = nil
feeEstimate.DataGasConsumed = nil

return feeEstimate, nil
}

// TraceTransaction returns the trace for a given executed transaction, including internal calls
//
// It follows the specification defined here:
Expand Down Expand Up @@ -1591,8 +1607,9 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra
Header: header,
BlockHashToBeRevealed: blockHashToBeRevealed,
}
useBlobData := !v0_6Response
overallFees, dataGasConsumed, traces, err := h.vm.Execute(txns, classes, paidFeesOnL1, &blockInfo,
state, h.bcReader.Network(), skipFeeCharge, skipValidate, errOnRevert)
state, h.bcReader.Network(), skipFeeCharge, skipValidate, errOnRevert, useBlobData)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrInternal.CloneWithData(throttledVMErr)
Expand Down Expand Up @@ -1625,8 +1642,13 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra
}
}

dataGasFee := new(felt.Felt).Mul(dataGasConsumed[i], dataGasPrice)
gasConsumed := new(felt.Felt).Sub(overallFee, dataGasFee)
var gasConsumed *felt.Felt
if !v0_6Response {
dataGasFee := new(felt.Felt).Mul(dataGasConsumed[i], dataGasPrice)
gasConsumed = new(felt.Felt).Sub(overallFee, dataGasFee)
} else {
gasConsumed = overallFee.Clone()
}
gasConsumed = gasConsumed.Div(gasConsumed, gasPrice) // division by zero felt is zero felt

estimate := FeeEstimate{
Expand Down Expand Up @@ -1758,8 +1780,9 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block,
BlockHashToBeRevealed: blockHashToBeRevealed,
}

useBlobData := !v0_6Response
overallFees, dataGasConsumed, traces, err := h.vm.Execute(block.Transactions, classes, paidFeesOnL1, &blockInfo, state, network, false,
false, false)
false, false, useBlobData)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrInternal.CloneWithData(throttledVMErr)
Expand Down Expand Up @@ -2197,7 +2220,7 @@ func (h *Handler) MethodsV0_6() ([]jsonrpc.Method, string) { //nolint: funlen
{
Name: "starknet_call",
Params: []jsonrpc.Parameter{{Name: "request"}, {Name: "block_id"}},
Handler: h.Call,
Handler: h.CallV0_6,
},
{
Name: "starknet_estimateFee",
Expand Down
26 changes: 13 additions & 13 deletions rpc/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2905,7 +2905,7 @@ func TestCall(t *testing.T) {
ClassHash: classHash,
Selector: selector,
Calldata: calldata,
}, &vm.BlockInfo{Header: headsHeader}, gomock.Any(), &utils.Mainnet, uint64(1337)).Return(expectedRes, nil)
}, &vm.BlockInfo{Header: headsHeader}, gomock.Any(), &utils.Mainnet, uint64(1337), true).Return(expectedRes, nil)

res, rpcErr := handler.Call(rpc.FunctionCall{
ContractAddress: *contractAddr,
Expand Down Expand Up @@ -2953,7 +2953,7 @@ func TestEstimateMessageFee(t *testing.T) {
expectedGasConsumed := new(felt.Felt).SetUint64(37)
mockVM.EXPECT().Execute(gomock.Any(), gomock.Any(), gomock.Any(), &vm.BlockInfo{
Header: latestHeader,
}, gomock.Any(), &utils.Mainnet, gomock.Any(), false, true).DoAndReturn(
}, gomock.Any(), &utils.Mainnet, gomock.Any(), false, true, false).DoAndReturn(
func(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo,
state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert bool,
) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, error) {
Expand Down Expand Up @@ -3055,7 +3055,7 @@ func TestTraceTransaction(t *testing.T) {
vmTrace := new(vm.TransactionTrace)
require.NoError(t, json.Unmarshal(vmTraceJSON, vmTrace))
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, []*felt.Felt{},
&vm.BlockInfo{Header: header}, gomock.Any(), &utils.Mainnet, false, false, false).Return(nil, []vm.TransactionTrace{*vmTrace}, nil)
&vm.BlockInfo{Header: header}, gomock.Any(), &utils.Mainnet, false, false, false, false).Return(nil, []vm.TransactionTrace{*vmTrace}, nil)

trace, err := handler.TraceTransaction(context.Background(), *hash)
require.Nil(t, err)
Expand Down Expand Up @@ -3085,7 +3085,7 @@ func TestSimulateTransactions(t *testing.T) {
t.Run("ok with zero values, skip fee", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &vm.BlockInfo{
Header: headsHeader,
}, mockState, &network, true, false, false).
}, mockState, &network, true, false, false, false).
Return([]*felt.Felt{}, []vm.TransactionTrace{}, nil)

_, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
Expand All @@ -3095,7 +3095,7 @@ func TestSimulateTransactions(t *testing.T) {
t.Run("ok with zero values, skip validate", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &vm.BlockInfo{
Header: headsHeader,
}, mockState, &network, false, false, false).
}, mockState, &network, false, false, false, false).
Return([]*felt.Felt{}, []vm.TransactionTrace{}, nil)

_, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag})
Expand All @@ -3105,7 +3105,7 @@ func TestSimulateTransactions(t *testing.T) {
t.Run("transaction execution error", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &vm.BlockInfo{
Header: headsHeader,
}, mockState, &network, false, false, false).
}, mockState, &network, false, false, false, false).
Return(nil, nil, vm.TransactionExecutionError{
Index: 44,
Cause: errors.New("oops"),
Expand All @@ -3119,7 +3119,7 @@ func TestSimulateTransactions(t *testing.T) {

mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &vm.BlockInfo{
Header: headsHeader,
}, mockState, &network, false, true, true).
}, mockState, &network, false, true, true, false).
Return(nil, nil, vm.TransactionExecutionError{
Index: 44,
Cause: errors.New("oops"),
Expand Down Expand Up @@ -3215,7 +3215,7 @@ func TestTraceBlockTransactions(t *testing.T) {
vmTrace := vm.TransactionTrace{}
require.NoError(t, json.Unmarshal(vmTraceJSON, &vmTrace))
mockVM.EXPECT().Execute(block.Transactions, []core.Class{declaredClass.Class}, paidL1Fees, &vm.BlockInfo{Header: header},
gomock.Any(), &network, false, false, false).Return(nil, []vm.TransactionTrace{vmTrace, vmTrace}, nil)
gomock.Any(), &network, false, false, false, false).Return(nil, []vm.TransactionTrace{vmTrace, vmTrace}, nil)

result, err := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
require.Nil(t, err)
Expand Down Expand Up @@ -3281,7 +3281,7 @@ func TestTraceBlockTransactions(t *testing.T) {
vmTrace := vm.TransactionTrace{}
require.NoError(t, json.Unmarshal(vmTraceJSON, &vmTrace))
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, []*felt.Felt{}, &vm.BlockInfo{Header: header},
gomock.Any(), &network, false, false, false).Return(nil, []vm.TransactionTrace{vmTrace}, nil)
gomock.Any(), &network, false, false, false, false).Return(nil, []vm.TransactionTrace{vmTrace}, nil)

expectedResult := []rpc.TracedBlockTransaction{
{
Expand Down Expand Up @@ -3651,23 +3651,23 @@ func TestEstimateFee(t *testing.T) {

blockInfo := vm.BlockInfo{Header: &core.Header{}}
t.Run("ok with zero values", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, &network, true, true, false).
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, &network, true, true, false, false).
Return([]*felt.Felt{}, []vm.TransactionTrace{}, nil)

_, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{}, rpc.BlockID{Latest: true})
require.Nil(t, err)
})

t.Run("ok with zero values, skip validate", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, &network, true, true, false).
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, &network, true, true, false, false).
Return([]*felt.Felt{}, []vm.TransactionTrace{}, nil)

_, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true})
require.Nil(t, err)
})

t.Run("transaction execution error", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, &network, true, true, false).
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, &network, true, true, false, false).
Return(nil, nil, vm.TransactionExecutionError{
Index: 44,
Cause: errors.New("oops"),
Expand All @@ -3679,7 +3679,7 @@ func TestEstimateFee(t *testing.T) {
ExecutionError: "oops",
}), err)

mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, &network, false, true, true).
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, &network, false, true, true, false).
Return(nil, nil, vm.TransactionExecutionError{
Index: 44,
Cause: errors.New("oops"),
Expand Down
Loading

0 comments on commit e7469e9

Please sign in to comment.