From 07a5b2b3aa11ae7edc4ea0f66e8a9077fdfc91bc Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 30 Jul 2019 10:26:18 -0700 Subject: [PATCH] Update rebind_bitvector interface to match rebind --- hwtypes/adt_util.py | 13 ++++++++----- tests/test_rebind.py | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/hwtypes/adt_util.py b/hwtypes/adt_util.py index 9431be3..f7404eb 100644 --- a/hwtypes/adt_util.py +++ b/hwtypes/adt_util.py @@ -3,16 +3,19 @@ from .util import _issubclass -def rebind_bitvector(adt, bv_type: AbstractBitVectorMeta): - if _issubclass(adt, AbstractBitVector): +def rebind_bitvector( + adt, + bv_type_0: AbstractBitVectorMeta, + bv_type_1: AbstractBitVectorMeta): + if _issubclass(adt, bv_type_0): if adt.is_sized: - return bv_type[adt.size] + return bv_type_1[adt.size] else: - return bv_type + return bv_type_1 elif isinstance(adt, BoundMeta): new_adt = adt for field in adt.fields: - new_field = rebind_bitvector(field, bv_type) + new_field = rebind_bitvector(field, bv_type_0, bv_type_1) new_adt = new_adt.rebind(field, new_field) return new_adt else: diff --git a/tests/test_rebind.py b/tests/test_rebind.py index b01726c..559e33d 100644 --- a/tests/test_rebind.py +++ b/tests/test_rebind.py @@ -130,13 +130,13 @@ class F(Product): def test_rebind_bv(): - P_bound = rebind_bitvector(P, BitVector) + P_bound = rebind_bitvector(P, AbstractBitVector, BitVector) assert P_bound.X == BitVector[16] assert P_bound.S == Sum[BitVector[4], BitVector[8]] assert P_bound.T[0] == BitVector[32] assert P_bound.F.Y == BitVector - P_unbound = rebind_bitvector(P_bound, AbstractBitVector) + P_unbound = rebind_bitvector(P_bound, BitVector, AbstractBitVector) assert P_unbound.X == AbstractBitVector[16] assert P_unbound.S == Sum[AbstractBitVector[4], AbstractBitVector[8]] assert P_unbound.T[0] == AbstractBitVector[32]