Skip to content

Commit

Permalink
fix: revert cairo tx when evm tx with cairo precompile fails (kkrt-la…
Browse files Browse the repository at this point in the history
…bs#1368)

<!--- Please provide a general summary of your changes in the title
above -->

<!-- Give an estimate of the time you spent on this PR in terms of work
days.
Did you spend 0.5 days on this PR or rather 2 days?  -->

Time spent on this PR:

## Pull request type

<!-- Please try to limit your pull request to one type,
submit multiple pull requests if needed. -->

Please check the type of change your PR introduces:

- [x] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying,
or link to a relevant issue. -->

Resolves kkrt-labs#1364

## What is the new behavior?
- revert cairo tx when evm tx fails with cairo precompile called
- fix error send back for input validation in cairo precompile

<!-- Reviewable:start -->
- - -
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/1368)
<!-- Reviewable:end -->
  • Loading branch information
obatirou committed Aug 30, 2024
1 parent 2395950 commit f1e7ebb
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 13 deletions.
6 changes: 6 additions & 0 deletions kakarot_scripts/utils/kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class EvmTransactionError(Exception):
pass


class StarknetTransactionError(Exception):
pass


@functools.lru_cache()
def get_solidity_artifacts(
contract_app: str,
Expand Down Expand Up @@ -654,6 +658,8 @@ async def send_starknet_transaction(
if event.from_address == evm_account.address
and event.keys[0] == starknet_keccak(b"transaction_executed")
]
if receipt.execution_status.name == "REVERTED":
raise StarknetTransactionError(f"Starknet tx reverted: {receipt.revert_reason}")
if len(transaction_events) != 1:
raise ValueError("Cannot locate the single event giving the actual tx status")
(
Expand Down
12 changes: 9 additions & 3 deletions kakarot_scripts/utils/starknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from starknet_py.hash.transaction import TransactionHashPrefix, compute_transaction_hash
from starknet_py.hash.utils import message_signature
from starknet_py.net.account.account import Account
from starknet_py.net.client_errors import ClientError
from starknet_py.net.client_models import Call, DeclareTransactionResponse
from starknet_py.net.full_node_client import _create_broadcasted_txn
from starknet_py.net.models.transaction import DeclareV1
Expand Down Expand Up @@ -473,10 +474,15 @@ async def upgrade(contract_name, *args):

logger.info(f"ℹ️ {contract_name} already deployed, checking version.")
class_hash = get_declarations()
try:
deployed_class_hash = await RPC_CLIENT.get_class_hash_at(
deployments[contract_name]["address"]
)
except ClientError as e:
if "Contract not found" in str(e):
logger.info(f"ℹ️ deploying {contract_name}.")
return await deploy(contract_name, *args)

deployed_class_hash = await RPC_CLIENT.get_class_hash_at(
deployments[contract_name]["address"]
)
if deployed_class_hash != class_hash[contract_name]:
logger.info(f"ℹ️ redeploying {contract_name}.")
return await deploy(contract_name, *args)
Expand Down
2 changes: 2 additions & 0 deletions src/kakarot/evm.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace EVM {
is_create=evm.message.is_create,
depth=evm.message.depth,
env=evm.message.env,
cairo_precompile_called=evm.message.cairo_precompile_called,
);

tempvar evm = new model.EVM(
Expand Down Expand Up @@ -279,6 +280,7 @@ namespace EVM {
is_create=self.message.is_create,
depth=self.message.depth,
env=self.message.env,
cairo_precompile_called=self.message.cairo_precompile_called,
);

if (is_valid_jumpdest == FALSE) {
Expand Down
47 changes: 42 additions & 5 deletions src/kakarot/instructions/system_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ namespace SystemOperations {
is_create=TRUE,
depth=evm.message.depth + 1,
env=evm.message.env,
cairo_precompile_called=evm.message.cairo_precompile_called,
);
let child_evm = EVM.init(message, gas_limit);
let stack = Stack.init();
Expand Down Expand Up @@ -900,6 +901,7 @@ namespace CallHelper {
is_create=FALSE,
depth=evm.message.depth + 1,
env=evm.message.env,
cairo_precompile_called=evm.message.cairo_precompile_called,
);

let child_evm = EVM.init(message, gas);
Expand Down Expand Up @@ -945,12 +947,29 @@ namespace CallHelper {
code_account, evm.message.valid_jumpdests_start, evm.message.valid_jumpdests
);
State.update_account(code_account);

tempvar message = new model.Message(
bytecode=evm.message.parent.evm.message.bytecode,
bytecode_len=evm.message.parent.evm.message.bytecode_len,
valid_jumpdests_start=evm.message.parent.evm.message.valid_jumpdests_start,
valid_jumpdests=evm.message.parent.evm.message.valid_jumpdests,
calldata=evm.message.parent.evm.message.calldata,
calldata_len=evm.message.parent.evm.message.calldata_len,
value=evm.message.parent.evm.message.value,
caller=evm.message.parent.evm.message.caller,
parent=evm.message.parent.evm.message.parent,
address=evm.message.parent.evm.message.address,
code_address=evm.message.parent.evm.message.code_address,
read_only=evm.message.parent.evm.message.read_only,
is_create=evm.message.parent.evm.message.is_create,
depth=evm.message.parent.evm.message.depth,
env=evm.message.parent.evm.message.env,
cairo_precompile_called=evm.message.cairo_precompile_called,
);
if (evm.reverted == Errors.EXCEPTIONAL_HALT) {
// If the call has halted exceptionnaly, the return_data is empty
// and nothing is copied to memory, and the gas is not returned;
tempvar evm = new model.EVM(
message=evm.message.parent.evm.message,
message=message,
return_data_len=0,
return_data=evm.return_data,
program_counter=evm.message.parent.evm.program_counter + 1,
Expand All @@ -969,7 +988,7 @@ namespace CallHelper {
Memory.store_n(actual_output_size, evm.return_data, ret_offset.low);

tempvar evm = new model.EVM(
message=evm.message.parent.evm.message,
message=message,
return_data_len=evm.return_data_len,
return_data=evm.return_data,
program_counter=evm.message.parent.evm.program_counter + 1,
Expand Down Expand Up @@ -1146,6 +1165,24 @@ namespace CreateHelper {
}(evm: model.EVM*) -> model.EVM* {
alloc_locals;

tempvar message = new model.Message(
bytecode=evm.message.parent.evm.message.bytecode,
bytecode_len=evm.message.parent.evm.message.bytecode_len,
valid_jumpdests_start=evm.message.parent.evm.message.valid_jumpdests_start,
valid_jumpdests=evm.message.parent.evm.message.valid_jumpdests,
calldata=evm.message.parent.evm.message.calldata,
calldata_len=evm.message.parent.evm.message.calldata_len,
value=evm.message.parent.evm.message.value,
caller=evm.message.parent.evm.message.caller,
parent=evm.message.parent.evm.message.parent,
address=evm.message.parent.evm.message.address,
code_address=evm.message.parent.evm.message.code_address,
read_only=evm.message.parent.evm.message.read_only,
is_create=evm.message.parent.evm.message.is_create,
depth=evm.message.parent.evm.message.depth,
env=evm.message.parent.evm.message.env,
cairo_precompile_called=evm.message.cairo_precompile_called,
);
// Reverted during execution - either REVERT or exceptional
if (evm.reverted != FALSE) {
let is_exceptional_revert = is_not_zero(Errors.REVERT - evm.reverted);
Expand All @@ -1161,7 +1198,7 @@ namespace CreateHelper {
tempvar state = evm.message.parent.state;
tempvar evm = new model.EVM(
message=evm.message.parent.evm.message,
message=message,
return_data_len=return_data_len,
return_data=evm.return_data,
program_counter=evm.message.parent.evm.program_counter + 1,
Expand Down Expand Up @@ -1196,7 +1233,7 @@ namespace CreateHelper {
tempvar state = evm.message.parent.state;
tempvar evm = new model.EVM(
message=evm.message.parent.evm.message,
message=message,
return_data_len=0,
return_data=evm.return_data,
program_counter=evm.message.parent.evm.program_counter + 1,
Expand Down
36 changes: 36 additions & 0 deletions src/kakarot/interpreter.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,37 @@ namespace Interpreter {
let evm_reverted = is_not_zero(evm.reverted);
let success = (1 - precompile_reverted) * (1 - evm_reverted);
let evm = EVM.stop(evm, output_len, output, 1 - success);
let is_cairo_precompile_called = PrecompilesHelpers.is_kakarot_precompile(
evm.message.code_address.evm
);
tempvar message = new model.Message(
bytecode=evm.message.bytecode,
bytecode_len=evm.message.bytecode_len,
valid_jumpdests_start=evm.message.valid_jumpdests_start,
valid_jumpdests=evm.message.valid_jumpdests,
calldata=evm.message.calldata,
calldata_len=evm.message.calldata_len,
value=evm.message.value,
caller=evm.message.caller,
parent=evm.message.parent,
address=evm.message.address,
code_address=evm.message.code_address,
read_only=evm.message.read_only,
is_create=evm.message.is_create,
depth=evm.message.depth,
env=evm.message.env,
cairo_precompile_called=is_cairo_precompile_called,
);
tempvar evm = new model.EVM(
message=message,
return_data_len=evm.return_data_len,
return_data=evm.return_data,
program_counter=evm.program_counter,
stopped=evm.stopped,
gas_left=evm.gas_left,
gas_refund=evm.gas_refund,
reverted=evm.reverted,
);
return evm;
} else {
let (return_data: felt*) = alloc();
Expand Down Expand Up @@ -862,6 +893,7 @@ namespace Interpreter {
is_create=is_deploy_tx,
depth=0,
env=env,
cairo_precompile_called=FALSE,
);

let stack = Stack.init();
Expand Down Expand Up @@ -954,6 +986,10 @@ namespace Interpreter {
// Only the gas fee paid will be committed.
State.finalize{state=state}();
if (evm.reverted != 0) {
with_attr error_message(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles") {
assert evm.message.cairo_precompile_called = FALSE;
}
tempvar state = State.init();
} else {
tempvar state = state;
Expand Down
1 change: 1 addition & 0 deletions src/kakarot/model.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ namespace model {
is_create: felt,
depth: felt,
env: Environment*,
cairo_precompile_called: felt,
}

// @dev Stores all data relevant to the current execution context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from kakarot_scripts.utils.kakarot import deploy
from kakarot_scripts.utils.starknet import get_contract, get_deployments, invoke
from tests.utils.errors import evm_error
from tests.utils.errors import cairo_error

ENTRY_TYPE_INDEX = {"SpotEntry": 0, "FutureEntry": 1, "GenericEntry": 2}

Expand Down Expand Up @@ -150,5 +150,7 @@ async def test_should_fail_unauthorized_caller(self, pragma_caller, data_type):
)
solidity_input = serialize_data_type(data_type)

with evm_error("CairoLib: call_contract failed"):
with cairo_error(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles"
):
await pragma_caller.getDataMedianSpot(solidity_input)
6 changes: 4 additions & 2 deletions tests/end_to_end/CairoPrecompiles/test_cairo_precompiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from kakarot_scripts.utils.kakarot import deploy, get_eoa
from kakarot_scripts.utils.starknet import get_contract, invoke, wait_for_transaction
from tests.utils.errors import evm_error
from tests.utils.errors import cairo_error


@pytest_asyncio.fixture()
Expand Down Expand Up @@ -69,7 +69,9 @@ async def test_should_fail_precompile_caller_not_whitelisted(
cairo_counter_caller = await deploy(
"CairoPrecompiles", "CairoCounterCaller", cairo_counter.address
)
with evm_error("CairoLib: call_contract failed"):
with cairo_error(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles"
):
await cairo_counter_caller.incrementCairoCounter()

async def test_last_caller_address_should_be_eoa(self, cairo_counter_caller):
Expand Down
11 changes: 11 additions & 0 deletions tests/end_to_end/CairoPrecompiles/test_dual_vm_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from kakarot_scripts.utils.starknet import deploy as deploy_starknet
from kakarot_scripts.utils.starknet import get_contract as get_contract_starknet
from kakarot_scripts.utils.starknet import invoke
from tests.utils.errors import cairo_error


@pytest_asyncio.fixture()
Expand Down Expand Up @@ -115,3 +116,13 @@ async def test_should_transfer_from(

assert balance_owner_before - amount == balance_owner_after
assert balance_other_before + amount == balance_other_after

async def test_should_revert_tx_cairo_precompiles(
self, starknet_token, dual_vm_token, owner, other
):
with cairo_error(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles"
):
await dual_vm_token.transfer(
other.address, 1, gas_limit=45_000
) # fails with out of gas
6 changes: 5 additions & 1 deletion tests/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import contextmanager

import pytest
from starknet_py.net.client_errors import ClientError
from web3 import Web3


Expand Down Expand Up @@ -33,7 +34,10 @@ def cairo_error(message=None):
yield e
if message is None:
return
error = re.search(r"Error message: (.*)", str(e.value))
if type(e.value) == ClientError:
error = re.search(r"Error message: (.*)", str(e.value.data["revert_error"]))
else:
error = re.search(r"Error message: (.*)", str(e.value))
error = error.group(1) if error else str(e.value)
assert message == error, f"Expected {message}, got {error}"
finally:
Expand Down
1 change: 1 addition & 0 deletions tests/utils/helpers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace TestHelpers {
is_create=FALSE,
depth=0,
env=env,
cairo_precompile_called=FALSE,
);
let evm: model.EVM* = EVM.init(message, 1000000);
return evm;
Expand Down

0 comments on commit f1e7ebb

Please sign in to comment.