Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make polymorhic bitvector #127

Merged
merged 4 commits into from
Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions hwtypes/bit_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,24 @@ def __xor__(self, other):
def ite(self, t_branch, f_branch):
'''
typing works as follows:
given cls is type(self)
and BV is cls.get_family().BitVector
if t_branch and f_branch are both Bit[Vector] types
from the same family, and type(t_branch) is type(f_branch),
then return type is t_branch.

if both branches are subclasses of cls and one is a subclass of the
other return type is the parent type
elif t_branch and f_branch are both Bit[Vector] types
from the same family, then return type is a polymorphic
type.

if both branches are subclasses of BV and one is a subclass of the
other return type is the parent type
elif t_brand and f_branch are tuples of the
same length, then these rules are applied recursively

if one branch is a subclass of cls try to cast the other branch to that
type and return that type
else there is an error

if one branch is a subclass of BV try to cast the other branch to that
type and return that type

if both branches are tuples of the same length, then these tests are
applied recursively to each pair of elements

all other cases are errors
'''
def _ite(t_branch, f_branch):
return t_branch if self else f_branch
def _ite(select, t_branch, f_branch):
return t_branch if select else f_branch

return build_ite(_ite, type(self), t_branch, f_branch)
return build_ite(_ite, self, t_branch, f_branch)

def __bool__(self) -> bool:
return self._value
Expand All @@ -114,7 +108,7 @@ def __int__(self) -> int:
return int(self._value)

def __repr__(self) -> str:
return 'Bit({})'.format(self._value)
return f'{type(self).__name__}({self._value})'

def __hash__(self) -> int:
return hash(self._value)
Expand Down Expand Up @@ -178,7 +172,7 @@ def __str__(self):
return str(int(self))

def __repr__(self):
return "BitVector[{size}]({value})".format(value=self._value, size=self.size)
return f'{type(self).__name__}({self._value})'

