diff --git a/hwtypes/bit_vector.py b/hwtypes/bit_vector.py index c9a1cae..2c222f6 100644 --- a/hwtypes/bit_vector.py +++ b/hwtypes/bit_vector.py @@ -28,7 +28,11 @@ def wrapped(self : 'Bit', other : tp.Union['Bit', bool]) -> 'Bit': if isinstance(other, Bit): return fn(self, other) else: - return fn(self, Bit(other)) + try: + other = Bit(other) + except TypeError: + return NotImplemented + return fn(self, other) return wrapped @@ -335,26 +339,139 @@ def bvsrem(self, other): # bvsmod def __invert__(self): return self.bvnot() - def __and__(self, other): return self.bvand(other) - def __or__(self, other): return self.bvor(other) - def __xor__(self, other): return self.bvxor(other) - def __lshift__(self, other): return self.bvshl(other) - def __rshift__(self, other): return self.bvlshr(other) + def __and__(self, other): + try: + return self.bvand(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __or__(self, other): + try: + return self.bvor(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __xor__(self, other): + try: + return self.bvxor(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + + def __lshift__(self, other): + try: + return self.bvshl(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __rshift__(self, other): + try: + return self.bvlshr(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __neg__(self): return self.bvneg() - def __add__(self, other): return self.bvadd(other) - def __sub__(self, other): return self.bvsub(other) - def __mul__(self, other): return self.bvmul(other) - def __floordiv__(self, other): return self.bvudiv(other) - def __mod__(self, other): return self.bvurem(other) - def __eq__(self, other): return self.bveq(other) - def __ne__(self, other): return self.bvne(other) - def __ge__(self, other): return self.bvuge(other) - def __gt__(self, other): return self.bvugt(other) - def __le__(self, other): return self.bvule(other) - def __lt__(self, other): return self.bvult(other) + def __add__(self, other): + try: + return self.bvadd(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __sub__(self, other): + try: + return self.bvsub(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __mul__(self, other): + try: + return self.bvmul(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __floordiv__(self, other): + try: + return self.bvudiv(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __mod__(self, other): + try: + return self.bvurem(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + + def __eq__(self, other): + try: + return self.bveq(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __ne__(self, other): + try: + return self.bvne(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __ge__(self, other): + try: + return self.bvuge(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __gt__(self, other): + try: + return self.bvugt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __le__(self, other): + try: + return self.bvule(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __lt__(self, other): + try: + return self.bvult(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + def as_uint(self): @@ -444,27 +561,60 @@ def __int__(self): return self.as_sint() def __rshift__(self, other): - return self.bvashr(other) + try: + return self.bvashr(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __floordiv__(self, other): - return self.bvsdiv(other) + try: + return self.bvsdiv(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __mod__(self, other): - return self.bvsrem(other) + try: + return self.bvsrem(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __ge__(self, other): - return self.bvsge(other) + try: + return self.bvsge(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __gt__(self, other): - return self.bvsgt(other) + try: + return self.bvsgt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __lt__(self, other): - - return self.bvslt(other) + try: + return self.bvslt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __le__(self, other): - return self.bvsle(other) - + try: + return self.bvsle(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented @staticmethod def random(width): diff --git a/hwtypes/smt_bit_vector.py b/hwtypes/smt_bit_vector.py index f7b6283..b086138 100644 --- a/hwtypes/smt_bit_vector.py +++ b/hwtypes/smt_bit_vector.py @@ -48,7 +48,11 @@ def wrapped(self, other): if isinstance(other, SMTBit): return fn(self, other) else: - return fn(self, SMTBit(other)) + try: + other = SMTBit(other) + except TypeError: + return NotImplemented + return fn(self, other) return wrapped class SMTBit(AbstractBit): @@ -472,28 +476,139 @@ def bvsdiv(self, other): def bvsrem(self, other): return type(self)(smt.BVSRem(self.value, other.value)) - __invert__ = bvnot - __and__ = bvand - __or__ = bvor - __xor__ = bvxor + def __invert__(self): return self.bvnot() + + def __and__(self, other): + try: + return self.bvand(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __or__(self, other): + try: + return self.bvor(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __xor__(self, other): + try: + return self.bvxor(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + + def __lshift__(self, other): + try: + return self.bvshl(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __rshift__(self, other): + try: + return self.bvlshr(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __neg__(self): return self.bvneg() + + def __add__(self, other): + try: + return self.bvadd(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __sub__(self, other): + try: + return self.bvsub(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __mul__(self, other): + try: + return self.bvmul(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __floordiv__(self, other): + try: + return self.bvudiv(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __mod__(self, other): + try: + return self.bvurem(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + + def __eq__(self, other): + try: + return self.bveq(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + + def __ne__(self, other): + try: + return self.bvne(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented - __lshift__ = bvshl - __rshift__ = bvlshr + def __ge__(self, other): + try: + return self.bvuge(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented - __neg__ = bvneg - __add__ = bvadd - __sub__ = bvsub - __mul__ = bvmul - __floordiv__ = bvudiv - __mod__ = bvurem + def __gt__(self, other): + try: + return self.bvugt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented - __eq__ = bveq - __ne__ = bvne - __ge__ = bvuge - __gt__ = bvugt - __le__ = bvule - __lt__ = bvult + def __le__(self, other): + try: + return self.bvule(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + def __lt__(self, other): + try: + return self.bvult(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented @int_cast @@ -540,24 +655,60 @@ class SMTUIntVector(SMTNumVector): class SMTSIntVector(SMTNumVector): def __rshift__(self, other): - return self.bvashr(other) + try: + return self.bvashr(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __floordiv__(self, other): - return self.bvsdiv(other) + try: + return self.bvsdiv(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __mod__(self, other): - return self.bvsrem(other) + try: + return self.bvsrem(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __ge__(self, other): - return self.bvsge(other) + try: + return self.bvsge(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __gt__(self, other): - return self.bvsgt(other) + try: + return self.bvsgt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __lt__(self, other): - return self.bvslt(other) + try: + return self.bvslt(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented def __le__(self, other): - return self.bvsle(other) + try: + return self.bvsle(other) + except InconsistentSizeError as e: + raise e from None + except TypeError: + return NotImplemented + _Family_ = TypeFamily(SMTBit, SMTBitVector, SMTUIntVector, SMTSIntVector)