From 1ad0eb3d36b672f5e050e58f0bd9af0363f6c0ac Mon Sep 17 00:00:00 2001 From: Pablo Andres-Martinez <104848389+PabloAndresCQ@users.noreply.github.com> Date: Wed, 3 Apr 2024 17:09:11 +0100 Subject: [PATCH] [bugfix] Default value of chi causes an error when state is copied (#93) * Now checking if chi is set to the default 'essentially unbounded' value; if so, do not raise error if non-default truncation fidelity is set. * Adding a test of copy state using different configurations. --- .../cutensornet/structured_state/general.py | 4 +- tests/test_structured_state.py | 38 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pytket/extensions/cutensornet/structured_state/general.py b/pytket/extensions/cutensornet/structured_state/general.py index 25d65876..fc862dc2 100644 --- a/pytket/extensions/cutensornet/structured_state/general.py +++ b/pytket/extensions/cutensornet/structured_state/general.py @@ -128,14 +128,16 @@ def __init__( ValueError: If the value of ``chi`` is set below 2. ValueError: If the value of ``truncation_fidelity`` is not in [0,1]. """ + _CHI_LIMIT = 2**60 if ( chi is not None + and chi < _CHI_LIMIT and truncation_fidelity is not None and truncation_fidelity != 1.0 ): raise ValueError("Cannot fix both chi and truncation_fidelity.") if chi is None: - chi = 2**60 # In practice, this is like having it be unbounded + chi = _CHI_LIMIT # In practice, this is like having it be unbounded if truncation_fidelity is None: truncation_fidelity = 1 diff --git a/tests/test_structured_state.py b/tests/test_structured_state.py index 6a14a87a..5617ca44 100644 --- a/tests/test_structured_state.py +++ b/tests/test_structured_state.py @@ -51,6 +51,44 @@ def test_init() -> None: assert ttn_gate.is_valid() +@pytest.mark.parametrize( + "algorithm", + [ + SimulationAlgorithm.MPSxGate, + SimulationAlgorithm.MPSxMPO, + SimulationAlgorithm.TTNxGate, + ], +) +def test_copy(algorithm: SimulationAlgorithm) -> None: + simple_circ = Circuit(2).H(0).H(1).CX(0, 1) + + with CuTensorNetHandle() as libhandle: + + # Default config + cfg = Config() + state = simulate(libhandle, simple_circ, algorithm, cfg) + assert state.is_valid() + copy_state = state.copy() + assert copy_state.is_valid() + assert np.isclose(copy_state.vdot(state), 1.0, atol=cfg._atol) + + # Bounded chi + cfg = Config(chi=8) + state = simulate(libhandle, simple_circ, algorithm, cfg) + assert state.is_valid() + copy_state = state.copy() + assert copy_state.is_valid() + assert np.isclose(copy_state.vdot(state), 1.0, atol=cfg._atol) + + # Bounded truncation_fidelity + cfg = Config(truncation_fidelity=0.9999) + state = simulate(libhandle, simple_circ, algorithm, cfg) + assert state.is_valid() + copy_state = state.copy() + assert copy_state.is_valid() + assert np.isclose(copy_state.vdot(state), 1.0, atol=cfg._atol) + + def test_canonicalise_mps() -> None: cp.random.seed(1) circ = Circuit(5)