From 4948fe41f7042e5f3dc8f25a8606bea674f5b113 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 1 Aug 2023 15:05:48 -0700 Subject: [PATCH] Add reflected operators to bit and bv types --- hwtypes/bit_vector.py | 196 ++++++++++++++----------------------- hwtypes/bit_vector_util.py | 3 +- hwtypes/smt_bit_vector.py | 196 ++++++++++++++----------------------- hwtypes/util.py | 23 +++++ tests/test_bv.py | 25 +++++ 5 files changed, 199 insertions(+), 244 deletions(-) diff --git a/hwtypes/bit_vector.py b/hwtypes/bit_vector.py index 50bbedf..6744901 100644 --- a/hwtypes/bit_vector.py +++ b/hwtypes/bit_vector.py @@ -1,6 +1,7 @@ import typing as tp from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily, InconsistentSizeError from .bit_vector_util import build_ite +from .util import Method from .compatibility import IntegerTypes, StringTypes import functools @@ -71,14 +72,27 @@ def __ne__(self, other): def __and__(self, other): return type(self)(self._value & other._value) + @bit_cast + def __rand__(self, other): + return type(self)(other._value & self._value) + @bit_cast def __or__(self, other): return type(self)(self._value | other._value) + @bit_cast + def __ror__(self, other): + return type(self)(other._value | self._value) + @bit_cast def __xor__(self, other): return type(self)(self._value ^ other._value) + @bit_cast + def __rxor__(self, other): + return type(self)(other._value ^ self._value) + + def ite(self, t_branch, f_branch): ''' typing works as follows: @@ -132,6 +146,34 @@ def wrapped(self : 'BitVector', other : tp.Any) -> tp.Any: return fn(self, other) return wrapped + + +def dispatch_oper(method: tp.MethodDescriptorType): + def oper(self, other): + try: + return method(self, other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + return Method(oper) + + +# A little inefficient because of double _coerce but whate;er +def dispatch_roper(method: Method): + def roper(self, other): + try: + other = _coerce(type(self), other) + except inconsistentsizeerror as e: + raise e from None + except TypeError: + return NotImplemented + return method(other, self) + + return Method(roper) + + class BitVector(AbstractBitVector): @staticmethod def get_family() -> TypeFamily: @@ -328,8 +370,6 @@ def bvurem(self, other): return self return type(self)(self.as_uint() % other) - # bvumod - @bv_cast def bvsdiv(self, other): other = other.as_sint() @@ -344,140 +384,46 @@ def bvsrem(self, other): return self return type(self)(self.as_sint() % other) - # bvsmod def __invert__(self): return self.bvnot() - def __and__(self, other): - try: - return self.bvand(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __and__ = dispatch_oper(bvand) + __rand__ = dispatch_roper(__and__) - def __or__(self, other): - try: - return self.bvor(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __or__ = dispatch_oper(bvor) + __ror__ = dispatch_roper(__or__) - def __xor__(self, other): - try: - return self.bvxor(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __xor__ = dispatch_oper(bvxor) + __rxor__ = dispatch_roper(__xor__) + __lshift__ = dispatch_oper(bvshl) + __rlshift__ = dispatch_roper(__lshift__) - def __lshift__(self, other): - try: - return self.bvshl(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __rshift__(self, other): - try: - return self.bvlshr(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __rshift__ = dispatch_oper(bvlshr) + __rrshift__ = dispatch_oper(__rshift__) def __neg__(self): return self.bvneg() - def __add__(self, other): - try: - return self.bvadd(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __sub__(self, other): - try: - return self.bvsub(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __mul__(self, other): - try: - return self.bvmul(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __add__ = dispatch_oper(bvadd) + __radd__ = dispatch_roper(__add__) - def __floordiv__(self, other): - try: - return self.bvudiv(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __sub__ = dispatch_oper(bvsub) + __rsub__ = dispatch_roper(__sub__) - def __mod__(self, other): - try: - return self.bvurem(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __mul__ = dispatch_oper(bvmul) + __rmul__ = dispatch_roper(__mul__) + __floordiv__ = dispatch_oper(bvudiv) + __rfloordiv__ = dispatch_roper(__floordiv__) - def __eq__(self, other): - try: - return self.bveq(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __mod__ = dispatch_oper(bvurem) + __rmod__ = dispatch_roper(__mod__) - def __ne__(self, other): - try: - return self.bvne(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __ge__(self, other): - try: - return self.bvuge(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __gt__(self, other): - try: - return self.bvugt(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __le__(self, other): - try: - return self.bvule(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __lt__(self, other): - try: - return self.bvult(other) - except InconsistentSizeError as e: - raise e from None - except TypeError as e: - return NotImplemented + __eq__ = dispatch_oper(bveq) + __ne__ = dispatch_oper(AbstractBitVector.bvne) + __ge__ = dispatch_oper(AbstractBitVector.bvuge) + __gt__ = dispatch_oper(AbstractBitVector.bvugt) + __le__ = dispatch_oper(AbstractBitVector.bvule) + __lt__ = dispatch_oper(bvult) def as_uint(self): return self._value @@ -565,6 +511,8 @@ def __rshift__(self, other): except TypeError: return NotImplemented + __rrshift__ = dispatch_roper(__rshift__) + def __floordiv__(self, other): try: return self.bvsdiv(other) @@ -573,6 +521,8 @@ def __floordiv__(self, other): except TypeError: return NotImplemented + __rfloordiv__ = dispatch_roper(__floordiv__) + def __mod__(self, other): try: return self.bvsrem(other) @@ -581,6 +531,8 @@ def __mod__(self, other): except TypeError: return NotImplemented + __rmod__ = dispatch_roper(__mod__) + def __ge__(self, other): try: return self.bvsge(other) diff --git a/hwtypes/bit_vector_util.py b/hwtypes/bit_vector_util.py index 2a607ac..2545270 100644 --- a/hwtypes/bit_vector_util.py +++ b/hwtypes/bit_vector_util.py @@ -3,6 +3,7 @@ import inspect import types +from .util import Method from .bit_vector_abc import InconsistentSizeError from .bit_vector_abc import BitVectorMeta, AbstractBitVector, AbstractBit @@ -172,7 +173,7 @@ def VCall(*args, **kwargs): if v0 is NotImplemented or v0 is NotImplemented: return NotImplemented return select.ite(v0, v1) - return VCall + return Method(VCall) def get_branch_type(branch): diff --git a/hwtypes/smt_bit_vector.py b/hwtypes/smt_bit_vector.py index 0aa8c4a..3fb81a0 100644 --- a/hwtypes/smt_bit_vector.py +++ b/hwtypes/smt_bit_vector.py @@ -3,6 +3,7 @@ import functools as ft from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily, InconsistentSizeError from .bit_vector_util import build_ite +from .util import Method from abc import abstractmethod @@ -130,14 +131,26 @@ def __invert__(self) -> 'SMTBit': def __and__(self, other : 'SMTBit') -> 'SMTBit': return type(self)(smt.And(self.value, other.value)) + @bit_cast + def __rand__(self, other): + return type(self)(smt.And(other.value, self.value)) + @bit_cast def __or__(self, other : 'SMTBit') -> 'SMTBit': return type(self)(smt.Or(self.value, other.value)) + @bit_cast + def __ror__(self, other : 'SMTBit') -> 'SMTBit': + return type(self)(smt.Or(other.value, self.value)) + @bit_cast def __xor__(self, other : 'SMTBit') -> 'SMTBit': return type(self)(smt.Xor(self.value, other.value)) + @bit_cast + def __xor__(self, other : 'SMTBit') -> 'SMTBit': + return type(self)(smt.Xor(other.value, self.value)) + def ite(self, t_branch, f_branch): def _ite(select, t_branch, f_branch): return smt.Ite(select.value, t_branch.value, f_branch.value) @@ -164,6 +177,7 @@ def _coerce(T : tp.Type['SMTBitVector'], val : tp.Any) -> 'SMTBitVector': else: return val + def bv_cast(fn : tp.Callable[['SMTBitVector', 'SMTBitVector'], tp.Any]) -> tp.Callable[['SMTBitVector', tp.Any], tp.Any]: @ft.wraps(fn) def wrapped(self : 'SMTBitVector', other : tp.Any) -> tp.Any: @@ -171,6 +185,7 @@ def wrapped(self : 'SMTBitVector', other : tp.Any) -> tp.Any: return fn(self, other) return wrapped + def int_cast(fn : tp.Callable[['SMTBitVector', int], tp.Any]) -> tp.Callable[['SMTBitVector', tp.Any], tp.Any]: @ft.wraps(fn) def wrapped(self : 'SMTBitVector', other : tp.Any) -> tp.Any: @@ -178,6 +193,33 @@ def wrapped(self : 'SMTBitVector', other : tp.Any) -> tp.Any: return fn(self, other) return wrapped + +def dispatch_oper(method: tp.MethodDescriptorType): + def oper(self, other): + try: + return method(self, other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + return Method(oper) + + +# A little inefficient because of double _coerce but whate;er +def dispatch_roper(method: Method): + def roper(self, other): + try: + other = _coerce(type(self), other) + except inconsistentsizeerror as e: + raise e from None + except TypeError: + return NotImplemented + return method(other, self) + + return Method(roper) + + class SMTBitVector(AbstractBitVector): @staticmethod def get_family() -> TypeFamily: @@ -477,138 +519,44 @@ def bvsrem(self, other): def __invert__(self): return self.bvnot() - def __and__(self, other): - try: - return self.bvand(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __or__(self, other): - try: - return self.bvor(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __and__ = dispatch_oper(bvand) + __rand__ = dispatch_roper(__and__) - def __xor__(self, other): - try: - return self.bvxor(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __or__ = dispatch_oper(bvor) + __ror__ = dispatch_roper(__or__) + __xor__ = dispatch_oper(bvxor) + __rxor__ = dispatch_roper(__xor__) - def __lshift__(self, other): - try: - return self.bvshl(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __lshift__ = dispatch_oper(bvshl) + __rlshift__ = dispatch_roper(__lshift__) - def __rshift__(self, other): - try: - return self.bvlshr(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __rshift__ = dispatch_oper(bvlshr) + __rrshift__ = dispatch_oper(__rshift__) def __neg__(self): return self.bvneg() - def __add__(self, other): - try: - return self.bvadd(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __add__ = dispatch_oper(bvadd) + __radd__ = dispatch_roper(__add__) - def __sub__(self, other): - try: - return self.bvsub(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __sub__ = dispatch_oper(bvsub) + __rsub__ = dispatch_roper(__sub__) - def __mul__(self, other): - try: - return self.bvmul(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __mul__ = dispatch_oper(bvmul) + __rmul__ = dispatch_roper(__mul__) - def __floordiv__(self, other): - try: - return self.bvudiv(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __mod__(self, other): - try: - return self.bvurem(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented + __floordiv__ = dispatch_oper(bvudiv) + __rfloordiv__ = dispatch_roper(__floordiv__) + __mod__ = dispatch_oper(bvurem) + __rmod__ = dispatch_roper(__mod__) - def __eq__(self, other): - try: - return self.bveq(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __ne__(self, other): - try: - return self.bvne(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __ge__(self, other): - try: - return self.bvuge(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __gt__(self, other): - try: - return self.bvugt(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __le__(self, other): - try: - return self.bvule(other) - except InconsistentSizeError as e: - raise e from None - except TypeError: - return NotImplemented - - def __lt__(self, other): - try: - return self.bvult(other) - except InconsistentSizeError as e: - raise e from None - except TypeError as e: - return NotImplemented - + __eq__ = dispatch_oper(bveq) + __ne__ = dispatch_oper(AbstractBitVector.bvne) + __ge__ = dispatch_oper(AbstractBitVector.bvuge) + __gt__ = dispatch_oper(AbstractBitVector.bvugt) + __le__ = dispatch_oper(AbstractBitVector.bvule) + __lt__ = dispatch_oper(bvult) @int_cast def repeat(self, other): @@ -672,6 +620,8 @@ def __rshift__(self, other): except TypeError: return NotImplemented + __rrshift__ = dispatch_roper(__rshift__) + def __floordiv__(self, other): try: return self.bvsdiv(other) @@ -680,6 +630,8 @@ def __floordiv__(self, other): except TypeError: return NotImplemented + __rfloordiv__ = dispatch_roper(__floordiv__) + def __mod__(self, other): try: return self.bvsrem(other) @@ -688,6 +640,8 @@ def __mod__(self, other): except TypeError: return NotImplemented + __rmod__ = dispatch_roper(__mod__) + def __ge__(self, other): try: return self.bvsge(other) @@ -709,7 +663,7 @@ def __lt__(self, other): return self.bvslt(other) except InconsistentSizeError as e: raise e from None - except TypeError as e: + except TypeError: return NotImplemented def __le__(self, other): diff --git a/hwtypes/util.py b/hwtypes/util.py index a68bf27..cf52ebc 100644 --- a/hwtypes/util.py +++ b/hwtypes/util.py @@ -1,6 +1,7 @@ from collections import OrderedDict from collections.abc import Mapping, MutableMapping import typing as tp +import types class FrozenDict(Mapping): __slots__ = '_d', '_hash' @@ -124,6 +125,28 @@ def setter(self, fset): def deleter(self, fdel): return type(self)(self.T)(self.fget, self.fset, fdel, self.__doc__) +class Method: + ''' + Method descriptor which automatically sets the name of the bound function + ''' + def __init__(self, m): + self.m = m + + def __get__(self, obj, objtype=None): + if obj is not None: + return types.MethodType(self.m, obj) + else: + return self.m + + def __set_name__(self, owner, name): + self.m.__name__ = name + self.m.__qualname__ = owner.__qualname__ + '.' + name + + + def __call__(self, *args, **kwargs): + # HACK + # need this because of vcall works + return self.m(*args, **kwargs) def _issubclass(sub : tp.Any, parent : type) -> bool: try: diff --git a/tests/test_bv.py b/tests/test_bv.py index 1dfea3d..13dcef8 100644 --- a/tests/test_bv.py +++ b/tests/test_bv.py @@ -127,3 +127,28 @@ def test_operator_by_0(op, reference): I0, I1 = BitVector.random(5), 0 expected = unsigned(reference(int(I0), int(I1)), 5) assert expected == int(op(I0, I1)) + + + +@pytest.mark.parametrize("op", [ + operator.add, + operator.mul, + operator.sub, + operator.floordiv, + operator.mod, + operator.and_, + operator.or_, + operator.xor, + operator.lshift, + operator.rshift, + operator.eq, + operator.ne, + operator.lt, + operator.le, + operator.gt, + operator.ge, + ]) +def test_coercion(op): + a = BitVector.random(16) + b = BitVector.random(16) + assert op(a, b) == op(int(a), b) == op(a, int(b))