Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/mps arguments #35

Merged
merged 20 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/modules/mps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@ Matrix Product State (MPS)
Simulation
~~~~~~~~~~

.. autofunction:: pytket.extensions.cutensornet.mps.simulate

.. autoenum:: pytket.extensions.cutensornet.mps.ContractionAlg()
:members:

.. autofunction:: pytket.extensions.cutensornet.mps.simulate
.. autoclass:: pytket.extensions.cutensornet.mps.ConfigMPS()

.. automethod:: __init__

.. autoclass:: pytket.extensions.cutensornet.mps.CuTensorNetHandle


Classes
Expand Down Expand Up @@ -47,8 +53,6 @@ Classes

.. automethod:: __init__

.. autoclass:: pytket.extensions.cutensornet.mps.CuTensorNetHandle


Miscellaneous
~~~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion examples/mpi/mpi_overlap_bcast_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from pytket.extensions.cutensornet.mps import (
simulate,
ConfigMPS,
ContractionAlg,
CuTensorNetHandle,
)
Expand Down Expand Up @@ -108,7 +109,7 @@
this_proc_mps = []
with CuTensorNetHandle(device_id) as libhandle: # Different handle for each process
for circ in this_proc_circs:
mps = simulate(libhandle, circ, ContractionAlg.MPSxGate)
mps = simulate(libhandle, circ, ContractionAlg.MPSxGate, ConfigMPS())
this_proc_mps.append(mps)

