Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert revert of 32-to-64-bit update #1456

Merged
merged 4 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytket/binders/include/UnitRegister.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ conventions defined here:
registers are up to _TKET_REG_WIDTH wide in bits and are interpreted as
equivalent to the C++ type _tket_uint_t
*/
#define _TKET_REG_WIDTH 32
typedef uint32_t _tket_uint_t;
#define _TKET_REG_WIDTH 64
typedef uint64_t _tket_uint_t;

template <typename T>
class UnitRegister {
Expand Down
2 changes: 1 addition & 1 deletion pytket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def package(self):
cmake.install()

def requirements(self):
self.requires("tket/1.3.10@tket/stable")
self.requires("tket/1.3.11@tket/stable")
self.requires("tklog/0.3.3@tket/stable")
self.requires("tkrng/0.3.3@tket/stable")
self.requires("tkassert/0.3.4@tket/stable")
Expand Down
6 changes: 6 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

Unreleased
----------

* Support classical transforms and predicates, and QASM registers, with up to 64
bits. Add an attribute to the pytket module to assert this.

1.29.2 (June 2024)
------------------

Expand Down
6 changes: 6 additions & 0 deletions pytket/pytket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@
config.write_file(pytket_config_file)

__path__ = __import__("pkgutil").extend_path(__path__, __name__)

"""Flag indicating 64-bit support.
If True, classical transforms and predicates, and QASM registers, with up to 64
bits are supported."""
bit_width_64 = True
2 changes: 1 addition & 1 deletion pytket/pytket/_tket/unit_id.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,4 @@ _DEBUG_ONE_REG_PREFIX: str = 'tk_DEBUG_ONE_REG'
_DEBUG_ZERO_REG_PREFIX: str = 'tk_DEBUG_ZERO_REG'
_TEMP_BIT_NAME: str = 'tk_SCRATCH_BIT'
_TEMP_BIT_REG_BASE: str = 'tk_SCRATCH_BITREG'
_TEMP_REG_SIZE: int = 32
_TEMP_REG_SIZE: int = 64
2 changes: 1 addition & 1 deletion pytket/pytket/circuit/add_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _add_condition(
target_bits = pred_exp.to_list()

minval = 0
maxval = (1 << 32) - 1
maxval = (1 << 64) - 1
if isinstance(condition, RegLt):
maxval = pred_val - 1
elif isinstance(condition, RegGt):
Expand Down
11 changes: 11 additions & 0 deletions pytket/pytket/qasm/qasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,18 @@ def _retrieve_registers(


def _parse_range(minval: int, maxval: int, maxwidth: int) -> Tuple[str, int]:
if maxwidth > 64:
raise NotImplementedError("Register width exceeds maximum of 64.")

REGMAX = (1 << maxwidth) - 1

if minval > REGMAX:
raise NotImplementedError("Range's lower bound exceeds register capacity.")
elif minval > maxval:
raise NotImplementedError("Range's lower bound exceeds upper bound.")
elif maxval > REGMAX:
maxval = REGMAX

if minval == maxval:
return ("==", minval)
elif minval == 0:
Expand Down
26 changes: 13 additions & 13 deletions pytket/tests/classical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

from pytket.passes import DecomposeClassicalExp, FlattenRegisters

from strategies import reg_name_regex, binary_digits, uint32 # type: ignore
from strategies import reg_name_regex, binary_digits, uint32, uint64 # type: ignore

curr_file_path = Path(__file__).resolve().parent

Expand Down Expand Up @@ -840,7 +840,7 @@ def primitive_reg_logic_exps(
RegGeq,
),
):
const_compare = draw(uint32)
const_compare = draw(uint64)
args.append(const_compare)
else:
args.append(draw(bit_regs))
Expand All @@ -854,8 +854,8 @@ def primitive_reg_logic_exps(
@given(
reg_exp=primitive_reg_logic_exps(),
constants=strategies.tuples(
uint32,
uint32,
uint64,
uint64,
),
)
def test_reg_exp(reg_exp: RegLogicExp, constants: Tuple[int, int]) -> None:
Expand Down Expand Up @@ -929,7 +929,7 @@ def composite_bit_logic_exps(
def composite_reg_logic_exps(
draw: DrawType,
regs: SearchStrategy[BitRegister] = bit_register(),
constants: SearchStrategy[int] = uint32,
constants: SearchStrategy[int] = uint64,
operators: SearchStrategy[Callable] = strategies.sampled_from(
[
operator.and_,
Expand Down Expand Up @@ -979,7 +979,7 @@ def reg_const_predicates(
operators: SearchStrategy[
Callable[[Union[RegLogicExp, BitRegister], int], PredicateExp]
] = strategies.sampled_from([reg_eq, reg_neq, reg_lt, reg_gt, reg_leq, reg_geq]),
constants: SearchStrategy[int] = uint32,
constants: SearchStrategy[int] = uint64,
) -> PredicateExp:
return draw(operators)(draw(exp), draw(constants)) # type: ignore

Expand Down Expand Up @@ -1131,10 +1131,10 @@ def test_decomposition_known() -> None:
)
check_serialization_roundtrip(circ)

temp_bits = BitRegister(_TEMP_BIT_NAME, 32)
temp_bits = BitRegister(_TEMP_BIT_NAME, 64)

def temp_reg(i: int) -> BitRegister:
return BitRegister(f"{_TEMP_BIT_REG_BASE}_{i}", 32)
return BitRegister(f"{_TEMP_BIT_REG_BASE}_{i}", 64)

for b in (temp_bits[i] for i in range(0, 10)):
conditioned_circ.add_bit(b)
Expand Down Expand Up @@ -1170,13 +1170,13 @@ def temp_reg(i: int) -> BitRegister:
conditioned_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5])
conditioned_circ.Y(qreg[4], condition_bits=[temp_bits[5]], condition_value=0)
conditioned_circ.add_c_range_predicate(
4, 4294967295, registers_lists[3], temp_bits[6]
4, 18446744073709551615, registers_lists[3], temp_bits[6]
)
conditioned_circ.Z(qreg[5], condition_bits=[temp_bits[6]], condition_value=1)
conditioned_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7])
conditioned_circ.S(qreg[6], condition_bits=[temp_bits[7]], condition_value=1)
conditioned_circ.add_c_range_predicate(
3, 4294967295, registers_lists[5], temp_bits[8]
3, 18446744073709551615, registers_lists[5], temp_bits[8]
)
conditioned_circ.T(qreg[7], condition_bits=[temp_bits[8]], condition_value=1)

Expand All @@ -1196,7 +1196,7 @@ def temp_reg(i: int) -> BitRegister:
decomposed_circ.add_bit(b)

decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_0", 3))
decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_1", 32))
decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_1", 64))

decomposed_circ.H(qreg[0], condition_bits=[bits[0]], condition_value=1)
decomposed_circ.X(qreg[0], condition_bits=[bits[1]], condition_value=1)
Expand All @@ -1211,11 +1211,11 @@ def temp_reg(i: int) -> BitRegister:
decomposed_circ.add_c_range_predicate(0, 5, registers_lists[1], temp_bits[4])
decomposed_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5])
decomposed_circ.add_c_range_predicate(
4, 4294967295, registers_lists[3], temp_bits[6]
4, 18446744073709551615, registers_lists[3], temp_bits[6]
)
decomposed_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7])
decomposed_circ.add_c_range_predicate(
3, 4294967295, registers_lists[5], temp_bits[8]
3, 18446744073709551615, registers_lists[5], temp_bits[8]
)

