Skip to content

Commit

Permalink
Merge pull request #83 from leonardt/modifiers
Browse files Browse the repository at this point in the history
Add utility functions for modifiers; Better type hierarchy
  • Loading branch information
cdonovick authored Jul 30, 2019
2 parents acb4834 + 1a8a1fc commit 59d796e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 11 deletions.
75 changes: 66 additions & 9 deletions hwtypes/modifiers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import types
import weakref

__ALL__ = ['new', 'make_modifier', 'is_modified', 'is_modifier', 'get_modifier', 'get_unmodified']


_DEBUG = False
#special sentinal value
class _MISSING: pass

Expand All @@ -17,24 +21,77 @@ class T(klass): pass
return T

class _ModifierMeta(type):
_modifier_lookup = weakref.WeakKeyDictionary()
def __instancecheck__(cls, obj):
return type(obj) in cls._sub_classes.values()
if cls is AbstractModifier:
return super().__instancecheck__(obj)
else:
return type(obj) in cls._sub_classes

def __subclasscheck__(cls, typ):
return typ in cls._sub_classes.values()
def __subclasscheck__(cls, T):
if cls is AbstractModifier:
return super().__subclasscheck__(T)
else:
return T in cls._sub_classes

def __call__(cls, *args):
if cls is AbstractModifier:
raise TypeError('Cannot instance or apply AbstractModifier')

if len(args) != 1:
return super().__call__(*args)
sub = args[0]

unmod_cls = args[0]
try:
return cls._sub_classes[sub]
return cls._sub_class_cache[unmod_cls]
except KeyError:
pass

mod_sub_name = cls.__name__ + sub.__name__
mod_sub = type(mod_sub_name, (sub,), {})
return cls._sub_classes.setdefault(sub, mod_sub)
mod_name = cls.__name__ + unmod_cls.__name__
bases = [unmod_cls]
for base in unmod_cls.__bases__:
bases.append(cls(base))
mod_cls = type(mod_name, tuple(bases), {})
cls._register_modified(unmod_cls, mod_cls)
return mod_cls

class AbstractModifier(metaclass=_ModifierMeta):
def __init_subclass__(cls, **kwargs):
cls._sub_class_cache = weakref.WeakValueDictionary()
cls._sub_classes = weakref.WeakSet()

@classmethod
def _register_modified(cls, unmod_cls, mod_cls):
type(cls)._modifier_lookup[mod_cls] = cls
cls._sub_classes.add(mod_cls)
cls._sub_class_cache[unmod_cls] = mod_cls
if _DEBUG:
# O(n) assert, but its a pretty key invariant
assert set(cls._sub_classes) == set(cls._sub_class_cache.values())

def is_modified(T):
return T in _ModifierMeta._modifier_lookup

def is_modifier(T):
return issubclass(T, AbstractModifier)

def get_modifier(T):
if is_modified(T):
return _ModifierMeta._modifier_lookup[T]
else:
raise TypeError(f'{T} has no modifiers')

def get_unmodified(T):
if is_modified(T):
unmod = T.__bases__[0]
if _DEBUG:
# Not an expensive assert but as there is a
# already a debug guard might as well use it.
mod = get_modifier(T)
assert mod._sub_class_cache[unmod] is T
return unmod
else:
raise TypeError(f'{T} has no modifiers')

_mod_cache = weakref.WeakValueDictionary()
# This is a factory for type modifiers.
Expand All @@ -45,7 +102,7 @@ def make_modifier(name, cache=False):
except KeyError:
pass

ModType = _ModifierMeta(name, (), {'_sub_classes' : weakref.WeakValueDictionary()})
ModType = _ModifierMeta(name, (AbstractModifier,), {})

if cache:
return _mod_cache.setdefault(name, ModType)
Expand Down
28 changes: 26 additions & 2 deletions tests/test_modifiers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,46 @@
from hwtypes.modifiers import make_modifier
import pytest

from hwtypes import Bit, AbstractBit
import hwtypes.modifiers as modifiers
from hwtypes.modifiers import make_modifier, is_modified, is_modifier
from hwtypes.modifiers import get_modifier, get_unmodified

modifiers._DEBUG = True

def test_basic():
Global = make_modifier("Global")
GlobalBit = Global(Bit)

assert GlobalBit is Global(Bit)

assert issubclass(GlobalBit, Bit)
assert issubclass(GlobalBit, AbstractBit)
assert issubclass(GlobalBit, Global)
assert issubclass(GlobalBit, Global(AbstractBit))

global_bit = GlobalBit(0)

assert isinstance(global_bit, GlobalBit)
assert isinstance(global_bit, Bit)
assert isinstance(global_bit, AbstractBit)
assert isinstance(global_bit, Global)
assert isinstance(global_bit, Global(AbstractBit))

assert is_modifier(Global)
assert is_modified(GlobalBit)
assert not is_modifier(Bit)
assert not is_modified(Bit)
assert not is_modified(Global)

assert get_modifier(GlobalBit) is Global
assert get_unmodified(GlobalBit) is Bit

with pytest.raises(TypeError):
get_modifier(Bit)

with pytest.raises(TypeError):
get_unmodified(Bit)


def test_cache():
G1 = make_modifier("Global", cache=True)
Expand All @@ -24,4 +49,3 @@ def test_cache():

assert G1 is G2
assert G1 is not G3

0 comments on commit 59d796e

Please sign in to comment.