Skip to content

Commit

Permalink
transaction: introduce TxOutput namedtuple
Browse files Browse the repository at this point in the history
Using a namedtuple is an intermediate step before creating a proper class for this. It encapsulate the 3-tuples and makes the code easier to read (attributes rather than indices)

backport of spesmilo#4596 and removal of a bunch of code Transaction methods that are unused.
  • Loading branch information
PiRK committed May 19, 2023
1 parent f3acde0 commit 231c003
Show file tree
Hide file tree
Showing 16 changed files with 152 additions and 125 deletions.
15 changes: 12 additions & 3 deletions electrumabc/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# Many of the functions in this file are copied from ElectrumX
from __future__ import annotations

import abc
import hashlib
import struct
from collections import namedtuple
Expand All @@ -47,6 +48,14 @@
hex_to_bytes = bytes.fromhex


class DestinationType(abc.ABC):
"""Base class for TxOutput destination types"""

@abc.abstractmethod
def to_ui_string(self) -> str:
pass


class AddressError(Exception):
"""Exception used for Address errors."""

Expand Down Expand Up @@ -106,7 +115,7 @@ def double_sha256(x):
return sha256(sha256(x))


class UnknownAddress:
class UnknownAddress(DestinationType):
def to_ui_string(self):
return "<UnknownAddress>"

Expand Down Expand Up @@ -233,7 +242,7 @@ def __repr__(self):
return "<PubKey {}>".format(self.__str__())


class ScriptOutput(namedtuple("ScriptAddressTuple", "script")):
class ScriptOutput(namedtuple("ScriptAddressTuple", "script"), DestinationType):
@classmethod
def from_string(self, string):
"""Instantiate from a mixture of opcodes and raw data."""
Expand Down Expand Up @@ -341,7 +350,7 @@ def protocol_factory(script):


# A namedtuple for easy comparison and unique hashing
class Address(namedtuple("AddressTuple", "hash160 kind")):
class Address(namedtuple("AddressTuple", "hash160 kind"), DestinationType):
# Address kinds
ADDR_P2PKH = 0
ADDR_P2SH = 1
Expand Down
5 changes: 3 additions & 2 deletions electrumabc/coinchooser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .bitcoin import CASH, TYPE_ADDRESS, sha256
from .printerror import PrintError
from .transaction import Transaction
from .transaction import Transaction, TxOutput
from .util import NotEnoughFunds


Expand Down Expand Up @@ -167,7 +167,8 @@ def change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
dust = sum(amount for amount in amounts if amount < dust_threshold)
amounts = [amount for amount in amounts if amount >= dust_threshold]
change = [
(TYPE_ADDRESS, addr, amount) for addr, amount in zip(change_addrs, amounts)
TxOutput(TYPE_ADDRESS, addr, amount)
for addr, amount in zip(change_addrs, amounts)
]
self.print_error("change:", change)
if dust:
Expand Down
6 changes: 3 additions & 3 deletions electrumabc/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .plugins import run_hook
from .printerror import print_error
from .simple_config import SimpleConfig
from .transaction import OPReturn, Transaction, multisig_script, tx_from_str
from .transaction import OPReturn, Transaction, TxOutput, multisig_script, tx_from_str
from .util import format_satoshis, json_decode, to_bytes
from .version import PACKAGE_VERSION
from .wallet import create_new_wallet, restore_wallet_from_text
Expand Down Expand Up @@ -427,7 +427,7 @@ def serialize(self, jsontx):
txin["num_sig"] = 1

