Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix QasmWriter.add_multi_bit() for non-register-aligned arguments #1572

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Features:

* DecomposeTK2 pass and transform can now accept a float for ZZPhase_fidelity.
* DecomposeTK2 pass now has a json representation when it contains no functions.
* Fix QASM conversion of non-register-aligned `MultiBitOp`.

1.32.0 (September 2024)
-----------------------
Expand Down
48 changes: 31 additions & 17 deletions pytket/pytket/qasm/qasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ def make_params_str(params: Optional[List[Union[float, Expr]]]) -> str:
return s


def make_args_str(args: List[UnitID]) -> str:
def make_args_str(args: Sequence[UnitID]) -> str:
s = ""
for i in range(len(args)):
s += f"{args[i]}"
Expand Down Expand Up @@ -1397,6 +1397,8 @@ def __init__(
self.cregs = {}
self.qregs = {}

self.cregs_as_bitseqs = set(tuple(creg) for creg in self.cregs.values())

# for holding condition values when writing Conditional blocks
# the size changes when adding and removing scratch bits
self.scratch_reg = BitRegister(
Expand All @@ -1423,7 +1425,7 @@ def write_params(self, params: Optional[List[Union[float, Expr]]]) -> None:
params_str = make_params_str(params)
self.strings.add_string(params_str)

def write_args(self, args: List[UnitID]) -> None:
def write_args(self, args: Sequence[UnitID]) -> None:
args_str = make_args_str(args)
self.strings.add_string(args_str)

Expand Down Expand Up @@ -1577,7 +1579,7 @@ def remove_unused_predicate(self, pred_label: int) -> bool:
self.strings.del_string(pred_label)
return True

def add_conditional(self, op: Conditional, args: List[UnitID]) -> None:
def add_conditional(self, op: Conditional, args: Sequence[UnitID]) -> None:
control_bits = args[: op.width]
if op.width == 1 and hqs_header(self.header):
variable = str(control_bits[0])
Expand Down Expand Up @@ -1666,12 +1668,24 @@ def add_copy_bits(self, op: CopyBitsOp, args: List[Bit]) -> None:
self.mark_as_written(label, f"{bit_l}")

def add_multi_bit(self, op: MultiBitOp, args: List[Bit]) -> None:
assert len(args) >= 2
registers_involved = [arg.reg_name for arg in args[:2]]
if len(args) > 2 and args[2].reg_name not in registers_involved:
# there is a distinct output register
registers_involved.append(args[2].reg_name)
self.add_op(op.basic_op, [self.cregs[name] for name in registers_involved]) # type: ignore
basic_op = op.basic_op
basic_n = basic_op.n_inputs + basic_op.n_outputs + basic_op.n_input_outputs
n_args = len(args)
assert n_args % basic_n == 0
arity = n_args // basic_n

# If the operation is register-aligned we can write it more succinctly.
poss_regs = [
tuple(args[basic_n * i + j] for i in range(arity)) for j in range(basic_n)
]
if all(poss_reg in self.cregs_as_bitseqs for poss_reg in poss_regs):
# The operation is register-aligned.
self.add_op(basic_op, [poss_regs[j][0].reg_name for j in range(basic_n)]) # type: ignore
else:
# The operation is not register-aligned.
for i in range(arity):
basic_args = args[basic_n * i : basic_n * (i + 1)]
self.add_op(basic_op, basic_args)

def add_explicit_op(self, op: Op, args: List[Bit]) -> None:
# &, ^ and | gates
Expand Down Expand Up @@ -1721,11 +1735,11 @@ def add_wasm(self, op: WASMOp, args: List[Bit]) -> None:
for variable in outputs:
self.mark_as_written(label, variable)

def add_measure(self, args: List[UnitID]) -> None:
def add_measure(self, args: Sequence[UnitID]) -> None:
label = self.strings.add_string(f"measure {args[0]} -> {args[1]};\n")
self.mark_as_written(label, f"{args[1]}")

def add_zzphase(self, param: Union[float, Expr], args: List[UnitID]) -> None:
def add_zzphase(self, param: Union[float, Expr], args: Sequence[UnitID]) -> None:
# as op.params returns reduced parameters, we can assume
# that 0 <= param < 4
if param > 1:
Expand All @@ -1739,7 +1753,7 @@ def add_zzphase(self, param: Union[float, Expr], args: List[UnitID]) -> None:
self.write_params([param])
self.write_args(args)

def add_data(self, op: BarrierOp, args: List[UnitID]) -> None:
def add_data(self, op: BarrierOp, args: Sequence[UnitID]) -> None:
if op.data == "":
opstr = _tk_to_qasm_noparams[OpType.Barrier]
else:
Expand All @@ -1748,18 +1762,18 @@ def add_data(self, op: BarrierOp, args: List[UnitID]) -> None:
self.strings.add_string(" ")
self.write_args(args)

def add_gate_noparams(self, op: Op, args: List[UnitID]) -> None:
def add_gate_noparams(self, op: Op, args: Sequence[UnitID]) -> None:
self.strings.add_string(_tk_to_qasm_noparams[op.type])
self.strings.add_string(" ")
self.write_args(args)

def add_gate_params(self, op: Op, args: List[UnitID]) -> None:
def add_gate_params(self, op: Op, args: Sequence[UnitID]) -> None:
optype, params = _get_optype_and_params(op)
self.strings.add_string(_tk_to_qasm_params[optype])
self.write_params(params)
self.write_args(args)

def add_extra_noparams(self, op: Op, args: List[UnitID]) -> Tuple[str, str]:
def add_extra_noparams(self, op: Op, args: Sequence[UnitID]) -> Tuple[str, str]:
optype = op.type
opstr = _tk_to_qasm_extra_noparams[optype]
gatedefstr = ""
Expand All @@ -1769,7 +1783,7 @@ def add_extra_noparams(self, op: Op, args: List[UnitID]) -> Tuple[str, str]:
mainstr = opstr + " " + make_args_str(args)
return gatedefstr, mainstr

def add_extra_params(self, op: Op, args: List[UnitID]) -> Tuple[str, str]:
def add_extra_params(self, op: Op, args: Sequence[UnitID]) -> Tuple[str, str]:
optype, params = _get_optype_and_params(op)
assert params is not None
opstr = _tk_to_qasm_extra_params[optype]
Expand All @@ -1782,7 +1796,7 @@ def add_extra_params(self, op: Op, args: List[UnitID]) -> Tuple[str, str]:
mainstr = opstr + make_params_str(params) + make_args_str(args)
return gatedefstr, mainstr

def add_op(self, op: Op, args: List[UnitID]) -> None:
def add_op(self, op: Op, args: Sequence[UnitID]) -> None:
optype, _params = _get_optype_and_params(op)
if optype == OpType.RangePredicate:
assert isinstance(op, RangePredicateOp)
Expand Down
25 changes: 25 additions & 0 deletions pytket/tests/qasm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,31 @@ def test_range_predicates_with_non_scratch_bits() -> None:
)


def test_multibitop() -> None:
# https://github.com/CQCL/tket/issues/1327
c = Circuit()
areg = c.add_c_register("a", 2)
breg = c.add_c_register("b", 2)
creg = c.add_c_register("c", 2)
c.add_c_and_to_registers(areg, breg, creg)
mbop = c.get_commands()[0].op
c.add_gate(mbop, [areg[0], areg[1], breg[0], breg[1], creg[0], creg[1]])
qasm = circuit_to_qasm_str(c, header="hqslib1")
assert (
qasm
== """OPENQASM 2.0;
include "hqslib1.inc";

creg a[2];
creg b[2];
creg c[2];
c = a & b;
b[0] = a[0] & a[1];
c[1] = b[1] & c[0];
"""
)


if __name__ == "__main__":
test_qasm_correct()
test_qasm_qubit()
Expand Down
5 changes: 3 additions & 2 deletions pytket/tests/qasm_test_files/test18_output.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ c[0] = a[0];
c[1] = a[1];
if(b!=2) c[1] = b[1] & a[1];
if(b!=2) c[1] = a[0] | c[1];
c = b & a;
c = d | c;
c[0] = b[0] & a[0];
c[1] = b[1] & a[1];
c[0] = d[0] | c[0];
d = 1;
d[0] = a[0] ^ d[0];
if(c>=2) h q[0];
Expand Down
Loading