From 9418a96a247a671f505f0444af8451003db4e5ac Mon Sep 17 00:00:00 2001 From: Remmy Zen Date: Thu, 6 Jun 2024 09:54:27 +0200 Subject: [PATCH] Add MaxRL algorithm to VCS and IFTLSP --- .../02 - Verification Circuit Synthesis.ipynb | 15 +- ...t-Tolerant Logical State Preparation.ipynb | 180 ------------------ .../envs/ft_logical_state_preparation_env.py | 61 ++++-- rlftqc/envs/logical_state_preparation_env.py | 2 + .../verification_circuit_synthesis_env.py | 52 +++-- rlftqc/ft_logical_state_preparation.py | 5 +- rlftqc/verification_circuit_synthesis.py | 6 +- 7 files changed, 107 insertions(+), 214 deletions(-) diff --git a/notebooks/02 - Verification Circuit Synthesis.ipynb b/notebooks/02 - Verification Circuit Synthesis.ipynb index 905a696..346cd29 100644 --- a/notebooks/02 - Verification Circuit Synthesis.ipynb +++ b/notebooks/02 - Verification Circuit Synthesis.ipynb @@ -92,7 +92,10 @@ "cell_type": "code", "execution_count": null, "id": "67e00217-a5a4-45f5-bac9-6e3ae09bc284", - "metadata": {}, + "metadata": { + "scrolled": true, + "tags": [] + }, "outputs": [], "source": [ "vcs.run()" @@ -249,7 +252,7 @@ "gates = [cliff_gates.cx, cliff_gates.cz]\n", "plus_ancilla_position = [5,6]\n", "\n", - "vcs = VerificationCircuitSynthesis(circ, num_ancillas = 2, gates=gates, plus_ancilla_position = plus_ancilla_position, gates_between_ancilla = False)\n", + "vcs = VerificationCircuitSynthesis(circ, num_ancillas = 2, gates=gates, plus_ancilla_position = plus_ancilla_position, gates_between_ancilla = False, use_max_reward = False)\n", "## Need to change training config such that the agent explore more.\n", "vcs.training_config[\"TOTAL_TIMESTEPS\"] = 1e6\n", "vcs.training_config[\"LR\"] = 5e-4\n", @@ -288,6 +291,14 @@ "## Log the training process\n", "vcs.log()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5f95696-0151-4171-b826-70f7385e2db7", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/03 - Integrated Fault-Tolerant Logical State Preparation.ipynb b/notebooks/03 - Integrated Fault-Tolerant Logical State Preparation.ipynb index d1d6504..a586676 100644 --- a/notebooks/03 - Integrated Fault-Tolerant Logical State Preparation.ipynb +++ b/notebooks/03 - Integrated Fault-Tolerant Logical State Preparation.ipynb @@ -157,186 +157,6 @@ "## One can also customize the folder name to save log\n", "ftlsp.log(results_folder_name='logs')" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "63b30d72-d7c7-4f5b-ac28-8ce70f2cf86d", - "metadata": {}, - "outputs": [], - "source": [ - "from rlftqc.logical_state_preparation import LogicalStatePreparation\n", - "from rlftqc.simulators.clifford_gates import CliffordGates\n", - "\n", - "## Define the target stabilizers\n", - "## For example, zero logical of 5 qubit perfect code.\n", - "target = [\n", - " \"+ZZZZZ\",\n", - " \"+IXZZX\",\n", - " \"+XZZXI\",\n", - " \"+ZZXIX\",\n", - " \"+ZXIXZ\"]\n", - "\n", - "## Specify gates\n", - "cliff_gates = CliffordGates(5)\n", - "gates = [cliff_gates.s, cliff_gates.cx, cliff_gates.sqrt_x, cliff_gates.x]\n", - "\n", - "## Create next-nearest neighbors connectivity graph\n", - "graph = []\n", - "for ii in range(4):\n", - " graph.append((ii, ii+1))\n", - " graph.append((ii+1, ii))\n", - "print(graph)\n", - " \n", - "## Create class\n", - "lsp = LogicalStatePreparation(target, gates=gates, graph=graph)" - ] - }, - { - "cell_type": "markdown", - "id": "1574325c-644d-4636-a319-44daac85a641", - "metadata": {}, - "source": [ - "We now train the agent. It takes around 60 seconds to train. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35267c81-8167-4771-a458-58c891331ff8", - "metadata": {}, - "outputs": [], - "source": [ - "lsp.train()" - ] - }, - { - "cell_type": "markdown", - "id": "c809ca77-a608-4b37-bd08-cd8a48e2b51e", - "metadata": {}, - "source": [ - "Run the agent and get the prepared circuit" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec17c2ec-b288-4e90-84bb-9cca5c9d0f71", - "metadata": {}, - "outputs": [], - "source": [ - "lsp.run()" - ] - }, - { - "cell_type": "markdown", - "id": "dcd0b753-64c0-4031-abf9-e8cabc0dbadd", - "metadata": {}, - "source": [ - "We can also log the result to check the training convergence." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e5110f6-ff4f-4f48-a87a-596ef3f6b90b", - "metadata": {}, - "outputs": [], - "source": [ - "lsp.log()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "66d471ba-dbd5-4ea2-abf7-94d806f50b09", - "metadata": {}, - "outputs": [], - "source": [ - "from rlftqc.logical_state_preparation import LogicalStatePreparation\n", - "\n", - "## Define the target stabilizers\n", - "## For example, zero logical of 7 qubit Steane code.\n", - "target = [\"ZZZZZZZ\",\n", - " \"ZIZIZIZ\",\n", - " \"XIXIXIX\",\n", - " \"IZZIIZZ\",\n", - " \"IXXIIXX\",\n", - " \"IIIZZZZ\",\n", - " \"IIIXXXX\",\n", - " ]\n", - "\n", - "## Create class\n", - "lsp = LogicalStatePreparation(target)\n" - ] - }, - { - "cell_type": "markdown", - "id": "c724948a-6fdf-4177-82b5-3ad7de6ac3b1", - "metadata": {}, - "source": [ - "Change the number of possible gates for training with the max_steps." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44286fcf-cd66-4483-abd5-91553fa82e44", - "metadata": {}, - "outputs": [], - "source": [ - "lsp = LogicalStatePreparation(target, max_steps = 100)" - ] - }, - { - "cell_type": "markdown", - "id": "31a9429b-c791-4019-86cc-2693b5808c82", - "metadata": {}, - "source": [ - "Change seed for training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef558728-d15f-4796-9e44-6895888630ba", - "metadata": {}, - "outputs": [], - "source": [ - "lsp = LogicalStatePreparation(target, seed = 123)" - ] - }, - { - "cell_type": "markdown", - "id": "538841d6-f720-4a13-bec2-acf25205f00c", - "metadata": {}, - "source": [ - "For more advanced training configurations, we can change the training config." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "760f5a6d-789a-4b91-ba13-91d9be0365ce", - "metadata": {}, - "outputs": [], - "source": [ - "lsp.training_config" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "314535ff-127e-4330-9704-c2896f71115b", - "metadata": {}, - "outputs": [], - "source": [ - "# NUM_AGENTS change the number of parallel agents to train (default: 1).\n", - "lsp.training_config['NUM_AGENTS'] = 5\n", - "\n", - "# TOTAL_TIMESTEPS change the number of total timesteps for training (default: 5e5), increase this for longer training.\n", - "lsp.training_config['TOTAL_TIMESTEPS'] = 1e7\n" - ] } ], "metadata": { diff --git a/rlftqc/envs/ft_logical_state_preparation_env.py b/rlftqc/envs/ft_logical_state_preparation_env.py index c83ab0b..aeffb96 100644 --- a/rlftqc/envs/ft_logical_state_preparation_env.py +++ b/rlftqc/envs/ft_logical_state_preparation_env.py @@ -22,7 +22,9 @@ class EnvState: previous_product_ancilla: float number_of_errors: int time: int - + max_diff: float + + @struct.dataclass class EnvParams: max_steps: int = 50 @@ -55,6 +57,8 @@ class FTLogicalStatePreparationEnv(environment.Environment): cz_ancilla_only, boolean, optional): If true, then CZ only applied in the ancilla. (Default: False) distance_metric (str, optional): Distance metric to use for the complementary distance reward. Currently only support 'hamming' or 'jaccard' (default). + use_max_reward (boolean, optional): Whether to use MAX RL algorithm. + """ def __init__(self, @@ -79,7 +83,8 @@ def __init__(self, group_ancillas = False, cz_ancilla_only = False, plus_ancilla_position = [], - distance_metric = 'jaccard' + distance_metric = 'jaccard', + use_max_reward = False ): """ Initialize a integrated fault-tolerant logical state preparation environment. """ @@ -89,6 +94,7 @@ def __init__(self, self.target = target self.target_sign = [] self.n_qubits_physical_encoding = len(self.target) + target_wo_sign = [] for stabs in self.target: ## Process sign @@ -202,6 +208,8 @@ def __init__(self, self.two_qubit_errors = list(itertools.product('IXYZ', repeat=2))[1:] self.initial_propagated_errors = PauliString(self.n_qubits_physical, 15 * self.max_steps) + self.use_max_reward = use_max_reward + self.actions = self.action_matrix() self.obs_shape = self.get_observation(self.initial_tableau_with_ancillas.current_tableau[0]).flatten().shape[0] + self.n_qubits_physical @@ -717,19 +725,31 @@ def step_env(self, key: chex.PRNGKey, state: EnvState, action: int, params=None) self.weight_flag * (current_flagged_errors - state.previous_flagged_errors) + \ self.weight_distance * (current_distance - state.previous_distance) - state = EnvState(new_state.astype(jnp.uint8), new_sign.astype(jnp.uint8), new_propagated_error.astype(jnp.uint8), current_flagged_errors, current_distance, current_product_ancilla, new_number_of_errors, state.time + 1) + + new_max_diff = jnp.max(jnp.array([0.0, state.max_diff - reward])) + reward_adapted = jnp.max(jnp.array([0.0, reward - state.max_diff])) + + state = EnvState(new_state.astype(jnp.uint8), new_sign.astype(jnp.uint8), new_propagated_error.astype(jnp.uint8), current_flagged_errors, current_distance, current_product_ancilla, new_number_of_errors, state.time + 1, new_max_diff) # Evaluate termination conditions done = self.is_terminal(state) - - return ( - jax.lax.stop_gradient(self.get_obs(state)), - jax.lax.stop_gradient(state), - reward, - done, - {"discount": self.discount(state, params)} - ) + if self.use_max_reward: + return ( + jax.lax.stop_gradient(self.get_obs(state)), + jax.lax.stop_gradient(state), + reward_adapted, + done, + {"discount": self.discount(state, params)} + ) + else: + return ( + jax.lax.stop_gradient(self.get_obs(state)), + jax.lax.stop_gradient(state), + reward, + done, + {"discount": self.discount(state, params)} + ) def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[chex.Array, EnvState]: """Performs resetting of environment. @@ -764,7 +784,8 @@ def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[chex.Array, E previous_distance=previous_distance, previous_product_ancilla=previous_product_ancilla, number_of_errors = 0, - time = 0 + time = 0, + max_diff = 0.0 ) return self.get_obs(state), state @@ -780,8 +801,11 @@ def get_obs(self, state: EnvState, params: Optional[EnvParams] = EnvParams) -> c Observations by appending the tableau and the sign """ obs_tab, obs_sign = self.canonical_stabilizers(self.get_observation(state.tableau), state.sign[self.n_qubits_physical:] * 2) - return jnp.append(obs_tab.flatten(), obs_sign // 2) - + if self.use_max_reward: + return jnp.append(jnp.append(obs_tab.flatten(), obs_sign // 2), state.max_diff) + else: + return jnp.append(obs_tab.flatten(), obs_sign // 2) + def is_terminal(self, state: EnvState, params=None) -> bool: """Check whether state is terminal. @@ -832,7 +856,8 @@ def copy(self): self.group_ancillas, self.cz_ancilla_only, self.plus_ancilla_position, - self.distance_metric + self.distance_metric, + self.use_max_reward ) @property @@ -851,7 +876,11 @@ def action_space(self, params: Optional[EnvParams] = EnvParams) -> spaces.Discre def observation_space(self, params: EnvParams) -> spaces.Box: """Observation space of the environment.""" - return spaces.Box(0, 1, self.obs_shape, dtype=jnp.uint8) + if self.use_max_reward: + ## Add x for max rl + return spaces.Box(0, 1, self.obs_shape + 1, dtype=jnp.uint8) + else: + return spaces.Box(0, 1, self.obs_shape, dtype=jnp.uint8) def state_space(self, params: EnvParams) -> spaces.Dict: """State space of the environment.""" diff --git a/rlftqc/envs/logical_state_preparation_env.py b/rlftqc/envs/logical_state_preparation_env.py index effed4f..31d9a2f 100644 --- a/rlftqc/envs/logical_state_preparation_env.py +++ b/rlftqc/envs/logical_state_preparation_env.py @@ -40,6 +40,8 @@ class LogicalStatePreparationEnv(environment.Environment): threshold (float, optional): The complementary distance threshold to indicates success. Default: 0.99 initialize_plus (list(int), optional): Initialize qubits given in the list as plus state instead of zero state. This is useful for large CSS codes or CZ is used in the gate set. + use_max_reward (boolean, optional): Whether to use MAX RL algorithm. + """ def __init__(self, target, diff --git a/rlftqc/envs/verification_circuit_synthesis_env.py b/rlftqc/envs/verification_circuit_synthesis_env.py index 8f058fa..99abfc8 100644 --- a/rlftqc/envs/verification_circuit_synthesis_env.py +++ b/rlftqc/envs/verification_circuit_synthesis_env.py @@ -22,6 +22,7 @@ class EnvState: previous_distance: float number_of_errors: int time: int + max_diff: float @struct.dataclass class EnvParams: @@ -54,6 +55,7 @@ class VerificationCircuitSynthesisEnv(environment.Environment): group_ancillas (boolean, optional): If set to True, this will group ancilla into two. Useful to replicate the protocol in Chamberland and Chao original paper. For example: If there are 4 flag qubits, there will be no two-qubit gates between flag qubits 1,2 and 3,4. plus_ancilla_position (list(int), optional): Initialize flag qubits given in the list as plus state and will measure in the X basis. This is useful for non-CSS codes. + use_max_reward (boolean, optional): Whether to use MAX RL algorithm. """ def __init__(self, @@ -77,7 +79,8 @@ def __init__(self, gates_between_ancilla = True, gates_between_data = False, group_ancillas = False, - plus_ancilla_position = [] + plus_ancilla_position = [], + use_max_reward = False ): """Initialize a verification circuit synthesis environment. """ @@ -205,6 +208,8 @@ def __init__(self, ## Generate observation self.obs_shape = self.get_observation(self.initial_tableau_with_ancillas.current_tableau[0]).flatten().shape + self.use_max_reward = use_max_reward + def stim_tableau_to_numpy(self, stim_tableau, num_ancillas = 0): ''' Convert stim tableau to proper numpy tableau for our simulator. @@ -800,20 +805,32 @@ def step_env(self, key: chex.PRNGKey, state: EnvState, action: int, params=None) reward = self.weight_ancillas * (current_product_ancilla - state.previous_product_ancilla) + \ self.weight_flag * (current_flagged_errors - state.previous_flagged_errors) + \ self.weight_distance * (current_distance - state.previous_distance) + + new_max_diff = jnp.max(jnp.array([0.0, state.max_diff - reward])) + reward_adapted = jnp.max(jnp.array([0.0, reward - state.max_diff])) - state = EnvState(new_state.astype(jnp.uint8), new_sign.astype(jnp.uint8), new_propagated_error.astype(jnp.uint8), current_flagged_errors, current_product_ancilla, current_distance, new_number_of_errors, state.time + 1) + state = EnvState(new_state.astype(jnp.uint8), new_sign.astype(jnp.uint8), new_propagated_error.astype(jnp.uint8), current_flagged_errors, current_product_ancilla, current_distance, new_number_of_errors, state.time + 1, new_max_diff) # Evaluate termination conditions done = self.is_terminal(state) - return ( - jax.lax.stop_gradient(self.get_obs(state)), - jax.lax.stop_gradient(state), - reward, - done, - {"discount": self.discount(state, params)}, - ) + if self.use_max_reward: + return ( + jax.lax.stop_gradient(self.get_obs(state)), + jax.lax.stop_gradient(state), + reward_adapted, + done, + {"discount": self.discount(state, params), "max_reward": reward_adapted}, + ) + else: + return ( + jax.lax.stop_gradient(self.get_obs(state)), + jax.lax.stop_gradient(state), + reward, + done, + {"discount": self.discount(state, params), "max_reward": reward_adapted}, + ) def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[chex.Array, EnvState]: """Performs resetting of environment. @@ -846,7 +863,8 @@ def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[chex.Array, E previous_product_ancilla=previous_product_ancilla, previous_distance=previous_distance, number_of_errors = self.number_of_initial_errors, - time = 0 + time = 0, + max_diff = 0.0 ) return self.get_obs(state), state @@ -881,7 +899,10 @@ def get_obs(self, state: EnvState) -> chex.Array: Returns: Observations by appending the tableau and the sign """ - return self.get_observation(state.tableau).flatten() + if self.use_max_reward: + return jnp.append(self.get_observation(state.tableau).flatten(), state.max_diff) + else: + return self.get_observation(state.tableau).flatten() def copy(self): """ Copy environment. """ @@ -906,7 +927,8 @@ def copy(self): self.gates_between_ancilla, self.gates_between_data, self.group_ancillas, - self.plus_ancilla_position + self.plus_ancilla_position, + self.use_max_reward ) @property @@ -925,7 +947,11 @@ def action_space(self, params: Optional[EnvParams] = EnvParams) -> spaces.Discre def observation_space(self, params: EnvParams) -> spaces.Box: """Observation space of the environment.""" - return spaces.Box(0, 1, self.obs_shape, dtype=jnp.uint8) + if self.use_max_reward: + ## Add x for max rl + return spaces.Box(0, 1, self.obs_shape[0] + 1, dtype=jnp.uint8) + else: + return spaces.Box(0, 1, self.obs_shape, dtype=jnp.uint8) def state_space(self, params: EnvParams) -> spaces.Dict: """State space of the environment.""" diff --git a/rlftqc/ft_logical_state_preparation.py b/rlftqc/ft_logical_state_preparation.py index b0aa299..006ab50 100644 --- a/rlftqc/ft_logical_state_preparation.py +++ b/rlftqc/ft_logical_state_preparation.py @@ -40,6 +40,7 @@ class FTLogicalStatePreparation: This is useful for non-CSS codes. distance_metric (str, optional): Distance metric to use for the complementary distance reward. Currently only support 'hamming' or 'jaccard' (default). + use_max_reward (boolean, optional): Whether to use MAX RL algorithm. training_config (optional): Training configuration. seed (int, optional): Random seed (default: 42) """ @@ -66,6 +67,7 @@ def __init__(self, cz_ancilla_only = False, plus_ancilla_position = [], distance_metric = 'jaccard', + use_max_reward = False, training_config = None, seed = 42): """ Initialize a integrated fault-tolerant state preparation task. """ @@ -92,7 +94,8 @@ def __init__(self, group_ancillas, cz_ancilla_only, plus_ancilla_position, - distance_metric) + distance_metric, + use_max_reward) self.seed = seed diff --git a/rlftqc/verification_circuit_synthesis.py b/rlftqc/verification_circuit_synthesis.py index af57030..f09ff08 100644 --- a/rlftqc/verification_circuit_synthesis.py +++ b/rlftqc/verification_circuit_synthesis.py @@ -62,7 +62,8 @@ def __init__(self, gates_between_ancilla = True, gates_between_data = False, group_ancillas = False, - plus_ancilla_position = [], + plus_ancilla_position = [], + use_max_reward = False, training_config = None, seed = 42): """ Initialize a verification circuit synthesis task. """ @@ -91,7 +92,8 @@ def __init__(self, gates_between_ancilla, gates_between_data, group_ancillas, - plus_ancilla_position) + plus_ancilla_position, + use_max_reward) self.seed = seed ## Get the agent