outputs = [
(TYPE_ADDRESS, Address.from_string(x["address"]), int(x["value"]))
TxOutput(TYPE_ADDRESS, Address.from_string(x["address"]), int(x["value"]))
for x in outputs
]
tx = Transaction.from_io(
Expand Down Expand Up @@ -696,7 +696,7 @@ def _mktx(
for address, amount in outputs:
address = self._resolver(address)
amount = satoshis(amount)
final_outputs.append((TYPE_ADDRESS, address, amount))
final_outputs.append(TxOutput(TYPE_ADDRESS, address, amount))

coins = self.wallet.get_spendable_coins(domain, self.config)
if feerate is not None:
Expand Down
10 changes: 5 additions & 5 deletions electrumabc/paymentrequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from .bitcoin import TYPE_ADDRESS
from .constants import PROJECT_NAME, PROJECT_NAME_NO_SPACES, XEC
from .printerror import PrintError, print_error
from .transaction import Transaction
from .transaction import Transaction, TxOutput
from .util import FileImportFailed, FileImportFailedEncrypted, bfh, bh2u
from .version import PACKAGE_VERSION

Expand Down Expand Up @@ -164,7 +164,7 @@ def parse(self, r):
self.outputs = []
for o in self.details.outputs:
addr = transaction.get_address_from_output_script(o.script)[1]
self.outputs.append((TYPE_ADDRESS, addr, o.amount))
self.outputs.append(TxOutput(TYPE_ADDRESS, addr, o.amount))
self.memo = self.details.memo
self.payment_url = self.details.payment_url

Expand Down Expand Up @@ -272,10 +272,10 @@ def get_expiration_date(self):
def get_amount(self):
return sum(map(lambda x: x[2], self.outputs))

def get_address(self):
def get_address(self) -> str:
o = self.outputs[0]
assert o[0] == TYPE_ADDRESS
return o[1].to_ui_string()
assert o.type == TYPE_ADDRESS
return o.destination.to_ui_string()

def get_requestor(self):
return self.requestor if self.requestor else self.get_address()
Expand Down
15 changes: 10 additions & 5 deletions electrumabc/tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,27 @@ def test_tx_unsigned(self):
tx.as_dict(), {"hex": unsigned_blob, "complete": False, "final": True}
)
self.assertEqual(
tx.get_outputs(),
[(o.destination, o.value) for o in tx.outputs()],
[(Address.from_string("1MYXdf4moacvaEKZ57ozerpJ3t9xSeN6LK"), 20112408)],
)
self.assertEqual(
tx.get_output_addresses(),
[o.destination for o in tx.outputs()],
[Address.from_string("1MYXdf4moacvaEKZ57ozerpJ3t9xSeN6LK")],
)

def tx_has_address(addr: Address) -> bool:
return any(addr == o.destination for o in tx.outputs()) or (
addr in (inp.get("address") for inp in tx.inputs())
)

self.assertTrue(
tx.has_address(Address.from_string("1MYXdf4moacvaEKZ57ozerpJ3t9xSeN6LK"))
tx_has_address(Address.from_string("1MYXdf4moacvaEKZ57ozerpJ3t9xSeN6LK"))
)
self.assertTrue(
tx.has_address(Address.from_string("13Vp8Y3hD5Cb6sERfpxePz5vGJizXbWciN"))
tx_has_address(Address.from_string("13Vp8Y3hD5Cb6sERfpxePz5vGJizXbWciN"))
)
self.assertFalse(
tx.has_address(Address.from_string("1CQj15y1N7LDHp7wTt28eoD1QhHgFgxECH"))
tx_has_address(Address.from_string("1CQj15y1N7LDHp7wTt28eoD1QhHgFgxECH"))
)

self.assertEqual(tx.serialize(), unsigned_blob)
Expand Down
50 changes: 26 additions & 24 deletions electrumabc/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import random
import struct
import warnings
from typing import Optional, Tuple, Union
from typing import List, NamedTuple, Optional, Tuple, Union

import ecdsa

from . import bitcoin, schnorr
from .address import (
Address,
DestinationType,
P2PKH_prefix,
P2PKH_suffix,
P2SH_prefix,
Expand Down Expand Up @@ -68,6 +69,13 @@ class InputValueMissing(ValueError):
"""thrown when the value of an input is needed but not present"""


class TxOutput(NamedTuple):
type: int
destination: DestinationType
# str when the output is set to max: '!'
value: Union[int, str]


class BCDataStream(object):
def __init__(self):
self.input = None
Expand Down Expand Up @@ -324,7 +332,7 @@ def parse_redeemScript(s):

def get_address_from_output_script(
_bytes: bytes,
) -> Tuple[int, Union[Address, PublicKey, ScriptOutput]]:
) -> Tuple[int, Union[PublicKey, DestinationType]]:
"""Return the type of the output and the address"""
scriptlen = len(_bytes)

Expand Down Expand Up @@ -408,7 +416,7 @@ def parse_input(vds):
return d


def parse_output(vds, i):
def parse_output(vds: BCDataStream, i: int):
d = {}
d["value"] = vds.read_int64()
scriptPubKey = vds.read_bytes(vds.read_compact_size())
Expand Down Expand Up @@ -463,7 +471,7 @@ def __init__(self, raw, sign_schnorr=False):
else:
raise RuntimeError("cannot initialize transaction", raw)
self._inputs = None
self._outputs = None
self._outputs: Optional[List[TxOutput]] = None
self.locktime = 0
self.version = 2
self._sign_schnorr = sign_schnorr
Expand Down Expand Up @@ -511,7 +519,7 @@ def inputs(self):
self.deserialize()
return self._inputs

def outputs(self):
def outputs(self) -> List[TxOutput]:
if self._outputs is None:
self.deserialize()
return self._outputs
Expand Down Expand Up @@ -615,7 +623,9 @@ def deserialize(self):
d = deserialize(self.raw)
self.invalidate_common_sighash_cache()
self._inputs = d["inputs"]
self._outputs = [(x["type"], x["address"], x["value"]) for x in d["outputs"]]
self._outputs = [
TxOutput(x["type"], x["address"], x["value"]) for x in d["outputs"]
]
assert all(
isinstance(output[1], (PublicKey, Address, ScriptOutput))
for output in self._outputs
Expand All @@ -625,7 +635,14 @@ def deserialize(self):
return d

@classmethod
def from_io(klass, inputs, outputs, locktime=0, sign_schnorr=False, version=None):
def from_io(
klass,
inputs,
outputs: List[TxOutput],
locktime=0,
sign_schnorr=False,
version=None,
):
assert all(
isinstance(output[1], (PublicKey, Address, ScriptOutput))
for output in outputs
Expand Down Expand Up @@ -1143,21 +1160,6 @@ def _sign_txin(self, i, j, sec, compressed, *, use_cache=False):
txin["pubkeys"][j] = pubkey # needed for fd keys
return txin

def get_outputs(self):
"""convert pubkeys to addresses"""
o = []
for type, addr, v in self.outputs():
o.append((addr, v)) # consider using yield (addr, v)
return o

def get_output_addresses(self):
return [addr for addr, val in self.get_outputs()]

def has_address(self, addr):
return (addr in self.get_output_addresses()) or (
addr in (tx.get("address") for tx in self.inputs())
)

def is_final(self):
return not any(
[
Expand Down Expand Up @@ -1641,7 +1643,7 @@ def output_for_stringdata(op_return):
op_return_payload = op_return_encoded.hex()
script = op_return_code + op_return_payload
amount = 0
return bitcoin.TYPE_SCRIPT, ScriptOutput.from_string(script), amount
return TxOutput(bitcoin.TYPE_SCRIPT, ScriptOutput.from_string(script), amount)

@staticmethod
def output_for_rawhex(op_return):
Expand All @@ -1660,7 +1662,7 @@ def output_for_rawhex(op_return):
_("OP_RETURN script too large, needs to be no longer than 223 bytes")
)
amount = 0
return (
return TxOutput(
bitcoin.TYPE_SCRIPT,
ScriptOutput.protocol_factory(op_return_script),
amount,
Expand Down
Loading

0 comments on commit 231c003

Please sign in to comment.