diff --git a/hwtypes/modifiers.py b/hwtypes/modifiers.py index 15df6ba..b10f262 100644 --- a/hwtypes/modifiers.py +++ b/hwtypes/modifiers.py @@ -1,6 +1,10 @@ import types import weakref +__ALL__ = ['new', 'make_modifier', 'is_modified', 'is_modifier', 'get_modifier', 'get_unmodified'] + + +_DEBUG = False #special sentinal value class _MISSING: pass @@ -17,24 +21,77 @@ class T(klass): pass return T class _ModifierMeta(type): + _modifier_lookup = weakref.WeakKeyDictionary() def __instancecheck__(cls, obj): - return type(obj) in cls._sub_classes.values() + if cls is AbstractModifier: + return super().__instancecheck__(obj) + else: + return type(obj) in cls._sub_classes - def __subclasscheck__(cls, typ): - return typ in cls._sub_classes.values() + def __subclasscheck__(cls, T): + if cls is AbstractModifier: + return super().__subclasscheck__(T) + else: + return T in cls._sub_classes def __call__(cls, *args): + if cls is AbstractModifier: + raise TypeError('Cannot instance or apply AbstractModifier') + if len(args) != 1: return super().__call__(*args) - sub = args[0] + + unmod_cls = args[0] try: - return cls._sub_classes[sub] + return cls._sub_class_cache[unmod_cls] except KeyError: pass - mod_sub_name = cls.__name__ + sub.__name__ - mod_sub = type(mod_sub_name, (sub,), {}) - return cls._sub_classes.setdefault(sub, mod_sub) + mod_name = cls.__name__ + unmod_cls.__name__ + bases = [unmod_cls] + for base in unmod_cls.__bases__: + bases.append(cls(base)) + mod_cls = type(mod_name, tuple(bases), {}) + cls._register_modified(unmod_cls, mod_cls) + return mod_cls + +class AbstractModifier(metaclass=_ModifierMeta): + def __init_subclass__(cls, **kwargs): + cls._sub_class_cache = weakref.WeakValueDictionary() + cls._sub_classes = weakref.WeakSet() + + @classmethod + def _register_modified(cls, unmod_cls, mod_cls): + type(cls)._modifier_lookup[mod_cls] = cls + cls._sub_classes.add(mod_cls) + cls._sub_class_cache[unmod_cls] = mod_cls + if _DEBUG: + # O(n) assert, but its a pretty key invariant + assert set(cls._sub_classes) == set(cls._sub_class_cache.values()) + +def is_modified(T): + return T in _ModifierMeta._modifier_lookup + +def is_modifier(T): + return issubclass(T, AbstractModifier) + +def get_modifier(T): + if is_modified(T): + return _ModifierMeta._modifier_lookup[T] + else: + raise TypeError(f'{T} has no modifiers') + +def get_unmodified(T): + if is_modified(T): + unmod = T.__bases__[0] + if _DEBUG: + # Not an expensive assert but as there is a + # already a debug guard might as well use it. + mod = get_modifier(T) + assert mod._sub_class_cache[unmod] is T + return unmod + else: + raise TypeError(f'{T} has no modifiers') _mod_cache = weakref.WeakValueDictionary() # This is a factory for type modifiers. @@ -45,7 +102,7 @@ def make_modifier(name, cache=False): except KeyError: pass - ModType = _ModifierMeta(name, (), {'_sub_classes' : weakref.WeakValueDictionary()}) + ModType = _ModifierMeta(name, (AbstractModifier,), {}) if cache: return _mod_cache.setdefault(name, ModType) diff --git a/tests/test_modifiers.py b/tests/test_modifiers.py index e97896f..e979ef0 100644 --- a/tests/test_modifiers.py +++ b/tests/test_modifiers.py @@ -1,14 +1,22 @@ -from hwtypes.modifiers import make_modifier +import pytest + from hwtypes import Bit, AbstractBit +import hwtypes.modifiers as modifiers +from hwtypes.modifiers import make_modifier, is_modified, is_modifier +from hwtypes.modifiers import get_modifier, get_unmodified +modifiers._DEBUG = True def test_basic(): Global = make_modifier("Global") GlobalBit = Global(Bit) + assert GlobalBit is Global(Bit) + assert issubclass(GlobalBit, Bit) assert issubclass(GlobalBit, AbstractBit) assert issubclass(GlobalBit, Global) + assert issubclass(GlobalBit, Global(AbstractBit)) global_bit = GlobalBit(0) @@ -16,6 +24,23 @@ def test_basic(): assert isinstance(global_bit, Bit) assert isinstance(global_bit, AbstractBit) assert isinstance(global_bit, Global) + assert isinstance(global_bit, Global(AbstractBit)) + + assert is_modifier(Global) + assert is_modified(GlobalBit) + assert not is_modifier(Bit) + assert not is_modified(Bit) + assert not is_modified(Global) + + assert get_modifier(GlobalBit) is Global + assert get_unmodified(GlobalBit) is Bit + + with pytest.raises(TypeError): + get_modifier(Bit) + + with pytest.raises(TypeError): + get_unmodified(Bit) + def test_cache(): G1 = make_modifier("Global", cache=True) @@ -24,4 +49,3 @@ def test_cache(): assert G1 is G2 assert G1 is not G3 -