Skip to content

Commit

Permalink
Added suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAndresCQ committed Nov 12, 2024
1 parent d66eba3 commit 5da5d19
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 45 deletions.
69 changes: 48 additions & 21 deletions pytket/extensions/cutensornet/structured_state/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,26 @@ def apply_classical_command(
)
for var_id, reg_pos_list in op.expr.reg_posn.items()
}
result = evaluate_clexpr(op.expr.expr, bitvar_val, regvar_val)
# Identify number of bits on each register
regvar_size = {
var_id: len(reg_pos_list)
for var_id, reg_pos_list in op.expr.reg_posn.items()
}
# Identify number of bits in output register
output_size = len(op.expr.output_posn)
result = evaluate_clexpr(
op.expr.expr, bitvar_val, regvar_val, regvar_size, output_size
)

# The result is an int in little-endian encoding. We update the
# output register accordingly.
for bit_pos in op.expr.output_posn:
bits_dict[args[bit_pos]] = (result % 2) == 1
result = result >> 1
assert result == 0 # All bits consumed
# If there has been overflow in the operations, error out.
# This can be detected if `result != 0`
if result != 0:
raise ValueError("Evaluation of the ClExpr resulted in overflow.")

elif isinstance(op, ClassicalExpBox):
the_exp = op.get_exp()
Expand All @@ -99,7 +111,11 @@ def apply_classical_command(


def evaluate_clexpr(
expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int]
expr: ClExpr,
bitvar_val: dict[int, int],
regvar_val: dict[int, int],
regvar_size: dict[int, int],
output_size: int,
) -> int:
"""Recursive evaluation of a ClExpr."""

