Skip to content

Commit

Permalink
Added support for ClExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAndresCQ committed Nov 11, 2024
1 parent 935f509 commit c2dd14a
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 1 deletion.
89 changes: 89 additions & 0 deletions pytket/extensions/cutensornet/structured_state/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
SetBitsOp,
CopyBitsOp,
RangePredicateOp,
ClExprOp,
ClassicalExpBox,
LogicExp,
BitWiseOp,
RegWiseOp,
)
from pytket._tket.circuit import ClExpr, ClOp, ClBitVar, ClRegVar


ExtendedLogicExp = Union[LogicExp, Bit, BitRegister, int]
Expand Down Expand Up @@ -56,6 +58,20 @@ def apply_classical_command(
# Check that the value is in the range
bits_dict[res_bit] = val >= op.lower and val <= op.upper

elif isinstance(op, ClExprOp):
# Convert bit_posn to dictionary of `ClBitVar` index to its value
bitvar_val = {var_id: int(bits_dict[args[bit_pos]]) for var_id, bit_pos in op.expr.bit_posn.items()}
# Convert reg_posn to dictionary of `ClRegVar` index to its value
regvar_val = {var_id: from_little_endian([bits_dict[args[bit_pos]] for bit_pos in reg_pos_list]) for var_id, reg_pos_list in op.expr.reg_posn.items()}
result = evaluate_clexpr(op.expr.expr, bitvar_val, regvar_val)

# The result is an int in little-endian encoding. We update the
# output register accordingly.
for bit_pos in op.expr.output_posn:
bits_dict[args[bit_pos]] = (result % 2) == 1
result = result >> 1
assert result == 0 # All bits consumed

elif isinstance(op, ClassicalExpBox):
the_exp = op.get_exp()
result = evaluate_logic_exp(the_exp, bits_dict)
Expand All @@ -74,6 +90,79 @@ def apply_classical_command(
raise NotImplementedError(f"Commands of type {op.type} are not supported.")


def evaluate_clexpr(expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int]) -> int:
"""Recursive evaluation of a ClExpr."""

# Evaluate arguments to operation
args_val = []
for arg in expr.args:
if isinstance(arg, int):
value = arg
elif isinstance(arg, ClBitVar):
value = bitvar_val[arg.index]
elif isinstance(arg, ClRegVar):
value = regvar_val[arg.index]
elif isinstance(arg, ClExpr):
value = evaluate_clexpr(arg, bitvar_val, regvar_val)
else:
raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.")

args_val.append(value)

# Apply the operation at the root of this ClExpr
if expr.op in [ClOp.BitAnd, ClOp.RegAnd]:
result = args_val[0] & args_val[1]
elif expr.op in [ClOp.BitOr, ClOp.RegOr]:
result = args_val[0] | args_val[1]
elif expr.op in [ClOp.BitXor, ClOp.RegXor]:
result = args_val[0] ^ args_val[1]
elif expr.op in [ClOp.BitEq, ClOp.RegEq]:
result = int(args_val[0] == args_val[1])
elif expr.op in [ClOp.BitNeq, ClOp.RegNeq]:
result = int(args_val[0] != args_val[1])
elif expr.op == ClOp.RegGeq:
result = int(args_val[0] >= args_val[1])
elif expr.op == ClOp.RegGt:
result = int(args_val[0] > args_val[1])
elif expr.op == ClOp.RegLeq:
result = int(args_val[0] <= args_val[1])
elif expr.op == ClOp.RegLt:
result = int(args_val[0] < args_val[1])
elif expr.op == ClOp.BitNot:
result = 1 - args_val[0]
# elif expr.op == ClOp.RegNot:
# result = int(args_val[0] == 0)
elif expr.op in [ClOp.BitZero, ClOp.RegZero]:
result = 0
elif expr.op in [ClOp.BitOne, ClOp.RegOne]:
result = 1
# elif expr.op == ClOp.RegAdd:
# result = args_val[0] + args_val[1]
# elif expr.op == ClOp.RegSub:
# result = args_val[0] - args_val[1]
# elif expr.op == ClOp.RegMul:
# result = args_val[0] * args_val[1]
# elif expr.op == ClOp.RegPow:
# result = int(args_val[0] ** args_val[1])
elif expr.op == ClOp.RegRsh:
result = args_val[0] >> args_val[1]
# elif expr.op == ClOp.RegNeg:
# result = -args_val[0]
else:
# TODO: Currently not supporting ClOp's RegDiv since it does not return int,
# so I am unsure what the semantic is meant to be.
# TODO: I don't now what to do with RegNot, since input
# is not guaranteed to be 0 or 1.
# TODO: It is not clear what to do with overflow of ADD, etc.
# so I have decided to not support them for now.
raise NotImplementedError(
f"Evaluation of {expr.op} not supported in ClExpr ",
"by pytket-cutensornet.",
)

return result


def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int:
"""Recursive evaluation of a LogicExp."""

Expand Down
193 changes: 192 additions & 1 deletion tests/test_structured_state_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
reg_eq,
)
from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp
from pytket.circuit.clexpr import wired_clexpr_from_logic_exp

