Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reflected operators to bit and bv types #153

Merged
merged 1 commit into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 74 additions & 122 deletions hwtypes/bit_vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing as tp
from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily, InconsistentSizeError
from .bit_vector_util import build_ite
from .util import Method
from .compatibility import IntegerTypes, StringTypes

import functools
Expand Down Expand Up @@ -71,14 +72,27 @@ def __ne__(self, other):
def __and__(self, other):
return type(self)(self._value & other._value)

@bit_cast
def __rand__(self, other):
return type(self)(other._value & self._value)

@bit_cast
def __or__(self, other):
return type(self)(self._value | other._value)

@bit_cast
def __ror__(self, other):
return type(self)(other._value | self._value)

@bit_cast
def __xor__(self, other):
return type(self)(self._value ^ other._value)

@bit_cast
def __rxor__(self, other):
return type(self)(other._value ^ self._value)


def ite(self, t_branch, f_branch):
'''
typing works as follows:
Expand Down Expand Up @@ -132,6 +146,34 @@ def wrapped(self : 'BitVector', other : tp.Any) -> tp.Any:
return fn(self, other)
return wrapped



def dispatch_oper(method: tp.MethodDescriptorType):
def oper(self, other):
try:
return method(self, other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

return Method(oper)


# A little inefficient because of double _coerce but whate;er
def dispatch_roper(method: Method):
def roper(self, other):
try:
other = _coerce(type(self), other)
except inconsistentsizeerror as e:
raise e from None
except TypeError:
return NotImplemented
return method(other, self)

return Method(roper)


class BitVector(AbstractBitVector):
@staticmethod
def get_family() -> TypeFamily:
Expand Down Expand Up @@ -328,8 +370,6 @@ def bvurem(self, other):
return self
return type(self)(self.as_uint() % other)

# bvumod

@bv_cast
def bvsdiv(self, other):
other = other.as_sint()
Expand All @@ -344,140 +384,46 @@ def bvsrem(self, other):
return self
return type(self)(self.as_sint() % other)

# bvsmod
def __invert__(self): return self.bvnot()

def __and__(self, other):
try:
return self.bvand(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__and__ = dispatch_oper(bvand)
__rand__ = dispatch_roper(__and__)

def __or__(self, other):
try:
return self.bvor(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__or__ = dispatch_oper(bvor)
__ror__ = dispatch_roper(__or__)

def __xor__(self, other):
try:
return self.bvxor(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__xor__ = dispatch_oper(bvxor)
__rxor__ = dispatch_roper(__xor__)

__lshift__ = dispatch_oper(bvshl)
__rlshift__ = dispatch_roper(__lshift__)

def __lshift__(self, other):
try:
return self.bvshl(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __rshift__(self, other):
try:
return self.bvlshr(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__rshift__ = dispatch_oper(bvlshr)
__rrshift__ = dispatch_oper(__rshift__)

def __neg__(self): return self.bvneg()

def __add__(self, other):
try:
return self.bvadd(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __sub__(self, other):
try:
return self.bvsub(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __mul__(self, other):
try:
return self.bvmul(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__add__ = dispatch_oper(bvadd)
__radd__ = dispatch_roper(__add__)

def __floordiv__(self, other):
try:
return self.bvudiv(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__sub__ = dispatch_oper(bvsub)
__rsub__ = dispatch_roper(__sub__)

def __mod__(self, other):
try:
return self.bvurem(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__mul__ = dispatch_oper(bvmul)
__rmul__ = dispatch_roper(__mul__)

__floordiv__ = dispatch_oper(bvudiv)
__rfloordiv__ = dispatch_roper(__floordiv__)

def __eq__(self, other):
try:
return self.bveq(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented
__mod__ = dispatch_oper(bvurem)
__rmod__ = dispatch_roper(__mod__)

def __ne__(self, other):
try:
return self.bvne(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __ge__(self, other):
try:
return self.bvuge(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __gt__(self, other):
try:
return self.bvugt(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __le__(self, other):
try:
return self.bvule(other)
except InconsistentSizeError as e:
raise e from None
except TypeError:
return NotImplemented

def __lt__(self, other):
try:
return self.bvult(other)
except InconsistentSizeError as e:
raise e from None
except TypeError as e:
return NotImplemented
__eq__ = dispatch_oper(bveq)
__ne__ = dispatch_oper(AbstractBitVector.bvne)
__ge__ = dispatch_oper(AbstractBitVector.bvuge)
__gt__ = dispatch_oper(AbstractBitVector.bvugt)
__le__ = dispatch_oper(AbstractBitVector.bvule)
__lt__ = dispatch_oper(bvult)

def as_uint(self):
return self._value
Expand Down Expand Up @@ -565,6 +511,8 @@ def __rshift__(self, other):
except TypeError:
return NotImplemented

__rrshift__ = dispatch_roper(__rshift__)

def __floordiv__(self, other):
try:
return self.bvsdiv(other)
Expand All @@ -573,6 +521,8 @@ def __floordiv__(self, other):
except TypeError:
return NotImplemented

__rfloordiv__ = dispatch_roper(__floordiv__)

def __mod__(self, other):
try:
return self.bvsrem(other)
Expand All @@ -581,6 +531,8 @@ def __mod__(self, other):
except TypeError:
return NotImplemented

__rmod__ = dispatch_roper(__mod__)

def __ge__(self, other):
try:
return self.bvsge(other)
Expand Down
3 changes: 2 additions & 1 deletion hwtypes/bit_vector_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import types

from .util import Method
from .bit_vector_abc import InconsistentSizeError
from .bit_vector_abc import BitVectorMeta, AbstractBitVector, AbstractBit

Expand Down Expand Up @@ -172,7 +173,7 @@ def VCall(*args, **kwargs):
if v0 is NotImplemented or v0 is NotImplemented:
return NotImplemented
return select.ite(v0, v1)
return VCall
return Method(VCall)


def get_branch_type(branch):
Expand Down
Loading
Loading