if rank == root:
Expand Down
2,243 changes: 1,126 additions & 1,117 deletions examples/mps_tutorial.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pytket/extensions/cutensornet/mps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .mps import (
CuTensorNetHandle,
DirectionMPS,
ConfigMPS,
Handle,
Tensor,
MPS,
Expand Down
174 changes: 112 additions & 62 deletions pytket/extensions/cutensornet/mps/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,76 +85,68 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
self._is_destroyed = True


class MPS:
"""Represents a state as a Matrix Product State.

Attributes:
chi (int): The maximum allowed dimension of a virtual bond.
truncation_fidelity (float): The target fidelity of SVD truncation.
tensors (list[Tensor]): A list of tensors in the MPS; ``tensors[0]`` is
the leftmost and ``tensors[len(self)-1]`` is the rightmost; ``tensors[i]``
and ``tensors[i+1]`` are connected in the MPS via a bond. All of the
tensors are rank three, with the dimensions listed in ``.shape`` matching
the left, right and physical bonds, in that order.
canonical_form (dict[int, Optional[DirectionMPS]]): A dictionary mapping
positions to the canonical form direction of the corresponding tensor,
or ``None`` if it the tensor is not canonicalised.
qubit_position (dict[pytket.circuit.Qubit, int]): A dictionary mapping circuit
qubits to the position its tensor is at in the MPS.
fidelity (float): A lower bound of the fidelity, obtained by multiplying
the fidelities after each contraction. The fidelity of a contraction
corresponds to ``|<psi|phi>|^2`` where ``|psi>`` and ``|phi>`` are the
states before and after truncation (assuming both are normalised).
"""
class ConfigMPS:
"""Configuration class for simulation using MPS."""

def __init__(
self,
libhandle: CuTensorNetHandle,
qubits: list[Qubit],
chi: Optional[int] = None,
truncation_fidelity: Optional[float] = None,
float_precision: Optional[Union[np.float32, np.float64]] = None,
k: int = 4,
optim_delta: float = 1e-5,
yapolyak marked this conversation as resolved.
Show resolved Hide resolved
float_precision: Union[np.float32, np.float64] = np.float64, # type: ignore
value_of_zero: float = 1e-16,
loglevel: int = logging.WARNING,
):
"""Initialise an MPS on the computational state ``|0>``.
"""Instantiate a configuration object for MPS simulation.

Note:
A ``libhandle`` should be created via a ``with CuTensorNet() as libhandle:``
statement. The device where the MPS is stored will match the one specified
by the library handle.

Providing both a custom ``chi`` and ``truncation_fidelity`` will raise an
exception. Choose one or the other (or neither, for exact simulation).

Args:
libhandle: The cuTensorNet library handle that will be used to carry out
tensor operations on the MPS.
qubits: The list of qubits in the circuit to be simulated.
chi: The maximum value allowed for the dimension of the virtual
bonds. Higher implies better approximation but more
computational resources. If not provided, ``chi`` will be set
to ``2**(len(qubits) // 2)``, which is enough for exact contraction.
computational resources. If not provided, ``chi`` will be unbounded.
truncation_fidelity: Every time a two-qubit gate is applied, the virtual
bond will be truncated to the minimum dimension that satisfies
``|<psi|phi>|^2 >= trucantion_fidelity``, where ``|psi>`` and ``|phi>``
are the states before and after truncation (both normalised).
If not provided, it will default to its maximum value 1.
k: If using MPSxMPO, the maximum number of layers the MPO is allowed to
have before being contracted. Increasing this might increase fidelity,
but it will also increase resource requirements exponentially.
Ignored if not using MPSxMPO. Default value is 4.
optim_delta: If using MPSxMPO, stopping criteria for the optimisation when
contracting the ``k`` layers of MPO. Stops when the increase of fidelity
between iterations is smaller than ``optim_delta``.
Ignored if not using MPSxMPO. Default value is ``1e-5``.
float_precision: The floating point precision used in tensor calculations;
choose from ``numpy`` types: ``np.float64`` or ``np.float32``.
Complex numbers are represented using two of such
``float`` numbers. Default is ``np.float64``.
loglevel: Internal logger output level.
value_of_zero: Any number below this value will be considered equal to zero.
Even when no ``chi`` or ``truncation_fidelity`` is provided, singular
values below this number will be truncated.
We suggest to use a value slightly below what your chosen
``float_precision`` can reasonably achieve. For instance, ``1e-16`` for
``np.float64`` precision (default) and ``1e-7`` for ``np.float32``.
loglevel: Internal logger output level. Use 30 for warnings only, 20 for
verbose and 10 for debug mode.

Raises:
ValueError: If less than two qubits are provided.
ValueError: If both ``chi`` and ``truncation_fidelity`` are fixed.
ValueError: If the value of ``chi`` is set below 2.
ValueError: If the value of ``truncation_fidelity`` is not in [0,1].
"""
if chi is not None and truncation_fidelity is not None:
if (
chi is not None
and truncation_fidelity is not None
and truncation_fidelity != 1.0
):
raise ValueError("Cannot fix both chi and truncation_fidelity.")
if chi is None:
chi = max(2 ** (len(qubits) // 2), 2)
chi = 2**60 # In practice, this is like having it be unbounded
if truncation_fidelity is None:
truncation_fidelity = 1

Expand All @@ -163,6 +155,9 @@ def __init__(
if truncation_fidelity < 0 or truncation_fidelity > 1:
raise ValueError("Provide a value of truncation_fidelity in [0,1].")

self.chi = chi
self.truncation_fidelity = truncation_fidelity

if float_precision is None or float_precision == np.float64: # Double precision
self._real_t = np.float64 # type: ignore
self._complex_t = np.complex128 # type: ignore
Expand All @@ -176,16 +171,75 @@ def __init__(
raise TypeError(
f"Value of float_precision must be in {allowed_precisions}."
)
self.zero = value_of_zero

self._lib = libhandle
self._logger = set_logger("MPS", level=loglevel)
if value_of_zero > self._atol / 1000:
warnings.warn(
"Your chosen value_of_zero is relatively large. "
"Faithfulness of final fidelity estimate is not guaranteed.",
UserWarning,
)

#######################################
# Initialise the MPS with a |0> state #
#######################################
self.k = k
self.optim_delta = 1e-5
self.loglevel = loglevel

def copy(self) -> ConfigMPS:
"""Standard copy of the contents."""
return ConfigMPS(
chi=self.chi,
truncation_fidelity=self.truncation_fidelity,
k=self.k,
optim_delta=self.optim_delta,
float_precision=self._real_t, # type: ignore
yapolyak marked this conversation as resolved.
Show resolved Hide resolved
)

self.chi = chi
self.truncation_fidelity = truncation_fidelity

class MPS:
"""Represents a state as a Matrix Product State.

Attributes:
tensors (list[Tensor]): A list of tensors in the MPS; ``tensors[0]`` is
the leftmost and ``tensors[len(self)-1]`` is the rightmost; ``tensors[i]``
and ``tensors[i+1]`` are connected in the MPS via a bond. All of the
tensors are rank three, with the dimensions listed in ``.shape`` matching
the left, right and physical bonds, in that order.
canonical_form (dict[int, Optional[DirectionMPS]]): A dictionary mapping
positions to the canonical form direction of the corresponding tensor,
or ``None`` if it the tensor is not canonicalised.
qubit_position (dict[pytket.circuit.Qubit, int]): A dictionary mapping circuit
qubits to the position its tensor is at in the MPS.
fidelity (float): A lower bound of the fidelity, obtained by multiplying
the fidelities after each contraction. The fidelity of a contraction
corresponds to ``|<psi|phi>|^2`` where ``|psi>`` and ``|phi>`` are the
states before and after truncation (assuming both are normalised).
"""

def __init__(
self,
libhandle: CuTensorNetHandle,
qubits: list[Qubit],
config: ConfigMPS,
):
"""Initialise an MPS on the computational state ``|0>``.

Note:
A ``libhandle`` should be created via a ``with CuTensorNet() as libhandle:``
statement. The device where the MPS is stored will match the one specified
by the library handle.

Args:
libhandle: The cuTensorNet library handle that will be used to carry out
tensor operations on the MPS.
qubits: The list of qubits in the circuit to be simulated.
config: The object describing the configuration for simulation.

Raises:
ValueError: If less than two qubits are provided.
"""
self._lib = libhandle
self._cfg = config
self._logger = set_logger("MPS", level=config.loglevel)
self.fidelity = 1.0

n_tensors = len(qubits)
Expand All @@ -203,7 +257,7 @@ def __init__(
# Append each of the tensors initialised in state |0>
m_shape = (1, 1, 2) # Two virtual bonds (dim=1) and one physical
for i in range(n_tensors):
m_tensor = cp.empty(m_shape, dtype=self._complex_t)
m_tensor = cp.empty(m_shape, dtype=self._cfg._complex_t)
# Initialise the tensor to ket 0
m_tensor[0][0][0] = 1
m_tensor[0][0][1] = 0
Expand All @@ -222,7 +276,7 @@ def is_valid(self) -> bool:
self._flush()

chi_ok = all(
all(dim <= self.chi for dim in self.get_virtual_dimensions(pos))
all(dim <= self._cfg.chi for dim in self.get_virtual_dimensions(pos))
for pos in range(len(self))
)
phys_ok = all(self.get_physical_dimension(pos) == 2 for pos in range(len(self)))
Expand Down Expand Up @@ -526,7 +580,7 @@ def measure(self, qubits: set[Qubit]) -> dict[Qubit, int]:
self._logger.debug(f"Measuring qubits={position_qubit_map}")

# Tensor for postselection to |0>
zero_tensor = cp.zeros(2, dtype=self._complex_t)
zero_tensor = cp.zeros(2, dtype=self._cfg._complex_t)
zero_tensor[0] = 1

# Measure and postselect each of the positions, one by one
Expand Down Expand Up @@ -559,7 +613,7 @@ def measure(self, qubits: set[Qubit]) -> dict[Qubit, int]:
self._logger.debug(f"Outcome of qubit at {pos} is {outcome}.")

# Postselect the MPS for this outcome, renormalising at the same time
postselection_tensor = cp.zeros(2, dtype=self._complex_t)
postselection_tensor = cp.zeros(2, dtype=self._cfg._complex_t)
postselection_tensor[outcome] = 1 / np.sqrt(
abs(outcome - prob)
) # Normalise
Expand Down Expand Up @@ -605,18 +659,18 @@ def postselect(self, qubit_outcomes: dict[Qubit, int]) -> float:
# Apply a postselection for each of the qubits
for qubit, outcome in qubit_outcomes.items():
# Create the rank-1 postselection tensor
postselection_tensor = cp.zeros(2, dtype=self._complex_t)
postselection_tensor = cp.zeros(2, dtype=self._cfg._complex_t)
postselection_tensor[outcome] = 1
# Apply postselection
self._postselect_qubit(qubit, postselection_tensor)

# Calculate the squared norm of the postselected state; this is its probability
prob = self.vdot(self)
assert np.isclose(prob.imag, 0.0, atol=self._atol)
assert np.isclose(prob.imag, 0.0, atol=self._cfg._atol)
prob = prob.real

# Renormalise; it suffices to update the first tensor
if len(self) > 0 and not np.isclose(prob, 0.0, atol=self._atol):
if len(self) > 0 and not np.isclose(prob, 0.0, atol=self._cfg._atol):
self.tensors[0] = self.tensors[0] / np.sqrt(prob)
self.canonical_form[0] = None

Expand Down Expand Up @@ -696,8 +750,8 @@ def expectation_value(self, pauli_string: QubitPauliString) -> float:
pos = mps_copy.qubit_position[qubit]
pauli_unitary = Op.create(pauli_optype[pauli]).get_unitary()
pauli_tensor = cp.asarray(
pauli_unitary.astype(dtype=self._complex_t, copy=False),
dtype=self._complex_t,
pauli_unitary.astype(dtype=self._cfg._complex_t, copy=False),
dtype=self._cfg._complex_t,
)

# Contract the Pauli to the MPS tensor of the corresponding qubit
Expand All @@ -707,7 +761,7 @@ def expectation_value(self, pauli_string: QubitPauliString) -> float:

# Obtain the inner product
value = self.vdot(mps_copy)
assert np.isclose(value.imag, 0.0, atol=self._atol)
assert np.isclose(value.imag, 0.0, atol=self._cfg._atol)

self._logger.debug(f"Expectation value is {value.real}.")
return value.real
Expand Down Expand Up @@ -772,10 +826,10 @@ def get_amplitude(self, state: int) -> complex:
mps_pos_bitvalue[pos] = bitvalue

# Carry out the contraction, starting from a dummy tensor
result_tensor = cp.ones(1, dtype=self._complex_t) # rank-1, dimension 1
result_tensor = cp.ones(1, dtype=self._cfg._complex_t) # rank-1, dimension 1

for pos in range(len(self)):
postselection_tensor = cp.zeros(2, dtype=self._complex_t)
postselection_tensor = cp.zeros(2, dtype=self._cfg._complex_t)
postselection_tensor[mps_pos_bitvalue[pos]] = 1
# Contract postselection with qubit into the result_tensor
result_tensor = cq.contract(
Expand Down Expand Up @@ -868,16 +922,12 @@ def copy(self) -> MPS:
self._flush()

# Create a dummy object
new_mps = MPS(self._lib, qubits=[])
new_mps = MPS(self._lib, qubits=[], config=self._cfg.copy())
# Copy all data
new_mps.chi = self.chi
new_mps.truncation_fidelity = self.truncation_fidelity
new_mps.fidelity = self.fidelity
new_mps.tensors = [t.copy() for t in self.tensors]
new_mps.canonical_form = self.canonical_form.copy()
new_mps.qubit_position = self.qubit_position.copy()
new_mps._complex_t = self._complex_t
new_mps._real_t = self._real_t

self._logger.debug(
"Successfully copied an MPS "
Expand Down
Loading