decomposed_circ.add_c_xor(bits[5], bits[6], temp_bits[2])
Expand Down
8 changes: 8 additions & 0 deletions pytket/tests/qasm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,14 @@ def test_const_condition() -> None:
)


def test_range_with_maxwidth() -> None:
c = Circuit(1)
a = c.add_c_register("a", 8)
c.X(0, condition=reg_geq(a, 1))
qasm = circuit_to_qasm_str(c, header="hqslib1", maxwidth=63)
assert "if(a>=1) x q[0];" in qasm


if __name__ == "__main__":
test_qasm_correct()
test_qasm_qubit()
Expand Down
1 change: 1 addition & 0 deletions pytket/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

binary_digits = st.sampled_from((0, 1))
uint32 = st.integers(min_value=1, max_value=1 << 32 - 1)
uint64 = st.integers(min_value=1, max_value=1 << 64 - 1)
reg_name_regex = re.compile("[a-z][a-zA-Z0-9_]*")


Expand Down
10 changes: 5 additions & 5 deletions schemas/circuit_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@
},
"lower": {
"type": "integer",
"maximum": 4294967295,
"description": "The inclusive minimum of the RangePredicate as a uint32."
"maximum": 18446744073709551615,
"description": "The inclusive minimum of the RangePredicate as a uint64."
},
"upper": {
"type": "integer",
"maximum": 4294967295,
"description": "The inclusive maximum of the RangePredicate as a uint32."
"maximum": 18446744073709551615,
"description": "The inclusive maximum of the RangePredicate as a uint64."
}
},
"required": [
Expand Down Expand Up @@ -1179,4 +1179,4 @@
"additionalProperties": false
}
}
}
}
2 changes: 1 addition & 1 deletion tket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TketConan(ConanFile):
name = "tket"
version = "1.3.10"
version = "1.3.11"
package_type = "library"
license = "Apache 2"
homepage = "https://github.com/CQCL/tket"
Expand Down
22 changes: 11 additions & 11 deletions tket/include/tket/Ops/ClassicalOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,18 @@ class ClassicalTransformOp : public ClassicalEvalOp {
* @param values table of binary-encoded values
* @param name name of operation
*
* @pre n <= 32
* @pre n <= 64
*/
ClassicalTransformOp(
unsigned n, const std::vector<uint32_t> &values,
unsigned n, const std::vector<uint64_t> &values,
const std::string &name = "ClassicalTransform");

std::vector<bool> eval(const std::vector<bool> &x) const override;

std::vector<uint32_t> get_values() const { return values_; }
std::vector<uint64_t> get_values() const { return values_; }

private:
const std::vector<uint32_t> values_;
const std::vector<uint64_t> values_;
};

