diff --git a/pytket/extensions/cutensornet/structured_state/general.py b/pytket/extensions/cutensornet/structured_state/general.py index 0cc5623..dcaed56 100644 --- a/pytket/extensions/cutensornet/structured_state/general.py +++ b/pytket/extensions/cutensornet/structured_state/general.py @@ -228,10 +228,6 @@ def _apply_command( except: raise ValueError(f"The command {op.type} introduced is not supported.") - # Load the gate's unitary to the GPU memory - unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False) - unitary = cp.asarray(unitary, dtype=self._cfg._complex_t) - if len(qubits) not in [1, 2]: raise ValueError( "Gates must act on only 1 or 2 qubits! " diff --git a/pytket/extensions/cutensornet/structured_state/mps.py b/pytket/extensions/cutensornet/structured_state/mps.py index b9aee6e..25835e8 100644 --- a/pytket/extensions/cutensornet/structured_state/mps.py +++ b/pytket/extensions/cutensornet/structured_state/mps.py @@ -18,6 +18,7 @@ from random import Random # type: ignore import numpy as np # type: ignore +from numpy.typing import NDArray # type: ignore try: import cupy as cp # type: ignore @@ -38,7 +39,12 @@ class DirMPS(Enum): - """An enum to refer to relative directions within the MPS.""" + """An enum to refer to relative directions within the MPS. + + When used to refer to the canonical form of a tensor, LEFT means that its conjugate + transpose is its inverse when connected to its left bond and physical bond. + Similarly for RIGHT. + """ LEFT = 0 RIGHT = 1 @@ -148,9 +154,7 @@ def is_valid(self) -> bool: return chi_ok and phys_ok and shape_ok and ds_ok - def apply_unitary( - self, unitary: cp.ndarray, qubits: list[Qubit] - ) -> StructuredState: + def apply_unitary(self, unitary: NDArray, qubits: list[Qubit]) -> StructuredState: """Applies the unitary to the specified qubits of the StructuredState. Note: @@ -158,8 +162,9 @@ def apply_unitary( not the case, the program will still run, but its behaviour is undefined. Args: - unitary: The matrix to be applied as a CuPy ndarray. It should either be - a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting on two. + unitary: The matrix to be applied as a NumPy or CuPy ndarray. It should + either be a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting + on two. qubits: The qubits the unitary acts on. Only one qubit and two qubit unitaries are supported. @@ -178,6 +183,11 @@ def apply_unitary( "See the documentation of update_libhandle and CuTensorNetHandle.", ) + if not isinstance(unitary, cp.ndarray): + # Load the gate's unitary to the GPU memory + unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False) + unitary = cp.asarray(unitary, dtype=self._cfg._complex_t) + self._logger.debug(f"Applying unitary {unitary} on {qubits}.") if len(qubits) == 1: diff --git a/pytket/extensions/cutensornet/structured_state/mps_gate.py b/pytket/extensions/cutensornet/structured_state/mps_gate.py index 4a7fa7f..322dabe 100644 --- a/pytket/extensions/cutensornet/structured_state/mps_gate.py +++ b/pytket/extensions/cutensornet/structured_state/mps_gate.py @@ -284,8 +284,8 @@ def _apply_2q_unitary_nonadjacent( optimize={"path": [(0, 1)]}, ) - # The site tensor is now in canonical form (since S is contracted to the right) - self.canonical_form[l_pos] = DirMPS.RIGHT # type: ignore + # The site tensor is now in canonical form + self.canonical_form[l_pos] = DirMPS.LEFT # type: ignore # Next, "push" the `msg_tensor` through all site tensors between `l_pos` # and `r_pos`. Once again, this is just contract_decompose on each. @@ -306,7 +306,7 @@ def _apply_2q_unitary_nonadjacent( ) # The site tensor is now in canonical form - self.canonical_form[pos] = DirMPS.RIGHT # type: ignore + self.canonical_form[pos] = DirMPS.LEFT # type: ignore # Finally, contract the `msg_tensor` with the site tensor in `r_pos` and the # `r_gate_tensor` from the decomposition of `gate_tensor` @@ -402,7 +402,7 @@ def _apply_2q_unitary_nonadjacent( # Since we are contracting S to the "left" in `svd_method`, the site tensor # at `pos+1` is canonicalised, whereas the site tensor at `pos` is the one # where S has been contracted to and, hence, is not in canonical form - self.canonical_form[pos + 1] = DirMPS.LEFT # type: ignore + self.canonical_form[pos + 1] = DirMPS.RIGHT # type: ignore self.canonical_form[pos] = None # Update fidelity lower bound this_fidelity = 1.0 - info.svd_info.discarded_weight diff --git a/pytket/extensions/cutensornet/structured_state/ttn.py b/pytket/extensions/cutensornet/structured_state/ttn.py index 3a1a412..2144471 100644 --- a/pytket/extensions/cutensornet/structured_state/ttn.py +++ b/pytket/extensions/cutensornet/structured_state/ttn.py @@ -19,6 +19,7 @@ from random import Random # type: ignore import math # type: ignore import numpy as np # type: ignore +from numpy.typing import NDArray # type: ignore try: import cupy as cp # type: ignore @@ -242,9 +243,7 @@ def is_valid(self) -> bool: ) return chi_ok and phys_ok and rank_ok and shape_ok - def apply_unitary( - self, unitary: cp.ndarray, qubits: list[Qubit] - ) -> StructuredState: + def apply_unitary(self, unitary: NDArray, qubits: list[Qubit]) -> StructuredState: """Applies the unitary to the specified qubits of the StructuredState. Note: @@ -252,8 +251,9 @@ def apply_unitary( not the case, the program will still run, but its behaviour is undefined. Args: - unitary: The matrix to be applied as a CuPy ndarray. It should either be - a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting on two. + unitary: The matrix to be applied as a NumPy or CuPy ndarray. It should + either be a 2x2 matrix if acting on one qubit or a 4x4 matrix if acting + on two. qubits: The qubits the unitary acts on. Only one qubit and two qubit unitaries are supported. @@ -272,6 +272,11 @@ def apply_unitary( "See the documentation of update_libhandle and CuTensorNetHandle.", ) + if not isinstance(unitary, cp.ndarray): + # Load the gate's unitary to the GPU memory + unitary = unitary.astype(dtype=self._cfg._complex_t, copy=False) + unitary = cp.asarray(unitary, dtype=self._cfg._complex_t) + self._logger.debug(f"Applying unitary {unitary} on {qubits}.") if len(qubits) == 1: