From 5dd96d24c7df855990b23f01528fdd7127a3085b Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Thu, 1 Aug 2019 09:29:03 -0700 Subject: [PATCH 1/2] Add InconsistentSizeError --- hwtypes/bit_vector.py | 8 +++++--- hwtypes/bit_vector_abc.py | 4 ++++ hwtypes/smt_bit_vector.py | 8 +++++--- tests/test_optypes.py | 11 ++++++++--- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/hwtypes/bit_vector.py b/hwtypes/bit_vector.py index 093e0d3..c9a1cae 100644 --- a/hwtypes/bit_vector.py +++ b/hwtypes/bit_vector.py @@ -1,5 +1,5 @@ import typing as tp -from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily +from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily, InconsistentSizeError from .compatibility import IntegerTypes, StringTypes import functools @@ -79,7 +79,9 @@ def ite(self, t_branch, f_branch): fb_t = type(f_branch) BV_t = self.get_family().BitVector if isinstance(t_branch, BV_t) and isinstance(f_branch, BV_t): - if tb_t is not fb_t: + if tb_t.size != fb_t.size: + raise InconsistentSizeError('Both branches must have the same size') + elif tb_t is not fb_t: raise TypeError('Both branches must have the same type') T = tb_t elif isinstance(t_branch, BV_t): @@ -110,7 +112,7 @@ def _coerce(T : tp.Type['BitVector'], val : tp.Any) -> 'BitVector': if not isinstance(val, BitVector): return T(val) elif val.size != T.size: - raise TypeError('Inconsistent size') + raise InconsistentSizeError('Inconsistent size') else: return val diff --git a/hwtypes/bit_vector_abc.py b/hwtypes/bit_vector_abc.py index 96241a3..e4f104d 100644 --- a/hwtypes/bit_vector_abc.py +++ b/hwtypes/bit_vector_abc.py @@ -9,6 +9,10 @@ TypeFamily = namedtuple('TypeFamily', ['Bit', 'BitVector', 'Unsigned', 'Signed']) +# Should be raised when bv[k].op(bv[j]) and j != k + +class InconsistentSizeError(TypeError): pass + #I want to be able differentiate an old style call #BitVector(val, None) from BitVector(val) _MISSING = object() diff --git a/hwtypes/smt_bit_vector.py b/hwtypes/smt_bit_vector.py index 1500a2e..f7b6283 100644 --- a/hwtypes/smt_bit_vector.py +++ b/hwtypes/smt_bit_vector.py @@ -1,7 +1,7 @@ import typing as tp import itertools as it import functools as ft -from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily +from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily, InconsistentSizeError from abc import abstractmethod @@ -134,7 +134,9 @@ def ite(self, t_branch, f_branch): fb_t = type(f_branch) BV_t = self.get_family().BitVector if isinstance(t_branch, BV_t) and isinstance(f_branch, BV_t): - if tb_t is not fb_t: + if tb_t.size != fb_t.size: + raise InconsistentSizeError('Both branches must have the same size') + elif tb_t is not fb_t: raise TypeError('Both branches must have the same type') T = tb_t elif isinstance(t_branch, BV_t): @@ -161,7 +163,7 @@ def _coerce(T : tp.Type['SMTBitVector'], val : tp.Any) -> 'SMTBitVector': if not isinstance(val, SMTBitVector): return T(val) elif val.size != T.size: - raise TypeError('Inconsistent size') + raise InconsistentSizeError('Inconsistent size') else: return val diff --git a/tests/test_optypes.py b/tests/test_optypes.py index c99da89..3ffb9da 100644 --- a/tests/test_optypes.py +++ b/tests/test_optypes.py @@ -4,6 +4,7 @@ from itertools import product from hwtypes import BitVector, Bit +from hwtypes.bit_vector_abc import InconsistentSizeError def _rand_bv(width): return BitVector[width](random.randint(0, (1 << width) - 1)) @@ -31,7 +32,7 @@ def test_bin(op, width1, width2, use_int): y = _rand_bv(width2) if width1 != width2: assert type(x) is not type(y) - with pytest.raises(TypeError): + with pytest.raises(InconsistentSizeError): op(x, y) else: assert type(x) is type(y) @@ -60,7 +61,7 @@ def test_comp(op, width1, width2, use_int): y = _rand_bv(width2) if width1 != width2: assert type(x) is not type(y) - with pytest.raises(TypeError): + with pytest.raises(InconsistentSizeError): op(x, y) else: assert type(x) is type(y) @@ -80,8 +81,12 @@ def test_ite(t_constructor, t_size, f_constructor, f_size): if t_constructor is f_constructor is _rand_bv and t_size == f_size: res = pred.ite(t, f) assert type(res) is type(t) + elif t_constructor is f_constructor is _rand_bv: + # BV with different size + with pytest.raises(InconsistentSizeError): + res = pred.ite(t, f) elif t_constructor is f_constructor: - # either both ints or BV with different size + # both int with pytest.raises(TypeError): res = pred.ite(t, f) elif t_constructor is _rand_bv: From 804130210de34e057c7d7e6150aeac0c898b366f Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Thu, 1 Aug 2019 09:30:32 -0700 Subject: [PATCH 2/2] return NotImplemented from operators --- hwtypes/bit_vector.py | 202 ++++++++++++++++++++++++++++++++----- hwtypes/smt_bit_vector.py | 203 +++++++++++++++++++++++++++++++++----- 2 files changed, 353 insertions(+), 52 deletions(-) diff --git a/hwtypes/bit_vector.py b/hwtypes/bit_vector.py index c9a1cae..2c222f6 100644 --- a/hwtypes/bit_vector.py +++ b/hwtypes/bit_vector.py @@ -28,7 +28,11 @@ def wrapped(self : 'Bit', other : tp.Union['Bit', bool]) -> 'Bit': if isinstance(other, Bit): return fn(self, other) else: - return fn(self, Bit(other)) + try: + other = Bit(other) + except TypeError: + return NotImplemented + return fn(self, other) return wrapped @@ -335,26 +339,139 @@ def bvsrem(self, other): # bvsmod def __invert__(self): return self.bvnot() - def __and__(self, other): return self.bvand(other) - def __or__(self, other): return self.bvor(other) - def __xor__(self, other): return self.bvxor(other) - def __lshift__(self, other): return self.bvshl(other) - def __rshift__(self, other): return self.bvlshr(other) + def __and__(self, other): + try: + return self.bvand(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __or__(self, other): + try: + return self.bvor(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __xor__(self, other): + try: + return self.bvxor(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + + def __lshift__(self, other): + try: + return self.bvshl(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __rshift__(self, other): + try: + return self.bvlshr(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __neg__(self): return self.bvneg() - def __add__(self, other): return self.bvadd(other) - def __sub__(self, other): return self.bvsub(other) - def __mul__(self, other): return self.bvmul(other) - def __floordiv__(self, other): return self.bvudiv(other) - def __mod__(self, other): return self.bvurem(other) - def __eq__(self, other): return self.bveq(other) - def __ne__(self, other): return self.bvne(other) - def __ge__(self, other): return self.bvuge(other) - def __gt__(self, other): return self.bvugt(other) - def __le__(self, other): return self.bvule(other) - def __lt__(self, other): return self.bvult(other) + def __add__(self, other): + try: + return self.bvadd(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __sub__(self, other): + try: + return self.bvsub(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __mul__(self, other): + try: + return self.bvmul(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __floordiv__(self, other): + try: + return self.bvudiv(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __mod__(self, other): + try: + return self.bvurem(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + + def __eq__(self, other): + try: + return self.bveq(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __ne__(self, other): + try: + return self.bvne(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __ge__(self, other): + try: + return self.bvuge(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __gt__(self, other): + try: + return self.bvugt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __le__(self, other): + try: + return self.bvule(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __lt__(self, other): + try: + return self.bvult(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + def as_uint(self): @@ -444,27 +561,60 @@ def __int__(self): return self.as_sint() def __rshift__(self, other): - return self.bvashr(other) + try: + return self.bvashr(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __floordiv__(self, other): - return self.bvsdiv(other) + try: + return self.bvsdiv(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __mod__(self, other): - return self.bvsrem(other) + try: + return self.bvsrem(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __ge__(self, other): - return self.bvsge(other) + try: + return self.bvsge(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __gt__(self, other): - return self.bvsgt(other) + try: + return self.bvsgt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __lt__(self, other): - - return self.bvslt(other) + try: + return self.bvslt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __le__(self, other): - return self.bvsle(other) - + try: + return self.bvsle(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented @staticmethod def random(width): diff --git a/hwtypes/smt_bit_vector.py b/hwtypes/smt_bit_vector.py index f7b6283..b086138 100644 --- a/hwtypes/smt_bit_vector.py +++ b/hwtypes/smt_bit_vector.py @@ -48,7 +48,11 @@ def wrapped(self, other): if isinstance(other, SMTBit): return fn(self, other) else: - return fn(self, SMTBit(other)) + try: + other = SMTBit(other) + except TypeError: + return NotImplemented + return fn(self, other) return wrapped class SMTBit(AbstractBit): @@ -472,28 +476,139 @@ def bvsdiv(self, other): def bvsrem(self, other): return type(self)(smt.BVSRem(self.value, other.value)) - __invert__ = bvnot - __and__ = bvand - __or__ = bvor - __xor__ = bvxor + def __invert__(self): return self.bvnot() + + def __and__(self, other): + try: + return self.bvand(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __or__(self, other): + try: + return self.bvor(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __xor__(self, other): + try: + return self.bvxor(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + + def __lshift__(self, other): + try: + return self.bvshl(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __rshift__(self, other): + try: + return self.bvlshr(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __neg__(self): return self.bvneg() + + def __add__(self, other): + try: + return self.bvadd(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __sub__(self, other): + try: + return self.bvsub(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __mul__(self, other): + try: + return self.bvmul(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __floordiv__(self, other): + try: + return self.bvudiv(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __mod__(self, other): + try: + return self.bvurem(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + + def __eq__(self, other): + try: + return self.bveq(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __ne__(self, other): + try: + return self.bvne(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented - __lshift__ = bvshl - __rshift__ = bvlshr + def __ge__(self, other): + try: + return self.bvuge(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented - __neg__ = bvneg - __add__ = bvadd - __sub__ = bvsub - __mul__ = bvmul - __floordiv__ = bvudiv - __mod__ = bvurem + def __gt__(self, other): + try: + return self.bvugt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented - __eq__ = bveq - __ne__ = bvne - __ge__ = bvuge - __gt__ = bvugt - __le__ = bvule - __lt__ = bvult + def __le__(self, other): + try: + return self.bvule(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + def __lt__(self, other): + try: + return self.bvult(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented @int_cast @@ -540,24 +655,60 @@ class SMTUIntVector(SMTNumVector): class SMTSIntVector(SMTNumVector): def __rshift__(self, other): - return self.bvashr(other) + try: + return self.bvashr(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __floordiv__(self, other): - return self.bvsdiv(other) + try: + return self.bvsdiv(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __mod__(self, other): - return self.bvsrem(other) + try: + return self.bvsrem(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __ge__(self, other): - return self.bvsge(other) + try: + return self.bvsge(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __gt__(self, other): - return self.bvsgt(other) + try: + return self.bvsgt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __lt__(self, other): - return self.bvslt(other) + try: + return self.bvslt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __le__(self, other): - return self.bvsle(other) + try: + return self.bvsle(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + _Family_ = TypeFamily(SMTBit, SMTBitVector, SMTUIntVector, SMTSIntVector)