@property
def value(self):
Expand Down Expand Up @@ -476,11 +470,9 @@ def __lt__(self, other):
return self.bvult(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
except TypeError as e:
return NotImplemented



def as_uint(self):
return self._value

Expand Down Expand Up @@ -526,7 +518,7 @@ def sext(self, ext):
return self.concat(T[1](self[-1]).repeat(ext))

def ext(self, ext):
return self.zext(other)
return self.zext(ext)

def zext(self, ext):
ext = int(ext)
Expand All @@ -541,29 +533,21 @@ def random(width):
return BitVector[width](random.randint(0, (1 << width) - 1))




class NumVector(BitVector):
__hash__ = BitVector.__hash__


class UIntVector(NumVector):
__hash__ = NumVector.__hash__

def __repr__(self):
return "UIntVector[{size}]({value})".format(value=self._value, size=self.size)

@staticmethod
def random(width):
return UIntVector[width](random.randint(0, (1 << width) - 1))



class SIntVector(NumVector):
__hash__ = NumVector.__hash__

def __repr__(self):
return "SIntVector[{size}]({value})".format(value=self._value, size=self.size)

def __int__(self):
return self.as_sint()

Expand Down Expand Up @@ -628,7 +612,6 @@ def random(width):
w = width - 1
return SIntVector[width](random.randint(-(1 << w), (1 << w) - 1))

@bv_cast
def ext(self, other):
return self.sext(other)

Expand Down
3 changes: 2 additions & 1 deletion hwtypes/bit_vector_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class InconsistentSizeError(TypeError): pass
#I want to be able differentiate an old style call
#BitVector(val, None) from BitVector(val)
_MISSING = object()
class AbstractBitVectorMeta(ABCMeta):
class AbstractBitVectorMeta(type): #:(ABCMeta):
# BitVectorType, size : BitVectorType[size]
_class_cache = weakref.WeakValueDictionary()

Expand Down Expand Up @@ -319,5 +319,6 @@ def ext(self, other) -> 'AbstractBitVector':
def zext(self, other) -> 'AbstractBitVector':
pass

BitVectorMeta = AbstractBitVectorMeta

_Family_ = TypeFamily(AbstractBit, AbstractBitVector, None, None)
190 changes: 147 additions & 43 deletions hwtypes/bit_vector_util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,146 @@
import functools as ft
import itertools as it
import inspect
import types

from .bit_vector_abc import InconsistentSizeError
from .bit_vector_abc import BitVectorMeta, AbstractBitVector, AbstractBit

# used as a tag
class PolyBase: pass

class PolyType(type):
def __getitem__(cls, args):
# From a typing perspective it would be better to make select an
# argument to init instead of making it a type param. This would
# allow types to be cached etc... However, making it an init arg
# means type(self)(val) is no longer sufficient to to cast val.
# Instead one would need to write type(self)(val, self._select_) or
# equivalent and hence would require a major change in the engineering
# of bitvector types (they would need to be aware of polymorphism).
# Note we can't cache as select is not necessarily hashable.

T0, T1, select = args

if not cls._type_check(T0, T1):
raise TypeError(f'Cannot construct {cls} from {T0} and {T1}')
if not isinstance(select, AbstractBit):
raise TypeError('select must be a Bit')
if (T0.get_family() is not T1.get_family()
or T0.get_family() is not select.get_family()):
raise TypeError('Cannot construct PolyTypes across families')


# stupid generator to make sure PolyBase is not replicated
# and always comes last
bases = *(b for b in cls._get_bases(T0, T1) if b is not PolyBase), PolyBase
class_name = f'{cls.__name__}[{T0.__name__}, {T1.__name__}, {select}]'
meta, namespace, _ = types.prepare_class(class_name, bases)

d0 = dict(inspect.getmembers(T0))
d1 = dict(inspect.getmembers(T1))

attrs = d0.keys() & d1.keys()
for k in attrs:
if k in {'_info_', '__int__', '__repr__', '__str__'}:
continue

m0 = inspect.getattr_static(T0, k)
m1 = inspect.getattr_static(T1, k)
namespace[k] = build_VCall(select, m0, m1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the magic, a PolyType store's it's select condition, operators are dispatched based on that condition when invoked.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory this might work with magma types?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not quite, Poly would need to implement the magma protocol, but then I think it could work, maybe just by concatenating the bits of the two subtypes together as the bits representation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan on handling magma with my MagmaBV proposal. phanrahan/magma#587

MagmaBV will use the same machinery to build ites as BitVector / SMTBitVector and implement the magma protocol.



new_cls = meta(class_name, bases, namespace)
final = cls._finalize(new_cls, T0, T1)
return final


def _get_common_bases(T0, T1):
if issubclass(T0, T1):
return T1,
elif issubclass(T1, T0):
return T0,
else:
bases = set()
for t in T0.__bases__:
bases.update(_get_common_bases(t, T1))

for t in T1.__bases__:
bases.update(_get_common_bases(t, T0))

# Filter to most specific types
bases_ = set()
for bi in bases:
if not any(issubclass(bj, bi) for bj in bases if bi is not bj):
bases_.add(bi)

return tuple(bases_)

class PolyVector(metaclass=PolyType):
@classmethod
def _type_check(cls, T0, T1):
if (issubclass(T0, AbstractBitVector)
and issubclass(T1, AbstractBitVector)):
if T0.size != T1.size:
raise InconsistentSizeError(f'Cannot construct {cls} from {T0} and {T1}')
else:
return True
else:
return False

@classmethod
def _get_bases(cls, T0, T1):
bases = _get_common_bases(T0, T1)

# get the unsized versions
bases_ = set()
for base in bases:
try:
bases_.add(base.unsized_t)
except AttributeError:
bases_.add(base)

return tuple(bases_)


@classmethod
def _finalize(cls, new_class, T0, T1):
return new_class[T0.size]

class PolyBit(metaclass=PolyType):
@classmethod
def _type_check(cls, T0, T1):
return (issubclass(T0, AbstractBit)
and issubclass(T1, AbstractBit))

@classmethod
def _get_bases(cls, T0, T1):
return _get_common_bases(T0, T1)

@classmethod
def _finalize(cls, new_class, T0, T1):
return new_class

def build_VCall(select, m0, m1):
if m0 is m1:
return m0
else:
def VCall(*args, **kwargs):
v0 = m0(*args, **kwargs)
v1 = m1(*args, **kwargs)
if v0 is NotImplemented or v0 is NotImplemented:
return NotImplemented
return select.ite(v0, v1)
return VCall


def get_branch_type(branch):
if isinstance(branch, tuple):
return tuple(map(get_branch_type, branch))
else:
return type(branch)

def determine_return_type(bit_t, bv_t, t_branch, f_branch):
def determine_return_type(select, t_branch, f_branch):
def _recurse(t_branch, f_branch):
tb_t = get_branch_type(t_branch)
fb_t = get_branch_type(f_branch)
Expand All @@ -29,33 +161,17 @@ def _recurse(t_branch, f_branch):
elif (isinstance(tb_t, tuple)
or isinstance(fb_t, tuple)):
raise TypeError(f'Branches have inconsistent types: {tb_t} and {fb_t}')
elif isinstance(t_branch, bit_t) and isinstance(f_branch, bit_t):
if issubclass(tb_t, fb_t):
return fb_t
elif issubclass(fb_t, tb_t):
elif issubclass(tb_t, AbstractBit) and issubclass(fb_t, AbstractBit):
if tb_t is fb_t:
return tb_t
else:
raise TypeError(f'Branches have inconsistent types: {tb_t} and {fb_t}')
elif isinstance(t_branch, bv_t) and isinstance(f_branch, bv_t):
if tb_t.size != fb_t.size:
raise InconsistentSizeError('Both branches must have the same size')
elif issubclass(tb_t, fb_t):
return fb_t
elif issubclass(fb_t, tb_t):
return PolyBit[tb_t, fb_t, select]
elif issubclass(tb_t, AbstractBitVector) and issubclass(fb_t, AbstractBitVector):
if tb_t is fb_t:
return tb_t
else:
raise TypeError(f'Branches have inconsistent types: {tb_t} and {fb_t}')
elif isinstance(t_branch, bit_t):
return tb_t
elif isinstance(f_branch, bit_t):
return fb_t
elif isinstance(t_branch, bv_t):
return tb_t
elif isinstance(f_branch, bv_t):
return fb_t
return PolyVector[tb_t, fb_t, select]
else:
raise TypeError(f'Cannot infer return type. '
f'Atleast one branch must be a {bv_t}, {bit_t}, or tuples')
raise TypeError(f'tb_t: {tb_t}, fb_t: {fb_t}')

return _recurse(t_branch, f_branch)

def coerce_branch(r_type, branch):
Expand All @@ -66,7 +182,7 @@ def coerce_branch(r_type, branch):
else:
return r_type(branch)

def push_ite(ite, t_branch, f_branch):
def push_ite(ite, select, t_branch, f_branch):
def _recurse(t_branch, f_branch):
if isinstance(t_branch, tuple):
assert isinstance(f_branch, tuple)
Expand All @@ -76,23 +192,11 @@ def _recurse(t_branch, f_branch):
zip(t_branch, f_branch)
))
else:
return ite(t_branch, f_branch)
return ite(select, t_branch, f_branch)
return _recurse(t_branch, f_branch)

def build_ite(ite, bit_t, t_branch, f_branch,
push_ite_to_leaves=False,
cast_return=False):
bv_t = bit_t.get_family().BitVector
r_type = determine_return_type(bit_t, bv_t, t_branch, f_branch)
t_branch = coerce_branch(r_type, t_branch)
f_branch = coerce_branch(r_type, f_branch)

if push_ite_to_leaves:
r_val = push_ite(ite, t_branch, f_branch)
else:
r_val = ite(t_branch, f_branch)

if cast_return:
r_val = coerce_branch(r_type, r_val)

def build_ite(ite, select, t_branch, f_branch):
r_type = determine_return_type(select, t_branch, f_branch)
r_val = push_ite(ite, select, t_branch, f_branch)
r_val = coerce_branch(r_type, r_val)
return r_val
Loading