Skip to content

Commit

Permalink
Add MaxRL algorithm to VCS and IFTLSP
Browse files Browse the repository at this point in the history
  • Loading branch information
Remmy Zen committed Jun 6, 2024
1 parent 862eccb commit 9418a96
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 214 deletions.
15 changes: 13 additions & 2 deletions notebooks/02 - Verification Circuit Synthesis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@
"cell_type": "code",
"execution_count": null,
"id": "67e00217-a5a4-45f5-bac9-6e3ae09bc284",
"metadata": {},
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"vcs.run()"
Expand Down Expand Up @@ -249,7 +252,7 @@
"gates = [cliff_gates.cx, cliff_gates.cz]\n",
"plus_ancilla_position = [5,6]\n",
"\n",
"vcs = VerificationCircuitSynthesis(circ, num_ancillas = 2, gates=gates, plus_ancilla_position = plus_ancilla_position, gates_between_ancilla = False)\n",
"vcs = VerificationCircuitSynthesis(circ, num_ancillas = 2, gates=gates, plus_ancilla_position = plus_ancilla_position, gates_between_ancilla = False, use_max_reward = False)\n",
"## Need to change training config such that the agent explore more.\n",
"vcs.training_config[\"TOTAL_TIMESTEPS\"] = 1e6\n",
"vcs.training_config[\"LR\"] = 5e-4\n",
Expand Down Expand Up @@ -288,6 +291,14 @@
"## Log the training process\n",
"vcs.log()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5f95696-0151-4171-b826-70f7385e2db7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,186 +157,6 @@
"## One can also customize the folder name to save log\n",
"ftlsp.log(results_folder_name='logs')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63b30d72-d7c7-4f5b-ac28-8ce70f2cf86d",
"metadata": {},
"outputs": [],
"source": [
"from rlftqc.logical_state_preparation import LogicalStatePreparation\n",
"from rlftqc.simulators.clifford_gates import CliffordGates\n",
"\n",
"## Define the target stabilizers\n",
"## For example, zero logical of 5 qubit perfect code.\n",
"target = [\n",
" \"+ZZZZZ\",\n",
" \"+IXZZX\",\n",
" \"+XZZXI\",\n",
" \"+ZZXIX\",\n",
" \"+ZXIXZ\"]\n",
"\n",
"## Specify gates\n",
"cliff_gates = CliffordGates(5)\n",
"gates = [cliff_gates.s, cliff_gates.cx, cliff_gates.sqrt_x, cliff_gates.x]\n",
"\n",
"## Create next-nearest neighbors connectivity graph\n",
"graph = []\n",
"for ii in range(4):\n",
" graph.append((ii, ii+1))\n",
" graph.append((ii+1, ii))\n",
"print(graph)\n",
" \n",
"## Create class\n",
"lsp = LogicalStatePreparation(target, gates=gates, graph=graph)"
]
},
{
"cell_type": "markdown",
"id": "1574325c-644d-4636-a319-44daac85a641",
"metadata": {},
"source": [
"We now train the agent. It takes around 60 seconds to train. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35267c81-8167-4771-a458-58c891331ff8",
"metadata": {},
"outputs": [],
"source": [
"lsp.train()"
]
},
{
"cell_type": "markdown",
"id": "c809ca77-a608-4b37-bd08-cd8a48e2b51e",
"metadata": {},
"source": [
"Run the agent and get the prepared circuit"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec17c2ec-b288-4e90-84bb-9cca5c9d0f71",
"metadata": {},
"outputs": [],
"source": [
"lsp.run()"
]
},
{
"cell_type": "markdown",
"id": "dcd0b753-64c0-4031-abf9-e8cabc0dbadd",
"metadata": {},
"source": [
"We can also log the result to check the training convergence."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e5110f6-ff4f-4f48-a87a-596ef3f6b90b",
"metadata": {},
"outputs": [],
"source": [
"lsp.log()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66d471ba-dbd5-4ea2-abf7-94d806f50b09",
"metadata": {},
"outputs": [],
"source": [
"from rlftqc.logical_state_preparation import LogicalStatePreparation\n",
"\n",
"## Define the target stabilizers\n",
"## For example, zero logical of 7 qubit Steane code.\n",
"target = [\"ZZZZZZZ\",\n",
" \"ZIZIZIZ\",\n",
" \"XIXIXIX\",\n",
" \"IZZIIZZ\",\n",
" \"IXXIIXX\",\n",
" \"IIIZZZZ\",\n",
" \"IIIXXXX\",\n",
" ]\n",
"\n",
"## Create class\n",
"lsp = LogicalStatePreparation(target)\n"
]
},
{
"cell_type": "markdown",
"id": "c724948a-6fdf-4177-82b5-3ad7de6ac3b1",
"metadata": {},
"source": [
"Change the number of possible gates for training with the max_steps."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44286fcf-cd66-4483-abd5-91553fa82e44",
"metadata": {},
"outputs": [],
"source": [
"lsp = LogicalStatePreparation(target, max_steps = 100)"
]
},
{
"cell_type": "markdown",
"id": "31a9429b-c791-4019-86cc-2693b5808c82",
"metadata": {},
"source": [
"Change seed for training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef558728-d15f-4796-9e44-6895888630ba",
"metadata": {},
"outputs": [],
"source": [
"lsp = LogicalStatePreparation(target, seed = 123)"
]
},
{
"cell_type": "markdown",
"id": "538841d6-f720-4a13-bec2-acf25205f00c",
"metadata": {},
"source": [
"For more advanced training configurations, we can change the training config."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "760f5a6d-789a-4b91-ba13-91d9be0365ce",
"metadata": {},
"outputs": [],
"source": [
"lsp.training_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "314535ff-127e-4330-9704-c2896f71115b",
"metadata": {},
"outputs": [],
"source": [
"# NUM_AGENTS change the number of parallel agents to train (default: 1).\n",
"lsp.training_config['NUM_AGENTS'] = 5\n",
"\n",
"# TOTAL_TIMESTEPS change the number of total timesteps for training (default: 5e5), increase this for longer training.\n",
"lsp.training_config['TOTAL_TIMESTEPS'] = 1e7\n"
]
}
],
"metadata": {
Expand Down
61 changes: 45 additions & 16 deletions rlftqc/envs/ft_logical_state_preparation_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class EnvState:
previous_product_ancilla: float
number_of_errors: int
time: int

max_diff: float


@struct.dataclass
class EnvParams:
max_steps: int = 50
Expand Down Expand Up @@ -55,6 +57,8 @@ class FTLogicalStatePreparationEnv(environment.Environment):
cz_ancilla_only, boolean, optional): If true, then CZ only applied in the ancilla. (Default: False)
distance_metric (str, optional): Distance metric to use for the complementary distance reward.
Currently only support 'hamming' or 'jaccard' (default).
use_max_reward (boolean, optional): Whether to use MAX RL algorithm.
"""

def __init__(self,
Expand All @@ -79,7 +83,8 @@ def __init__(self,
group_ancillas = False,
cz_ancilla_only = False,
plus_ancilla_position = [],
distance_metric = 'jaccard'
distance_metric = 'jaccard',
use_max_reward = False
):
""" Initialize a integrated fault-tolerant logical state preparation environment.
"""
Expand All @@ -89,6 +94,7 @@ def __init__(self,
self.target = target
self.target_sign = []
self.n_qubits_physical_encoding = len(self.target)

target_wo_sign = []
for stabs in self.target:
## Process sign
Expand Down Expand Up @@ -202,6 +208,8 @@ def __init__(self,
self.two_qubit_errors = list(itertools.product('IXYZ', repeat=2))[1:]
self.initial_propagated_errors = PauliString(self.n_qubits_physical, 15 * self.max_steps)

self.use_max_reward = use_max_reward

self.actions = self.action_matrix()

self.obs_shape = self.get_observation(self.initial_tableau_with_ancillas.current_tableau[0]).flatten().shape[0] + self.n_qubits_physical
Expand Down Expand Up @@ -717,19 +725,31 @@ def step_env(self, key: chex.PRNGKey, state: EnvState, action: int, params=None)
self.weight_flag * (current_flagged_errors - state.previous_flagged_errors) + \
self.weight_distance * (current_distance - state.previous_distance)

state = EnvState(new_state.astype(jnp.uint8), new_sign.astype(jnp.uint8), new_propagated_error.astype(jnp.uint8), current_flagged_errors, current_distance, current_product_ancilla, new_number_of_errors, state.time + 1)

new_max_diff = jnp.max(jnp.array([0.0, state.max_diff - reward]))
reward_adapted = jnp.max(jnp.array([0.0, reward - state.max_diff]))

state = EnvState(new_state.astype(jnp.uint8), new_sign.astype(jnp.uint8), new_propagated_error.astype(jnp.uint8), current_flagged_errors, current_distance, current_product_ancilla, new_number_of_errors, state.time + 1, new_max_diff)

# Evaluate termination conditions
done = self.is_terminal(state)


return (
jax.lax.stop_gradient(self.get_obs(state)),
jax.lax.stop_gradient(state),
reward,
done,
{"discount": self.discount(state, params)}
)
if self.use_max_reward:
return (
jax.lax.stop_gradient(self.get_obs(state)),
jax.lax.stop_gradient(state),
reward_adapted,
done,
{"discount": self.discount(state, params)}
)
else:
return (
jax.lax.stop_gradient(self.get_obs(state)),
jax.lax.stop_gradient(state),
reward,
done,
{"discount": self.discount(state, params)}
)

def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[chex.Array, EnvState]:
"""Performs resetting of environment.
Expand Down Expand Up @@ -764,7 +784,8 @@ def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[chex.Array, E
previous_distance=previous_distance,
previous_product_ancilla=previous_product_ancilla,
number_of_errors = 0,
time = 0
time = 0,
max_diff = 0.0
)

return self.get_obs(state), state
Expand All @@ -780,8 +801,11 @@ def get_obs(self, state: EnvState, params: Optional[EnvParams] = EnvParams) -> c
Observations by appending the tableau and the sign
"""
obs_tab, obs_sign = self.canonical_stabilizers(self.get_observation(state.tableau), state.sign[self.n_qubits_physical:] * 2)
return jnp.append(obs_tab.flatten(), obs_sign // 2)

if self.use_max_reward:
return jnp.append(jnp.append(obs_tab.flatten(), obs_sign // 2), state.max_diff)
else:
return jnp.append(obs_tab.flatten(), obs_sign // 2)

def is_terminal(self, state: EnvState, params=None) -> bool:
"""Check whether state is terminal.
Expand Down Expand Up @@ -832,7 +856,8 @@ def copy(self):
self.group_ancillas,
self.cz_ancilla_only,
self.plus_ancilla_position,
self.distance_metric
self.distance_metric,
self.use_max_reward
)

@property
Expand All @@ -851,7 +876,11 @@ def action_space(self, params: Optional[EnvParams] = EnvParams) -> spaces.Discre

def observation_space(self, params: EnvParams) -> spaces.Box:
"""Observation space of the environment."""
return spaces.Box(0, 1, self.obs_shape, dtype=jnp.uint8)
if self.use_max_reward:
## Add x for max rl
return spaces.Box(0, 1, self.obs_shape + 1, dtype=jnp.uint8)
else:
return spaces.Box(0, 1, self.obs_shape, dtype=jnp.uint8)

def state_space(self, params: EnvParams) -> spaces.Dict:
"""State space of the environment."""
Expand Down
2 changes: 2 additions & 0 deletions rlftqc/envs/logical_state_preparation_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class LogicalStatePreparationEnv(environment.Environment):
threshold (float, optional): The complementary distance threshold to indicates success. Default: 0.99
initialize_plus (list(int), optional): Initialize qubits given in the list as plus state instead of zero state.
This is useful for large CSS codes or CZ is used in the gate set.
use_max_reward (boolean, optional): Whether to use MAX RL algorithm.
"""
def __init__(self,
target,
Expand Down
Loading

0 comments on commit 9418a96

Please sign in to comment.