Skip to content

Commit

Permalink
Many fixes
Browse files Browse the repository at this point in the history
- Fix 2 qubit Gates with redundant --- padding
    using id gates to avoid out of order parsing
- fix parse order
- Add pennylane -> str
- better tests
- width equiliser
- phase gate support
  • Loading branch information
plutoniumm committed Mar 4, 2024
1 parent 0ef7375 commit 3fd6ade
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 75 deletions.
5 changes: 5 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ A tiny DSL to compile to quantum circuits. The goal is to speed up the time it t

[Qiskit](https://qiskit.org/) • [CudaQ](https://nvidia.github.io/cuda-quantum/latest/install.html) • [Pennylane](https://docs.pennylane.ai/en/stable/code/qml.html)

<!-- hooks need some preproc for qis & pennylane -->
- hooks for qiskit
- hooks for pennylane


## Install
```py
pip install abrax
Expand Down
29 changes: 20 additions & 9 deletions abrax/_utils_compile.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
# padd rows with [] to make them equal width
def equalize_widths(matrix):
mlen = max([len(i) for i in matrix])
for i in range(len(matrix)):
matrix[i] += ['---'] * (mlen - len(matrix[i]))

return matrix


def matrix_to_str(matrix):
qubits = len(matrix)
matrix = equalize_widths(matrix)
matrix = [*zip(*matrix)]
stris = ['-' + str(i) for i in range(len(matrix))]
stris = ['-' + str(i) for i in range(qubits)]
for i in range(len(matrix)):
for j in range(len(matrix[0])):
if isinstance(matrix[i][j], list):
arg = matrix[i][j][1]
# arg may have , in it and parser can deal with it
stris[j] += f' {matrix[i][j][0]}({arg})'
else:
stris[j] += f' {matrix[i][j]}'
if matrix[i][j] == 'id':
stris[j] += ' ' + ('-' * 3)
else:
stris[j] += f' {matrix[i][j]}'

# padding stris to make them equal width
max = 0
for j in stris:
if len(j) > max:
max = len(j)

mlen = max([len(j) for j in stris])
for j in range(len(stris)):
stris[j] += ' ' * (max - len(stris[j]))
stris[j] += ' ' * (mlen - len(stris[j]))

return '\n'.join(stris)

Expand All @@ -26,9 +36,10 @@ def matrix_to_str(matrix):
['h', 'x', 'y', 'z'],
['s', 't', 'sdg', 'tdg'],
['cx', 'cy', 'cz'],
['crx', 'cry', 'crz'],
['swap', 'iswap'],
['rx', 'ry', 'rz'],
['u', 'u1', 'u2', 'u3'],
['id'],
['id', 'p'],
]
valid_gates = [i for j in valid_gates for i in j]
81 changes: 72 additions & 9 deletions abrax/_utils_parse.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,84 @@
from dataclasses import dataclass


@dataclass
class PNLGate:
name: str
wires: list
parameters: list


def pwires(wires):
if not isinstance(wires, list):
wires = wires.tolist()
return wires


def pnl_tdg(wire):
return [
PNLGate('RZ', [wire], [-0.7854]),
]


# rot(a,b,c)->rz(c)ry(b)rz(a)
def pnl_rot(rot):
params = rot.parameters
wires = pwires(rot.wires)
a, b, c = params

return [
PNLGate('RZ', wires, [c]),
PNLGate('RY', wires, [b]),
PNLGate('RZ', wires, [a]),
]


def pnl_toffoli(toffoli):
a, b, c = pwires(toffoli.wires)
# Ref: Treat as same
# Tdg = RZ(-pi/4)
# T = RZ(pi/4)
# Sdg = RZ(-pi/2)
# S = RZ(pi/2)
substitute = [
['Hadamard', [c]],
['CNOT', [c, b]],
['Tdg', [c]],
['CNOT', [c, a]],
['T', [c]],
['CNOT', [c, b]],
['T', [b]],
['Tdg', [c]],
['CNOT', [c, a]],
['CNOT', [b, a]],
['T', [c]],
['T', [a]],
['Tdg', [b]],
['Hadamard', [c]],
['CNOT', [b, a]],
]

for j in range(len(substitute)):
substitute[j] = PNLGate(
name=substitute[j][0],
wires=substitute[j][1],
parameters=[],
)

return substitute


pnl_gate_map = {
'id': 'Identity',
'h': 'Hadamard',
'x': 'PauliX',
'y': 'PauliY',
'z': 'PauliZ',
's': 'S',
't': 'T',
'rx': 'RX',
'ry': 'RY',
'rz': 'RZ',
'u': 'U3',
'cx': 'CNOT',
'cz': 'CZ',
'cy': 'CY',
'swap': 'SWAP',
'iswap': 'ISWAP',
'p': 'PhaseShift',
}


Expand Down Expand Up @@ -43,5 +107,4 @@ def parse_circuit(string):
return list(map(list, zip(*by_rows)))


def isClose(a, b):
return abs(a - b) < 0.001
isClose = lambda a, b: abs(a - b) < 0.001
109 changes: 88 additions & 21 deletions abrax/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ def isIndex(qc):
return '_index'


def qis_preprep(qc):
from typing import Dict, Union, Callable, List

Hook = Union[Callable, List[Callable]]


def qis_preprep(qc, hooks: Dict[str, Hook]):
from qiskit import QuantumCircuit
from numpy import pi

Expand All @@ -19,8 +24,11 @@ def qis_preprep(qc):
newc = QuantumCircuit(*qregs, *cregs)

for i in qc.data:
# i = (gate, qargs, cargs)
# gate is just the gate applied
# qargs is the qubits the gate is applied to
gate = i[0].name
if gate not in valid_gates:
if (gate not in valid_gates) and (gate not in hooks.keys()):
raise ValueError(
f'Invalid gate: {gate}, try decomposing the circuit. Or it may be unsupported by abraxas.'
)
Expand All @@ -32,6 +40,15 @@ def qis_preprep(qc):
elif gate == 'u3':
newc.u(i[0].params[0], i[0].params[1], i[0].params[2], i[1][0])
else:
# if hooks is not None and gate in hooks:
# # a gate may be replaced by multiple gates
# # hook is key: lambda (circuit, instruction): ...
# if isinstance(hooks[gate], list):
# for j in hooks[gate]:
# j(newc, i)
# else:
# hooks[gate](newc, i)
# else:
newc.append(i)

return newc
Expand All @@ -44,41 +61,82 @@ def getParam(p):
return str(p)


def compile_pennylane(qfunc) -> str:
from ._utils_parse import pnl_gate_map
def flat(l):
out = []
for item in l:
if isinstance(item, (list, tuple)):
out.extend(flat(item))
else:
out.append(item)
return out


def compile_pennylane(qfunc, hooks) -> str:
from ._utils_parse import (
pnl_gate_map,
pnl_toffoli,
PNLGate,
pnl_rot,
pwires,
)

pi = 3.141592653589793
tape = qfunc.qtape
num_qubits = len(tape.wires.tolist())
matrix = [[] for _ in range(num_qubits)]

ops2 = tape.operations
# everytime there is a Tofolli, remove it and add 2 CNOTs a,b | b,c
for i in range(len(ops2)):
if ops2[i].name == 'Toffoli':
a, b, c = ops2[i].wires.tolist()
tape.operations[i] = tape.operations[i]._replace(
name='CNOT', wires=[a, b]
)
tape.operations.insert(
i + 1, tape.operations[i]._replace(name='CNOT', wires=[b, c])
)
ops = tape.operations
# DECOMPOSITIONS
for i in range(len(ops)):
if ops[i].name == 'Toffoli':
ops[i] = pnl_toffoli(ops[i])
elif ops[i].name == 'Rot':
ops[i] = pnl_rot(ops[i])
elif ops[i].name == 'S':
wires = pwires(ops[i].wires)
ops[i] = [PNLGate('T', wires, [])] * 2
elif ops[i].name == 'Sdg':
ops[i] = [PNLGate('Tdg', wires, [])] * 2
else:
pass

ops = flat(ops)

for i in tape.operations:
for i in ops:
gate = i.name
qubits = i.wires.tolist()
qubits = pwires(i.wires)
params = i.parameters
# name reverse
if gate in pnl_gate_map.values():
gate = list(pnl_gate_map.keys())[
list(pnl_gate_map.values()).index(gate)
]

# manual map in for T and Tdg
if gate == 'T':
gate = 'rz'
params = [pi / 4]
elif gate == 'Tdg':
gate = 'rz'
params = [-pi / 4]
else:
gate = gate.lower()

if gate not in valid_gates:
print(f'USED GATE: {gate}')
raise ValueError(
f'Invalid gate: {gate}, try decomposing the circuit. Or it may be unsupported by abraxas.'
)

if len(qubits) == 2:
# first fill all unequal rows with id
# then add cx
mlen = max([len(x) for x in matrix])
for l in range(len(matrix)):
if len(matrix[l]) < mlen:
for _ in range(mlen - len(matrix[l])):
matrix[l].append('id')

if len(params) > 0:
param = ','.join([getParam(x) for x in params])
else:
Expand Down Expand Up @@ -106,13 +164,20 @@ def compile_qiskit(qc) -> str:
for i in qc.data:
gate = i[0].name
if gate not in valid_gates:
print(qc.draw())
raise ValueError(
f'Invalid gate: {gate}, try decomposing the circuit. Or it may be unsupported by abraxas.'
)
if gate == 'measure':
continue

# for cx, fill all rows with id
if len(i[1]) == 2:
mlen = max([len(x) for x in matrix])
for l in range(len(matrix)):
if len(matrix[l]) < mlen:
for _ in range(mlen - len(matrix[l])):
matrix[l].append('id')

qargs = i[1]
if len(i[0].params) > 0:
param = ','.join([getParam(x) for x in i[0].params])
Expand All @@ -139,12 +204,14 @@ def compile_qiskit(qc) -> str:
return matrix_to_str(matrix)


def toPrime(qc):
def toPrime(qc, hooks=None):
name = qc.__class__.__name__
if hooks is None:
hooks = {}
if name == 'QuantumCircuit':
qc2 = qis_preprep(qc)
qc2 = qis_preprep(qc, hooks)
return compile_qiskit(qc2)
elif name == 'QNode':
return compile_pennylane(qc)
return compile_pennylane(qc, hooks)
else:
raise ValueError(f'Unsupported circuit: {name}')
10 changes: 9 additions & 1 deletion abrax/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def parse_param(p, typ=None):

def resolve_qiskit(circuit, qc):
for _, layer in enumerate(circuit):
# read top to down, default is bottom to top
layer.reverse()
for wireNo, gate in enumerate(layer):
wireNo = len(layer) - wireNo - 1
if no_gate.match(gate):
continue
gate = gate.lower()
Expand Down Expand Up @@ -108,6 +111,7 @@ def u(self, a, b, c, u):
kernel.rz(c, u)

Kernel.u = u
Kernel.p = Kernel.r1

for _, layer in enumerate(circuit):
for wireNo, gate in enumerate(layer):
Expand Down Expand Up @@ -157,7 +161,11 @@ def resolve_pennylane(circuit):
gate = gate.lower()
if '(' in gate and ')' in gate:
gate_name = gate[: gate.index('(')]
op = pnl_gate_map[gate_name]
if gate_name in pnl_gate_map:
op = pnl_gate_map[gate_name]
else:
op = gate_name.upper()

param = gate[gate.index('(') + 1 : gate.index(')')]

if ',' in param:
Expand Down
Loading

0 comments on commit 3fd6ade

Please sign in to comment.