/**
Expand Down Expand Up @@ -341,15 +341,15 @@ class RangePredicateOp : public PredicateOp {
* @param b upper bound in little-endian encoding
*/
RangePredicateOp(
unsigned n, uint32_t a = 0,
uint32_t b = std::numeric_limits<uint32_t>::max())
unsigned n, uint64_t a = 0,
uint64_t b = std::numeric_limits<uint64_t>::max())
: PredicateOp(OpType::RangePredicate, n, "RangePredicate"), a(a), b(b) {}

std::string get_name(bool latex) const override;

uint32_t upper() const { return b; }
uint64_t upper() const { return b; }

uint32_t lower() const { return a; }
uint64_t lower() const { return a; }

std::vector<bool> eval(const std::vector<bool> &x) const override;

Expand All @@ -359,8 +359,8 @@ class RangePredicateOp : public PredicateOp {
bool is_equal(const Op &other) const override;

private:
uint32_t a;
uint32_t b;
uint64_t a;
uint64_t b;
};

/**
Expand All @@ -378,7 +378,7 @@ class ExplicitPredicateOp : public PredicateOp {
* @param values table of values
* @param name name of operation
*
* @pre n <= 32
* @pre n <= 64
*/
ExplicitPredicateOp(
unsigned n, const std::vector<bool> &values,
Expand Down
4 changes: 2 additions & 2 deletions tket/include/tket/Utils/HelperFunctions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ bimap_to_map(MapT& bm) {
}

/**
* Reverse bits 0,1,...,w-1 of the number v, assuming v < 2^w and w <= 32.
* Reverse bits 0,1,...,w-1 of the number v, assuming v < 2^w and w <= 64.
*/
uint32_t reverse_bits(uint32_t v, unsigned w);
uint64_t reverse_bits(uint64_t v, unsigned w);

/**
* @brief
Expand Down
Loading
Loading