from pytket.extensions.cutensornet.structured_state import (
CuTensorNetHandle,
Expand All @@ -25,6 +26,35 @@
# (see https://github.com/CQCL/pytket-qir/blob/main/tests/conditional_test.py)
# Further down, there are tests to check that the simulation works correctly.

def test_circuit_with_clexpr_i() -> None:
# test conditional handling

circ = Circuit(3)
a = circ.add_c_register("a", 5)
b = circ.add_c_register("b", 5)
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
circ.H(0)
wexpr, args = wired_clexpr_from_logic_exp(a | b, c)
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d)
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d)
circ.add_clexpr(wexpr, args, condition=a[4])
circ.H(0)
circ.Measure(Qubit(0), d[4])
circ.H(1)
circ.Measure(Qubit(1), d[3])
circ.H(2)
circ.Measure(Qubit(2), d[2])

with CuTensorNetHandle() as libhandle:
cfg = Config()
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0


def test_circuit_with_classicalexpbox_i() -> None:
# test conditional handling
Expand Down Expand Up @@ -52,6 +82,35 @@ def test_circuit_with_classicalexpbox_i() -> None:
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0

def test_circuit_with_clexpr_ii() -> None:
# test conditional handling with else case

circ = Circuit(3)
a = circ.add_c_register("a", 5)
b = circ.add_c_register("b", 5)
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
circ.H(0)
wexpr, args = wired_clexpr_from_logic_exp(a | b, c)
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d)
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d)
circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4]))
circ.H(0)
circ.Measure(Qubit(0), d[4])
circ.H(1)
circ.Measure(Qubit(1), d[3])
circ.H(2)
circ.Measure(Qubit(2), d[2])

with CuTensorNetHandle() as libhandle:
cfg = Config()
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0


def test_circuit_with_classicalexpbox_ii() -> None:
# test conditional handling with else case
Expand Down Expand Up @@ -81,6 +140,35 @@ def test_circuit_with_classicalexpbox_ii() -> None:
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0

@pytest.mark.skip(reason="Currently not supporting arithmetic operations in ClExpr")
def test_circuit_with_clexpr_iii() -> None:
# test complicated conditions and recursive classical op

circ = Circuit(2)

a = circ.add_c_register("a", 15)
b = circ.add_c_register("b", 15)
c = circ.add_c_register("c", 15)
d = circ.add_c_register("d", 15)
e = circ.add_c_register("e", 15)

circ.H(0)
bits = [Bit(i) for i in range(10)]
big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8]
circ.H(0, condition=big_exp)

wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c)
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e)
circ.add_clexpr(wexpr, args)