Expand All @@ -113,7 +129,9 @@ def evaluate_clexpr(
elif isinstance(arg, ClRegVar):
value = regvar_val[arg.index]
elif isinstance(arg, ClExpr):
value = evaluate_clexpr(arg, bitvar_val, regvar_val)
value = evaluate_clexpr(
arg, bitvar_val, regvar_val, regvar_size, output_size
)
else:
raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.")

Expand All @@ -140,31 +158,39 @@ def evaluate_clexpr(
result = int(args_val[0] < args_val[1])
elif expr.op == ClOp.BitNot:
result = 1 - args_val[0]
# elif expr.op == ClOp.RegNot:
# result = int(args_val[0] == 0)
elif expr.op == ClOp.RegNot: # Bit-wise NOT (flip all bits)
n_bits = regvar_size[expr.args[0].index] # type: ignore
result = (2**n_bits - 1) ^ args_val[0] # XOR with all 1s bitstring
elif expr.op in [ClOp.BitZero, ClOp.RegZero]:
result = 0
elif expr.op in [ClOp.BitOne, ClOp.RegOne]:
elif expr.op == ClOp.BitOne:
result = 1
# elif expr.op == ClOp.RegAdd:
# result = args_val[0] + args_val[1]
# elif expr.op == ClOp.RegSub:
# result = args_val[0] - args_val[1]
# elif expr.op == ClOp.RegMul:
# result = args_val[0] * args_val[1]
# elif expr.op == ClOp.RegPow:
# result = int(args_val[0] ** args_val[1])
elif expr.op == ClOp.RegOne: # All 1s bitstring
n_bits = output_size
result = 2**n_bits - 1
elif expr.op == ClOp.RegAdd:
result = args_val[0] + args_val[1]
elif expr.op == ClOp.RegSub:
if args_val[0] < args_val[1]:
raise NotImplementedError(
"Currently not supporting ClOp.RegSub where the outcome is negative."
)
result = args_val[0] - args_val[1]
elif expr.op == ClOp.RegMul:
result = args_val[0] * args_val[1]
elif expr.op == ClOp.RegDiv: # floor(a / b)
result = args_val[0] // args_val[1]
elif expr.op == ClOp.RegPow:
result = int(args_val[0] ** args_val[1])
elif expr.op == ClOp.RegLsh:
result = args_val[0] << args_val[1]
elif expr.op == ClOp.RegRsh:
result = args_val[0] >> args_val[1]
# elif expr.op == ClOp.RegNeg:
# result = -args_val[0]
else:
# TODO: Currently not supporting ClOp's RegDiv since it does not return int,
# so I am unsure what the semantic is meant to be.
# TODO: I don't now what to do with RegNot, since input
# is not guaranteed to be 0 or 1.
# TODO: It is not clear what to do with overflow of ADD, etc.
# so I have decided to not support them for now.
# TODO: Not supporting RegNeg because I do not know if we have agreed how to
# specify signed ints.
raise NotImplementedError(
f"Evaluation of {expr.op} not supported in ClExpr ",
"by pytket-cutensornet.",
Expand Down Expand Up @@ -231,4 +257,5 @@ def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int
def from_little_endian(bitstring: list[bool]) -> int:
"""Obtain the integer from the little-endian encoded bitstring (i.e. bitstring
[False, True] is interpreted as the integer 2)."""
# TODO: Assumes unisigned integer. What are the specs for signed integers?
return sum(1 << i for i, b in enumerate(bitstring) if b)
87 changes: 63 additions & 24 deletions tests/test_structured_state_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
Bit,
if_not_bit,
reg_eq,
WiredClExpr,
ClExpr,
ClOp,
)
from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp
from pytket.circuit.clexpr import wired_clexpr_from_logic_exp
Expand All @@ -36,11 +39,11 @@ def test_circuit_with_clexpr_i() -> None:
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
circ.H(0)
wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(a | b, c.to_list())
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list())
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list())
circ.add_clexpr(wexpr, args, condition=a[4])
circ.H(0)
circ.Measure(Qubit(0), d[4])
Expand All @@ -66,9 +69,9 @@ def test_circuit_with_classicalexpbox_i() -> None:
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
circ.H(0)
circ.add_classicalexpbox_register(a | b, c) # type: ignore
circ.add_classicalexpbox_register(c | b, d) # type: ignore
circ.add_classicalexpbox_register(c | b, d, condition=a[4]) # type: ignore
circ.add_classicalexpbox_register(a | b, c.to_list())
circ.add_classicalexpbox_register(c | b, d.to_list())
circ.add_classicalexpbox_register(c | b, d.to_list(), condition=a[4])
circ.H(0)
circ.Measure(Qubit(0), d[4])
circ.H(1)
Expand All @@ -93,11 +96,11 @@ def test_circuit_with_clexpr_ii() -> None:
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
circ.H(0)
wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(a | b, c.to_list())
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list())
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list())
circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4]))
circ.H(0)
circ.Measure(Qubit(0), d[4])
Expand All @@ -123,11 +126,9 @@ def test_circuit_with_classicalexpbox_ii() -> None:
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
circ.H(0)
circ.add_classicalexpbox_register(a | b, c) # type: ignore
circ.add_classicalexpbox_register(c | b, d) # type: ignore
circ.add_classicalexpbox_register(
c | b, d, condition=if_not_bit(a[4]) # type: ignore
)
circ.add_classicalexpbox_register(a | b, c.to_list())
circ.add_classicalexpbox_register(c | b, d.to_list())
circ.add_classicalexpbox_register(c | b, d.to_list(), condition=if_not_bit(a[4]))
circ.H(0)
circ.Measure(Qubit(0), d[4])
circ.H(1)
Expand Down Expand Up @@ -160,9 +161,9 @@ def test_circuit_with_clexpr_iii() -> None:
big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8]
circ.H(0, condition=big_exp)

wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c.to_list())
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e.to_list())
circ.add_clexpr(wexpr, args)

with CuTensorNetHandle() as libhandle:
Expand Down Expand Up @@ -190,8 +191,8 @@ def test_circuit_with_classicalexpbox_iii() -> None:
big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8]
circ.H(0, condition=big_exp)

circ.add_classicalexpbox_register(a + b - d, c) # type: ignore
circ.add_classicalexpbox_register(a * b * d * c, e) # type: ignore
circ.add_classicalexpbox_register(a + b - d, c.to_list())
circ.add_classicalexpbox_register(a * b * d * c, e.to_list())

