From 7ac4d93ea7d8c7174e0c7b924bdda39062f9c4bc Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 10 Mar 2020 14:58:19 -0700 Subject: [PATCH 1/4] Make polymorhic bitvector --- hwtypes/bit_vector.py | 48 ++++-------- hwtypes/bit_vector_abc.py | 3 +- hwtypes/bit_vector_util.py | 150 ++++++++++++++++++++++++++----------- hwtypes/smt_bit_vector.py | 28 +------ tests/test_optypes.py | 42 +++++++---- 5 files changed, 155 insertions(+), 116 deletions(-) diff --git a/hwtypes/bit_vector.py b/hwtypes/bit_vector.py index bae29c4..088bcd8 100644 --- a/hwtypes/bit_vector.py +++ b/hwtypes/bit_vector.py @@ -82,30 +82,24 @@ def __xor__(self, other): def ite(self, t_branch, f_branch): ''' typing works as follows: - given cls is type(self) - and BV is cls.get_family().BitVector + if t_branch and f_branch are both Bit[Vector] types + from the same family, and type(t_branch) is type(f_branch), + then return type is t_branch. - if both branches are subclasses of cls and one is a subclass of the - other return type is the parent type + elif t_branch and f_branch are both Bit[Vector] types + from the same family, then return type is a polymorphic + type. - if both branches are subclasses of BV and one is a subclass of the - other return type is the parent type + elif t_brand and f_branch are tuples of the + same length, then these rules are applied recursively - if one branch is a subclass of cls try to cast the other branch to that - type and return that type + else there is an error - if one branch is a subclass of BV try to cast the other branch to that - type and return that type - - if both branches are tuples of the same length, then these tests are - applied recursively to each pair of elements - - all other cases are errors ''' - def _ite(t_branch, f_branch): - return t_branch if self else f_branch + def _ite(select, t_branch, f_branch): + return t_branch if select else f_branch - return build_ite(_ite, type(self), t_branch, f_branch) + return build_ite(_ite, self, t_branch, f_branch) def __bool__(self) -> bool: return self._value @@ -114,7 +108,7 @@ def __int__(self) -> int: return int(self._value) def __repr__(self) -> str: - return 'Bit({})'.format(self._value) + return f'{type(self).__name__}({self._value})' def __hash__(self) -> int: return hash(self._value) @@ -178,7 +172,7 @@ def __str__(self): return str(int(self)) def __repr__(self): - return "BitVector[{size}]({value})".format(value=self._value, size=self.size) + return f'{type(self).__name__}({self._value})' @property def value(self): @@ -476,11 +470,9 @@ def __lt__(self, other): return self.bvult(other) except InconsistentSizeError as e: raise e from None - except TypeError: + except TypeError as e: return NotImplemented - - def as_uint(self): return self._value @@ -541,29 +533,21 @@ def random(width): return BitVector[width](random.randint(0, (1 << width) - 1)) - - class NumVector(BitVector): __hash__ = BitVector.__hash__ + class UIntVector(NumVector): __hash__ = NumVector.__hash__ - def __repr__(self): - return "UIntVector[{size}]({value})".format(value=self._value, size=self.size) - @staticmethod def random(width): return UIntVector[width](random.randint(0, (1 << width) - 1)) - class SIntVector(NumVector): __hash__ = NumVector.__hash__ - def __repr__(self): - return "SIntVector[{size}]({value})".format(value=self._value, size=self.size) - def __int__(self): return self.as_sint() diff --git a/hwtypes/bit_vector_abc.py b/hwtypes/bit_vector_abc.py index e4f104d..2799213 100644 --- a/hwtypes/bit_vector_abc.py +++ b/hwtypes/bit_vector_abc.py @@ -16,7 +16,7 @@ class InconsistentSizeError(TypeError): pass #I want to be able differentiate an old style call #BitVector(val, None) from BitVector(val) _MISSING = object() -class AbstractBitVectorMeta(ABCMeta): +class AbstractBitVectorMeta(type): #:(ABCMeta): # BitVectorType, size : BitVectorType[size] _class_cache = weakref.WeakValueDictionary() @@ -319,5 +319,6 @@ def ext(self, other) -> 'AbstractBitVector': def zext(self, other) -> 'AbstractBitVector': pass +BitVectorMeta = AbstractBitVectorMeta _Family_ = TypeFamily(AbstractBit, AbstractBitVector, None, None) diff --git a/hwtypes/bit_vector_util.py b/hwtypes/bit_vector_util.py index e5bc8c5..2efe022 100644 --- a/hwtypes/bit_vector_util.py +++ b/hwtypes/bit_vector_util.py @@ -1,6 +1,98 @@ import functools as ft import itertools as it +import inspect +import types + from .bit_vector_abc import InconsistentSizeError +from .bit_vector_abc import BitVectorMeta, AbstractBitVector, AbstractBit + +def _get_common_bases(t, s): + if issubclass(t, s): + return (s,) + elif issubclass(s, t): + return (t,) + else: + bases = set() + for t_ in t.__bases__: + bases.update(_get_common_bases(t_, s)) + + for s_ in s.__bases__: + bases.update(_get_common_bases(t, s_)) + + return tuple(bases) + +class PolyType(type): + _type_cache = {} + def __getitem__(cls, args): + try: + return cls._type_cache[args] + except KeyError: + pass + + # In terms of typing it would make more sense to make + # select a instance paramater not a type paramater + # however for engineering reasons its a lot more + # convient to make it a type parameter. + T0, T1, select = args + + if not cls._type_check(T0, T1): + raise TypeError(f'Cannot construct {cls} from {T0} and {T1}') + if T0.get_family() is not T1.get_family(): + raise TypeError('Cannot construct PolyTypes across families') + elif not isinstance(select, AbstractBit): + raise TypeError('select must be a Bit') + elif select.get_family() is not T0.get_family(): + raise TypeError('Cannot construct PolyTypes across families') + + bases = _get_common_bases(T0, T1) + class_name = f'{cls.__name__}[{T0.__name__}, {T1.__name__}, {select}]' + meta, namespace, _ = types.prepare_class(class_name, bases) + + d0 = dict(inspect.getmembers(T0)) + d1 = dict(inspect.getmembers(T1)) + + attrs = d0.keys() & d1.keys() + for k in attrs: + if k in {'_info_', '__int__', '__repr__', '__str__'}: + continue + + m0 = inspect.getattr_static(T0, k) + m1 = inspect.getattr_static(T1, k) + namespace[k] = build_VCall(select, [m0, m1]) + + new_cls = meta(class_name, bases, namespace) + return cls._type_cache.setdefault(args, new_cls) + +class PolyVector(metaclass=PolyType): + @classmethod + def _type_check(cls, T0, T1): + if (issubclass(T0, AbstractBitVector) + and issubclass(T1, AbstractBitVector)): + if T0.size != T1.size: + raise InconsistentSizeError(f'Cannot construct {cls} from {T0} and {T1}') + else: + return True + else: + return False + +class PolyBit(metaclass=PolyType): + @classmethod + def _type_check(cls, T0, T1): + return (issubclass(T0, AbstractBit) + and issubclass(T1, AbstractBit)) + +def build_VCall(select, methods): + if methods[0] is methods[1]: + return methods[0] + else: + def VCall(*args, **kwargs): + v0 = methods[0](*args, **kwargs) + v1 = methods[1](*args, **kwargs) + if v0 is NotImplemented or v0 is NotImplemented: + return NotImplemented + return select.ite(v0, v1) + return VCall + def get_branch_type(branch): if isinstance(branch, tuple): @@ -8,7 +100,7 @@ def get_branch_type(branch): else: return type(branch) -def determine_return_type(bit_t, bv_t, t_branch, f_branch): +def determine_return_type(select, t_branch, f_branch): def _recurse(t_branch, f_branch): tb_t = get_branch_type(t_branch) fb_t = get_branch_type(f_branch) @@ -29,33 +121,17 @@ def _recurse(t_branch, f_branch): elif (isinstance(tb_t, tuple) or isinstance(fb_t, tuple)): raise TypeError(f'Branches have inconsistent types: {tb_t} and {fb_t}') - elif isinstance(t_branch, bit_t) and isinstance(f_branch, bit_t): - if issubclass(tb_t, fb_t): - return fb_t - elif issubclass(fb_t, tb_t): + elif issubclass(tb_t, AbstractBit) and issubclass(fb_t, AbstractBit): + if tb_t is fb_t: return tb_t - else: - raise TypeError(f'Branches have inconsistent types: {tb_t} and {fb_t}') - elif isinstance(t_branch, bv_t) and isinstance(f_branch, bv_t): - if tb_t.size != fb_t.size: - raise InconsistentSizeError('Both branches must have the same size') - elif issubclass(tb_t, fb_t): - return fb_t - elif issubclass(fb_t, tb_t): + return PolyBit[tb_t, fb_t, select] + elif issubclass(tb_t, AbstractBitVector) and issubclass(fb_t, AbstractBitVector): + if tb_t is fb_t: return tb_t - else: - raise TypeError(f'Branches have inconsistent types: {tb_t} and {fb_t}') - elif isinstance(t_branch, bit_t): - return tb_t - elif isinstance(f_branch, bit_t): - return fb_t - elif isinstance(t_branch, bv_t): - return tb_t - elif isinstance(f_branch, bv_t): - return fb_t + return PolyVector[tb_t, fb_t, select] else: - raise TypeError(f'Cannot infer return type. ' - f'Atleast one branch must be a {bv_t}, {bit_t}, or tuples') + raise TypeError(f'tb_t: {tb_t}, fb_t: {fb_t}') + return _recurse(t_branch, f_branch) def coerce_branch(r_type, branch): @@ -66,7 +142,7 @@ def coerce_branch(r_type, branch): else: return r_type(branch) -def push_ite(ite, t_branch, f_branch): +def push_ite(ite, select, t_branch, f_branch): def _recurse(t_branch, f_branch): if isinstance(t_branch, tuple): assert isinstance(f_branch, tuple) @@ -76,23 +152,11 @@ def _recurse(t_branch, f_branch): zip(t_branch, f_branch) )) else: - return ite(t_branch, f_branch) + return ite(select, t_branch, f_branch) return _recurse(t_branch, f_branch) -def build_ite(ite, bit_t, t_branch, f_branch, - push_ite_to_leaves=False, - cast_return=False): - bv_t = bit_t.get_family().BitVector - r_type = determine_return_type(bit_t, bv_t, t_branch, f_branch) - t_branch = coerce_branch(r_type, t_branch) - f_branch = coerce_branch(r_type, f_branch) - - if push_ite_to_leaves: - r_val = push_ite(ite, t_branch, f_branch) - else: - r_val = ite(t_branch, f_branch) - - if cast_return: - r_val = coerce_branch(r_type, r_val) - +def build_ite(ite, select, t_branch, f_branch): + r_type = determine_return_type(select, t_branch, f_branch) + r_val = push_ite(ite, select, t_branch, f_branch) + r_val = coerce_branch(r_type, r_val) return r_val diff --git a/hwtypes/smt_bit_vector.py b/hwtypes/smt_bit_vector.py index 6f56a21..bd513da 100644 --- a/hwtypes/smt_bit_vector.py +++ b/hwtypes/smt_bit_vector.py @@ -135,33 +135,11 @@ def __xor__(self, other : 'SMTBit') -> 'SMTBit': return type(self)(smt.Xor(self.value, other.value)) def ite(self, t_branch, f_branch): - ''' - typing works as follows: - given cls is type(self) - and BV is cls.get_family().BitVector + def _ite(select, t_branch, f_branch): + return smt.Ite(select.value, t_branch.value, f_branch.value) - if both branches are subclasses of cls and one is a subclass of the - other return type is the parent type - if both branches are subclasses of BV and one is a subclass of the - other return type is the parent type - - if one branch is a subclass of cls try to cast the other branch to that - type and return that type - - if one branch is a subclass of BV try to cast the other branch to that - type and return that type - - if both branches are tuples of the same length, then these tests are - applied recursively to each pair of elements - - all other cases are errors - ''' - def _ite(t_branch, f_branch): - return smt.Ite(self.value, t_branch.value, f_branch.value) - - - return build_ite(_ite, type(self), t_branch, f_branch, True, True) + return build_ite(_ite, self, t_branch, f_branch) def substitute(self, *subs : tp.List[tp.Tuple['SMTBit', 'SMTBit']]): return SMTBit( diff --git a/tests/test_optypes.py b/tests/test_optypes.py index 3ffb9da..938d6f2 100644 --- a/tests/test_optypes.py +++ b/tests/test_optypes.py @@ -3,12 +3,17 @@ import random from itertools import product -from hwtypes import BitVector, Bit +from hwtypes import SIntVector, BitVector, Bit from hwtypes.bit_vector_abc import InconsistentSizeError +from hwtypes.bit_vector_util import PolyVector def _rand_bv(width): return BitVector[width](random.randint(0, (1 << width) - 1)) +def _rand_signed(width): + return SIntVector[width](random.randint(0, (1 << width) - 1)) + + def _rand_int(width): return random.randint(0, (1 << width) - 1) @@ -69,29 +74,36 @@ def test_comp(op, width1, width2, use_int): assert type(res) is Bit -@pytest.mark.parametrize("t_constructor", (_rand_bv, _rand_int)) -@pytest.mark.parametrize("t_size", (1, 2, 4, 8)) -@pytest.mark.parametrize("f_constructor", (_rand_bv, _rand_int)) -@pytest.mark.parametrize("f_size", (1, 2, 4, 8)) +@pytest.mark.parametrize("t_constructor", (_rand_bv, _rand_signed, _rand_int)) +@pytest.mark.parametrize("t_size", (1, 2, 4)) +@pytest.mark.parametrize("f_constructor", (_rand_bv, _rand_signed, _rand_int)) +@pytest.mark.parametrize("f_size", (1, 2, 4)) def test_ite(t_constructor, t_size, f_constructor, f_size): pred = Bit(_rand_int(1)) t = t_constructor(t_size) f = f_constructor(f_size) - if t_constructor is f_constructor is _rand_bv and t_size == f_size: + t_is_bv_constructor = t_constructor in {_rand_signed, _rand_bv} + f_is_bv_constructor = f_constructor in {_rand_signed, _rand_bv} + sizes_equal = t_size == f_size + + if (t_constructor is f_constructor and t_is_bv_constructor and sizes_equal): + # The same bv_constructor res = pred.ite(t, f) assert type(res) is type(t) - elif t_constructor is f_constructor is _rand_bv: + elif t_is_bv_constructor and f_is_bv_constructor and sizes_equal: + # Different bv_constuctor + res = pred.ite(t, f) + assert type(res) is PolyVector[type(t), type(f), pred] + # The bases should be the most specific types that are common + # to both branches. As SIntVect[size] is a subclass of + # BitVector[size], BitVector[size] is such a type. + assert type(res).__bases__ == (BitVector[t_size],) + elif t_is_bv_constructor and f_is_bv_constructor and not sizes_equal: # BV with different size with pytest.raises(InconsistentSizeError): res = pred.ite(t, f) - elif t_constructor is f_constructor: - # both int + else: + # Trying to coerce an int with pytest.raises(TypeError): res = pred.ite(t, f) - elif t_constructor is _rand_bv: - res = pred.ite(t, f) - assert type(res) is BitVector[t_size] - else: #t_constructor is _rand_int - res = pred.ite(t, f) - assert type(res) is BitVector[f_size] From 689f7e00c95d857b1466f09107ed1f436880a5fe Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 10 Mar 2020 16:04:38 -0700 Subject: [PATCH 2/4] Fix some bugs --- hwtypes/bit_vector_util.py | 65 +++++++++++++++++++++++--------------- hwtypes/smt_bit_vector.py | 1 - tests/test_optypes.py | 8 ++--- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/hwtypes/bit_vector_util.py b/hwtypes/bit_vector_util.py index 2efe022..4a9ce2d 100644 --- a/hwtypes/bit_vector_util.py +++ b/hwtypes/bit_vector_util.py @@ -19,33 +19,33 @@ def _get_common_bases(t, s): for s_ in s.__bases__: bases.update(_get_common_bases(t, s_)) - return tuple(bases) + # Filter to most specific types + bases_ = set() + for bi in bases: + if not any(issubclass(bj, bi) for bj in bases if bi is not bj): + bases_.add(bi) + + return tuple(bases_) + +# used as a tag +class PolyBase: pass class PolyType(type): _type_cache = {} def __getitem__(cls, args): + T0, T1 = args try: return cls._type_cache[args] except KeyError: pass - # In terms of typing it would make more sense to make - # select a instance paramater not a type paramater - # however for engineering reasons its a lot more - # convient to make it a type parameter. - T0, T1, select = args - if not cls._type_check(T0, T1): raise TypeError(f'Cannot construct {cls} from {T0} and {T1}') if T0.get_family() is not T1.get_family(): raise TypeError('Cannot construct PolyTypes across families') - elif not isinstance(select, AbstractBit): - raise TypeError('select must be a Bit') - elif select.get_family() is not T0.get_family(): - raise TypeError('Cannot construct PolyTypes across families') - bases = _get_common_bases(T0, T1) - class_name = f'{cls.__name__}[{T0.__name__}, {T1.__name__}, {select}]' + bases = *_get_common_bases(T0, T1), PolyBase + class_name = f'{cls.__name__}[{T0.__name__}, {T1.__name__}]' meta, namespace, _ = types.prepare_class(class_name, bases) d0 = dict(inspect.getmembers(T0)) @@ -58,9 +58,22 @@ def __getitem__(cls, args): m0 = inspect.getattr_static(T0, k) m1 = inspect.getattr_static(T1, k) - namespace[k] = build_VCall(select, [m0, m1]) + namespace[k] = build_VCall([m0, m1]) new_cls = meta(class_name, bases, namespace) + + genv = {'base': new_cls} + lenv = {} + + # build __init__ + __init__ = f''' +def __init__(self, *args, _poly_select_, **kwargs): + self._poly_select_ = _poly_select_ + return super(base, self).__init__(*args, **kwargs) +''' + exec(__init__, genv, lenv) + new_cls.__init__ = lenv['__init__'] + return cls._type_cache.setdefault(args, new_cls) class PolyVector(metaclass=PolyType): @@ -81,13 +94,13 @@ def _type_check(cls, T0, T1): return (issubclass(T0, AbstractBit) and issubclass(T1, AbstractBit)) -def build_VCall(select, methods): +def build_VCall(methods): if methods[0] is methods[1]: return methods[0] else: - def VCall(*args, **kwargs): - v0 = methods[0](*args, **kwargs) - v1 = methods[1](*args, **kwargs) + def VCall(self, *args, **kwargs): + v0 = methods[0](self, *args, **kwargs) + v1 = methods[1](self, *args, **kwargs) if v0 is NotImplemented or v0 is NotImplemented: return NotImplemented return select.ite(v0, v1) @@ -100,7 +113,7 @@ def get_branch_type(branch): else: return type(branch) -def determine_return_type(select, t_branch, f_branch): +def determine_return_type(t_branch, f_branch): def _recurse(t_branch, f_branch): tb_t = get_branch_type(t_branch) fb_t = get_branch_type(f_branch) @@ -124,21 +137,23 @@ def _recurse(t_branch, f_branch): elif issubclass(tb_t, AbstractBit) and issubclass(fb_t, AbstractBit): if tb_t is fb_t: return tb_t - return PolyBit[tb_t, fb_t, select] + return PolyBit[tb_t, fb_t] elif issubclass(tb_t, AbstractBitVector) and issubclass(fb_t, AbstractBitVector): if tb_t is fb_t: return tb_t - return PolyVector[tb_t, fb_t, select] + return PolyVector[tb_t, fb_t] else: raise TypeError(f'tb_t: {tb_t}, fb_t: {fb_t}') return _recurse(t_branch, f_branch) -def coerce_branch(r_type, branch): +def coerce_branch(r_type, select, branch): if isinstance(r_type, tuple): assert isinstance(branch, tuple) assert len(r_type) == len(branch) - return tuple(coerce_branch(t, arg) for t, arg in zip(r_type, branch)) + return tuple(coerce_branch(t, select, arg) for t, arg in zip(r_type, branch)) + elif issubclass(r_type, PolyBase): + return r_type(branch, _poly_select_ = select) else: return r_type(branch) @@ -156,7 +171,7 @@ def _recurse(t_branch, f_branch): return _recurse(t_branch, f_branch) def build_ite(ite, select, t_branch, f_branch): - r_type = determine_return_type(select, t_branch, f_branch) + r_type = determine_return_type(t_branch, f_branch) r_val = push_ite(ite, select, t_branch, f_branch) - r_val = coerce_branch(r_type, r_val) + r_val = coerce_branch(r_type, select, r_val) return r_val diff --git a/hwtypes/smt_bit_vector.py b/hwtypes/smt_bit_vector.py index bd513da..e195c25 100644 --- a/hwtypes/smt_bit_vector.py +++ b/hwtypes/smt_bit_vector.py @@ -643,7 +643,6 @@ def substitute(self, *subs : tp.List[tp.Tuple["SBV", "SBV"]]): class SMTNumVector(SMTBitVector): pass - class SMTUIntVector(SMTNumVector): pass diff --git a/tests/test_optypes.py b/tests/test_optypes.py index 938d6f2..f65bdb8 100644 --- a/tests/test_optypes.py +++ b/tests/test_optypes.py @@ -5,7 +5,7 @@ from hwtypes import SIntVector, BitVector, Bit from hwtypes.bit_vector_abc import InconsistentSizeError -from hwtypes.bit_vector_util import PolyVector +from hwtypes.bit_vector_util import PolyVector, PolyBase def _rand_bv(width): return BitVector[width](random.randint(0, (1 << width) - 1)) @@ -94,11 +94,11 @@ def test_ite(t_constructor, t_size, f_constructor, f_size): elif t_is_bv_constructor and f_is_bv_constructor and sizes_equal: # Different bv_constuctor res = pred.ite(t, f) - assert type(res) is PolyVector[type(t), type(f), pred] + assert type(res) is PolyVector[type(t), type(f)] # The bases should be the most specific types that are common - # to both branches. As SIntVect[size] is a subclass of + # to both branches and PolyBase. As SIntVect[size] is a subclass of # BitVector[size], BitVector[size] is such a type. - assert type(res).__bases__ == (BitVector[t_size],) + assert type(res).__bases__ == (BitVector[t_size], PolyBase) elif t_is_bv_constructor and f_is_bv_constructor and not sizes_equal: # BV with different size with pytest.raises(InconsistentSizeError): From baba2e4563839ba253867334bab9b6aaec173ef4 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 10 Mar 2020 18:42:40 -0700 Subject: [PATCH 3/4] Fix a bunch of bugs --- hwtypes/bit_vector.py | 1 - hwtypes/bit_vector_util.py | 137 ++++++++++++++++++++++--------------- hwtypes/smt_bit_vector.py | 6 +- tests/test_optypes.py | 7 +- 4 files changed, 87 insertions(+), 64 deletions(-) diff --git a/hwtypes/bit_vector.py b/hwtypes/bit_vector.py index 088bcd8..5909326 100644 --- a/hwtypes/bit_vector.py +++ b/hwtypes/bit_vector.py @@ -612,7 +612,6 @@ def random(width): w = width - 1 return SIntVector[width](random.randint(-(1 << w), (1 << w) - 1)) - @bv_cast def ext(self, other): return self.sext(other) diff --git a/hwtypes/bit_vector_util.py b/hwtypes/bit_vector_util.py index 4a9ce2d..2afacbf 100644 --- a/hwtypes/bit_vector_util.py +++ b/hwtypes/bit_vector_util.py @@ -6,46 +6,33 @@ from .bit_vector_abc import InconsistentSizeError from .bit_vector_abc import BitVectorMeta, AbstractBitVector, AbstractBit -def _get_common_bases(t, s): - if issubclass(t, s): - return (s,) - elif issubclass(s, t): - return (t,) - else: - bases = set() - for t_ in t.__bases__: - bases.update(_get_common_bases(t_, s)) - - for s_ in s.__bases__: - bases.update(_get_common_bases(t, s_)) - - # Filter to most specific types - bases_ = set() - for bi in bases: - if not any(issubclass(bj, bi) for bj in bases if bi is not bj): - bases_.add(bi) - - return tuple(bases_) - # used as a tag class PolyBase: pass class PolyType(type): - _type_cache = {} def __getitem__(cls, args): - T0, T1 = args - try: - return cls._type_cache[args] - except KeyError: - pass + # From a typing perspective it would be better to make select an + # argument to init instead of making it a type param. This would + # allow types to be cached etc... However, making it an init arg + # means type(self)(val) is no longer sufficient to to cast val. + # Instead one would need to write type(self)(val, self._select_) or + # equivalent and hence would require a major change in the engineering + # of bitvector types (they would need to be aware of polymorphism). + # Note we can't cache as select is not necessarily hashable. + + T0, T1, select = args if not cls._type_check(T0, T1): raise TypeError(f'Cannot construct {cls} from {T0} and {T1}') - if T0.get_family() is not T1.get_family(): + if not isinstance(select, AbstractBit): + raise TypeError('select must be a Bit') + if (T0.get_family() is not T1.get_family() + or T0.get_family() is not select.get_family()): raise TypeError('Cannot construct PolyTypes across families') - bases = *_get_common_bases(T0, T1), PolyBase - class_name = f'{cls.__name__}[{T0.__name__}, {T1.__name__}]' + + bases = *cls._get_bases(T0, T1), PolyBase + class_name = f'{cls.__name__}[{T0.__name__}, {T1.__name__}, {select}]' meta, namespace, _ = types.prepare_class(class_name, bases) d0 = dict(inspect.getmembers(T0)) @@ -58,23 +45,34 @@ def __getitem__(cls, args): m0 = inspect.getattr_static(T0, k) m1 = inspect.getattr_static(T1, k) - namespace[k] = build_VCall([m0, m1]) + namespace[k] = build_VCall(select, m0, m1) + new_cls = meta(class_name, bases, namespace) + final = cls._finalize(new_cls, T0, T1) + return final - genv = {'base': new_cls} - lenv = {} - # build __init__ - __init__ = f''' -def __init__(self, *args, _poly_select_, **kwargs): - self._poly_select_ = _poly_select_ - return super(base, self).__init__(*args, **kwargs) -''' - exec(__init__, genv, lenv) - new_cls.__init__ = lenv['__init__'] +def _get_common_bases(T0, T1): + if issubclass(T0, T1): + return T1, + elif issubclass(T1, T0): + return T0, + else: + bases = set() + for t in T0.__bases__: + bases.update(_get_common_bases(t, T1)) - return cls._type_cache.setdefault(args, new_cls) + for t in T1.__bases__: + bases.update(_get_common_bases(t, T0)) + + # Filter to most specific types + bases_ = set() + for bi in bases: + if not any(issubclass(bj, bi) for bj in bases if bi is not bj): + bases_.add(bi) + + return tuple(bases_) class PolyVector(metaclass=PolyType): @classmethod @@ -88,19 +86,46 @@ def _type_check(cls, T0, T1): else: return False + @classmethod + def _get_bases(cls, T0, T1): + bases = _get_common_bases(T0, T1) + + # get the unsized versions + bases_ = set() + for base in bases: + try: + bases_.add(base.unsized_t) + except AttributeError: + bases_.add(base) + + return tuple(bases_) + + + @classmethod + def _finalize(cls, new_class, T0, T1): + return new_class[T0.size] + class PolyBit(metaclass=PolyType): @classmethod def _type_check(cls, T0, T1): return (issubclass(T0, AbstractBit) and issubclass(T1, AbstractBit)) -def build_VCall(methods): - if methods[0] is methods[1]: - return methods[0] + @classmethod + def _get_bases(cls, T0, T1): + return _get_common_bases(T0, T1) + + @classmethod + def _finalize(cls, new_class, T0, T1): + return new_class + +def build_VCall(select, m0, m1): + if m0 is m1: + return m0 else: - def VCall(self, *args, **kwargs): - v0 = methods[0](self, *args, **kwargs) - v1 = methods[1](self, *args, **kwargs) + def VCall(*args, **kwargs): + v0 = m0(*args, **kwargs) + v1 = m1(*args, **kwargs) if v0 is NotImplemented or v0 is NotImplemented: return NotImplemented return select.ite(v0, v1) @@ -113,7 +138,7 @@ def get_branch_type(branch): else: return type(branch) -def determine_return_type(t_branch, f_branch): +def determine_return_type(select, t_branch, f_branch): def _recurse(t_branch, f_branch): tb_t = get_branch_type(t_branch) fb_t = get_branch_type(f_branch) @@ -137,23 +162,21 @@ def _recurse(t_branch, f_branch): elif issubclass(tb_t, AbstractBit) and issubclass(fb_t, AbstractBit): if tb_t is fb_t: return tb_t - return PolyBit[tb_t, fb_t] + return PolyBit[tb_t, fb_t, select] elif issubclass(tb_t, AbstractBitVector) and issubclass(fb_t, AbstractBitVector): if tb_t is fb_t: return tb_t - return PolyVector[tb_t, fb_t] + return PolyVector[tb_t, fb_t, select] else: raise TypeError(f'tb_t: {tb_t}, fb_t: {fb_t}') return _recurse(t_branch, f_branch) -def coerce_branch(r_type, select, branch): +def coerce_branch(r_type, branch): if isinstance(r_type, tuple): assert isinstance(branch, tuple) assert len(r_type) == len(branch) - return tuple(coerce_branch(t, select, arg) for t, arg in zip(r_type, branch)) - elif issubclass(r_type, PolyBase): - return r_type(branch, _poly_select_ = select) + return tuple(coerce_branch(t, arg) for t, arg in zip(r_type, branch)) else: return r_type(branch) @@ -171,7 +194,7 @@ def _recurse(t_branch, f_branch): return _recurse(t_branch, f_branch) def build_ite(ite, select, t_branch, f_branch): - r_type = determine_return_type(t_branch, f_branch) + r_type = determine_return_type(select, t_branch, f_branch) r_val = push_ite(ite, select, t_branch, f_branch) - r_val = coerce_branch(r_type, select, r_val) + r_val = coerce_branch(r_type, r_val) return r_val diff --git a/hwtypes/smt_bit_vector.py b/hwtypes/smt_bit_vector.py index e195c25..715d10d 100644 --- a/hwtypes/smt_bit_vector.py +++ b/hwtypes/smt_bit_vector.py @@ -593,7 +593,7 @@ def __lt__(self, other): return self.bvult(other) except InconsistentSizeError as e: raise e from None - except TypeError: + except TypeError as e: return NotImplemented @@ -692,7 +692,7 @@ def __lt__(self, other): return self.bvslt(other) except InconsistentSizeError as e: raise e from None - except TypeError: + except TypeError as e: return NotImplemented def __le__(self, other): @@ -703,6 +703,8 @@ def __le__(self, other): except TypeError: return NotImplemented + def ext(self, other): + return self.sext(other) diff --git a/tests/test_optypes.py b/tests/test_optypes.py index f65bdb8..a9c8b56 100644 --- a/tests/test_optypes.py +++ b/tests/test_optypes.py @@ -94,11 +94,10 @@ def test_ite(t_constructor, t_size, f_constructor, f_size): elif t_is_bv_constructor and f_is_bv_constructor and sizes_equal: # Different bv_constuctor res = pred.ite(t, f) - assert type(res) is PolyVector[type(t), type(f)] # The bases should be the most specific types that are common - # to both branches and PolyBase. As SIntVect[size] is a subclass of - # BitVector[size], BitVector[size] is such a type. - assert type(res).__bases__ == (BitVector[t_size], PolyBase) + # to both branches and PolyBase. + assert isinstance(res, PolyBase) + assert isinstance(res, BitVector[t_size]) elif t_is_bv_constructor and f_is_bv_constructor and not sizes_equal: # BV with different size with pytest.raises(InconsistentSizeError): From ded09e45817f4eb5b7f673d342f08b6bef6ee57a Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Wed, 11 Mar 2020 13:31:33 -0700 Subject: [PATCH 4/4] Fix a couple bugs, add tests --- hwtypes/bit_vector.py | 2 +- hwtypes/bit_vector_util.py | 4 +- tests/test_poly.py | 92 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 tests/test_poly.py diff --git a/hwtypes/bit_vector.py b/hwtypes/bit_vector.py index 5909326..8210cd8 100644 --- a/hwtypes/bit_vector.py +++ b/hwtypes/bit_vector.py @@ -518,7 +518,7 @@ def sext(self, ext): return self.concat(T[1](self[-1]).repeat(ext)) def ext(self, ext): - return self.zext(other) + return self.zext(ext) def zext(self, ext): ext = int(ext) diff --git a/hwtypes/bit_vector_util.py b/hwtypes/bit_vector_util.py index 2afacbf..2b05c60 100644 --- a/hwtypes/bit_vector_util.py +++ b/hwtypes/bit_vector_util.py @@ -31,7 +31,9 @@ def __getitem__(cls, args): raise TypeError('Cannot construct PolyTypes across families') - bases = *cls._get_bases(T0, T1), PolyBase + # stupid generator to make sure PolyBase is not replicated + # and always comes last + bases = *(b for b in cls._get_bases(T0, T1) if b is not PolyBase), PolyBase class_name = f'{cls.__name__}[{T0.__name__}, {T1.__name__}, {select}]' meta, namespace, _ = types.prepare_class(class_name, bases) diff --git a/tests/test_poly.py b/tests/test_poly.py new file mode 100644 index 0000000..9b5d058 --- /dev/null +++ b/tests/test_poly.py @@ -0,0 +1,92 @@ +import pytest + +from pysmt import shortcuts as sc + +from hwtypes import BitVector, SIntVector, UIntVector +from hwtypes import Bit + +from hwtypes import SMTBit, SMTBitVector +from hwtypes import SMTUIntVector, SMTSIntVector + +@pytest.mark.parametrize("cond_0", [Bit(0), Bit(1)]) +@pytest.mark.parametrize("cond_1", [Bit(0), Bit(1)]) +def test_poly_bv(cond_0, cond_1): + S = SIntVector[8] + U = UIntVector[8] + val = cond_0.ite(S(0), U(0)) - 1 + + assert val < 0 if cond_0 else val > 0 + val2 = cond_1.ite(S(-1), val) + val2 = val2.ext(1) + assert val2 == val.sext(1) if cond_0 or cond_1 else val2 == val.zext(1) + + val3 = cond_1.ite(cond_0.ite(U(0), S(1)), cond_0.ite(S(-1), U(2))) + + if cond_1: + if cond_0: + assert val3 == 0 + assert val3 - 1 > 0 + else: + assert val3 == 1 + assert val3 - 2 < 0 + else: + if cond_0: + assert val3 == -1 + assert val3 < 0 + else: + assert val3 == 2 + assert val3 - 3 > 0 + +def test_poly_smt(): + S = SMTSIntVector[8] + U = SMTUIntVector[8] + + c1 = SMTBit(name='c1') + u1 = U(name='u1') + u2 = U(name='u2') + s1 = S(name='s1') + s2 = S(name='s2') + + # NOTE: __eq__ on pysmt terms is strict structural equivalence + # for example: + assert u1.value == u1.value # .value extract pysmt term + assert u1.value != u2.value + assert (u1 * 2).value != (u1 + u1).value + assert (u1 + u2).value == (u1 + u2).value + assert (u1 + u2).value != (u2 + u1).value + + # On to the real test + expr = c1.ite(u1, s1) < 1 + # get the pysmt values + _c1, _u1, _s1 = c1.value, u1.value, s1.value + e1 = sc.Ite(_c1, _u1, _s1) + one = sc.BV(1, 8) + # Here we see that `< 1` dispatches symbolically + f = sc.Ite(_c1, sc.BVULT(e1, one), sc.BVSLT(e1, one)) + assert expr.value == f + + expr = expr.ite(c1.ite(u1, s1), c1.ite(s2, u2)).ext(1) + + e2 = sc.Ite(_c1, s2.value, u2.value) + e3 = sc.Ite(f, e1, e2) + + se = sc.BVSExt(e3, 1) + ze = sc.BVZExt(e3, 1) + + + g = sc.Ite( + f, + sc.Ite(_c1, ze, se), + sc.Ite(_c1, se, ze) + ) + # Here we see that ext dispatches symbolically / recursively + assert expr.value == g + + + # Here we see that polymorphic types only build muxes if they need to + expr = c1.ite(u1, s1) + 1 + assert expr.value == sc.BVAdd(e1, one) + # Note how it is not: + assert expr.value != sc.Ite(_c1, sc.BVAdd(e1, one), sc.BVAdd(e1, one)) + # which was the pattern for sign dependent operators +