with CuTensorNetHandle() as libhandle:
cfg = Config()
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0


@pytest.mark.skip(reason="Currently not supporting arithmetic operations in LogicExp")
def test_circuit_with_classicalexpbox_iii() -> None:
Expand Down Expand Up @@ -239,6 +327,32 @@ def test_pytket_qir_conditional_10() -> None:
assert state.get_fidelity() == 1.0


def test_pytket_qir_conditional_11() -> None:
box_circ = Circuit(4)
box_circ.X(0)
box_circ.Y(1)
box_circ.Z(2)
box_circ.H(3)
box_c = box_circ.add_c_register("c", 5)

box_circ.H(0)

wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c)
box_circ.add_clexpr(wexpr, args)

cbox = CircBox(box_circ)
d = Circuit(4, 5)
a = d.add_c_register("a", 4)
d.add_circbox(cbox, [0, 2, 1, 3, 0, 1, 2, 3, 4], condition=a[0])

with CuTensorNetHandle() as libhandle:
cfg = Config()
state = simulate(libhandle, d, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0


def test_circuit_with_conditional_gate_v() -> None:
# test conditional with no register

Expand Down Expand Up @@ -430,7 +544,84 @@ def test_repeat_until_success_i() -> None:
assert np.allclose(target_state, output_state)


def test_repeat_until_success_ii() -> None:
def test_repeat_until_success_ii_clexpr() -> None:
# From Figure 1(c) of https://arxiv.org/pdf/1311.1074

attempts = 100

circ = Circuit()
qin = circ.add_q_register("qin", 1)
qaux = circ.add_q_register("aux", 2)
flag = circ.add_c_register("flag", 3)
circ.add_c_setbits([True, True], [flag[0], flag[1]]) # Set flag bits to 11
circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X)

for _ in range(attempts):
wexpr, args = wired_clexpr_from_logic_exp(
flag[0] | flag[1], [flag[2]] # Success if both are zero
)
circ.add_clexpr(wexpr, args)

circ.add_gate(
OpType.Reset, [qaux[0]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(
OpType.Reset, [qaux[1]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1)

circ.add_gate(OpType.T, [qin[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(
OpType.Tdg, [qaux[0]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(
OpType.CX, [qaux[1], qaux[0]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(OpType.T, [qaux[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(
OpType.CX, [qin[0], qaux[1]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(OpType.T, [qaux[1]], condition_bits=[flag[2]], condition_value=1)

circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1)
circ.Measure(qaux[0], flag[0], condition_bits=[flag[2]], condition_value=1)
circ.Measure(qaux[1], flag[1], condition_bits=[flag[2]], condition_value=1)

# From chat with Silas and exploring the RUS as a block matrix, we have noticed
# that the circuit is missing an X correction when this condition is satisfied
wexpr, args = wired_clexpr_from_logic_exp(flag[0] ^ flag[1], [flag[2]])
circ.add_clexpr(wexpr, args)
circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1)

circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X)

with CuTensorNetHandle() as libhandle:
cfg = Config()

state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0

# All of the flag bits should have turned False
assert all(not state.get_bits()[bit] for bit in flag)
# The auxiliary qubits should be in state |0>
prob = state.postselect({qaux[0]: 0, qaux[1]: 0})
assert np.isclose(prob, 1.0)

target_state = [np.sqrt(1 / 5), np.sqrt(4 / 5) * 1j]
output_state = state.get_statevector()
# As indicated in the paper, the gate is implemented up to global phase
global_phase = target_state[0] / output_state[0]
assert np.isclose(abs(global_phase), 1.0)
output_state *= global_phase
assert np.allclose(target_state, output_state)


def test_repeat_until_success_ii_classicalexpblox() -> None:
# From Figure 1(c) of https://arxiv.org/pdf/1311.1074

attempts = 100
Expand Down

0 comments on commit c2dd14a

Please sign in to comment.