with CuTensorNetHandle() as libhandle:
cfg = Config()
Expand Down Expand Up @@ -268,7 +269,7 @@ def test_circuit_with_conditional_gate_iv() -> None:
assert state.get_fidelity() == 1.0


def test_pytket_qir_conditional_8() -> None:
def test_pytket_basic_conditional_i() -> None:
c = Circuit(4)
c.H(0)
c.H(1)
Expand All @@ -287,7 +288,7 @@ def test_pytket_qir_conditional_8() -> None:
assert state.get_fidelity() == 1.0


def test_pytket_qir_conditional_9() -> None:
def test_pytket_basic_conditional_ii() -> None:
c = Circuit(4)
c.X(0)
c.Y(1)
Expand All @@ -306,7 +307,7 @@ def test_pytket_qir_conditional_9() -> None:
assert state.get_fidelity() == 1.0


def test_pytket_qir_conditional_10() -> None:
def test_pytket_basic_conditional_iii_classicalexpbox() -> None:
box_circ = Circuit(4)
box_circ.X(0)
box_circ.Y(1)
Expand All @@ -315,7 +316,7 @@ def test_pytket_qir_conditional_10() -> None:
box_c = box_circ.add_c_register("c", 5)

box_circ.H(0)
box_circ.add_classicalexpbox_register(box_c | box_c, box_c) # type: ignore
box_circ.add_classicalexpbox_register(box_c | box_c, box_c.to_list())

cbox = CircBox(box_circ)
d = Circuit(4, 5)
Expand All @@ -330,7 +331,7 @@ def test_pytket_qir_conditional_10() -> None:
assert state.get_fidelity() == 1.0


def test_pytket_qir_conditional_11() -> None:
def test_pytket_basic_conditional_iii_clexpr() -> None:
box_circ = Circuit(4)
box_circ.X(0)
box_circ.Y(1)
Expand All @@ -340,7 +341,7 @@ def test_pytket_qir_conditional_11() -> None:

box_circ.H(0)

wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c) # type: ignore
wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c.to_list())
box_circ.add_clexpr(wexpr, args)

cbox = CircBox(box_circ)
Expand Down Expand Up @@ -697,3 +698,41 @@ def test_repeat_until_success_ii_classicalexpblox() -> None:
assert np.isclose(abs(global_phase), 1.0)
output_state *= global_phase
assert np.allclose(target_state, output_state)


def test_clexpr_on_regs() -> None:
"""Non-exhaustive test on some ClOp on registers."""
circ = Circuit(2)
a = circ.add_c_register("a", 5)
b = circ.add_c_register("b", 5)
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
e = circ.add_c_register("e", 5)

w_expr_regone = WiredClExpr(ClExpr(ClOp.RegOne, []), output_posn=list(range(5)))
circ.add_clexpr(w_expr_regone, a.to_list()) # a = 0b11111 = 31
circ.add_c_setbits([True, True, False, False, False], b.to_list()) # b = 3
circ.add_c_setbits([False, True, False, True, False], c.to_list()) # c = 10
circ.add_clexpr(*wired_clexpr_from_logic_exp(b | c, d.to_list())) # d = 11
circ.add_clexpr(*wired_clexpr_from_logic_exp(a - d, e.to_list())) # e = 20

with CuTensorNetHandle() as libhandle:
cfg = Config()

state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0

# Check the bits
bits_dict = state.get_bits()
a_bitstring = list(bits_dict[bit] for bit in a)
assert all(a_bitstring) # a = 0b11111
b_bitstring = list(bits_dict[bit] for bit in b)
assert b_bitstring == [True, True, False, False, False] # b = 0b11000
c_bitstring = list(bits_dict[bit] for bit in c)
assert c_bitstring == [False, True, False, True, False] # c = 0b01010
d_bitstring = list(bits_dict[bit] for bit in d)
assert d_bitstring == [True, True, False, True, False] # d = 0b11010
e_bitstring = list(bits_dict[bit] for bit in e)
assert e_bitstring == [False, False, True, False, True] # e = 0b00101

0 comments on commit 5da5d19

Please sign in to comment.