diff --git a/python/ffsim/states/states.py b/python/ffsim/states/states.py index 78a0ae7ef..3432ce345 100644 --- a/python/ffsim/states/states.py +++ b/python/ffsim/states/states.py @@ -599,13 +599,14 @@ def sample_state_vector( TypeError: When passing vec as a StateVector, norb and nelec must both be None. """ vec, norb, nelec = canonicalize_vec_norb_nelec(vec, norb, nelec) + all_orbs = list(range(norb if isinstance(nelec, (int, np.integer)) else 2 * norb)) if orbs is None: - orbs = range(2 * norb) + orbs = all_orbs rng = np.random.default_rng(seed) probabilities = np.abs(vec) ** 2 samples = rng.choice(len(vec), size=shots, p=probabilities) bitstrings = indices_to_strings(samples, norb, nelec) - if list(orbs) == list(range(2 * norb)): + if list(orbs) == all_orbs: return bitstrings return ["".join(bitstring[-1 - i] for i in orbs[::-1]) for bitstring in bitstrings] diff --git a/tests/python/states/states_test.py b/tests/python/states/states_test.py index 8408a5ac0..bc500a4e1 100644 --- a/tests/python/states/states_test.py +++ b/tests/python/states/states_test.py @@ -364,8 +364,8 @@ def test_slater_determinant_one_rdm_same_rotation( np.testing.assert_allclose(rdm, expected, atol=1e-12) -def test_sample_state_vector(): - """Test sampling state vector.""" +def test_sample_state_vector_spinful(): + """Test sampling state vector, spinful.""" norb = 5 nelec = (3, 2) index = ffsim.strings_to_indices(["1000101101"], norb=norb, nelec=nelec)[0] @@ -380,6 +380,22 @@ def test_sample_state_vector(): assert samples == ["101101"] * 10 +def test_sample_state_vector_spinless(): + """Test sampling state vector, spinless.""" + norb = 5 + nelec = 3 + index = ffsim.strings_to_indices(["01101"], norb=norb, nelec=nelec)[0] + vec = ffsim.linalg.one_hot(ffsim.dim(norb, nelec), index) + + samples = ffsim.sample_state_vector(vec, norb=norb, nelec=nelec) + assert samples == ["01101"] + + samples = ffsim.sample_state_vector( + vec, orbs=[0, 1, 3], shots=10, norb=norb, nelec=nelec + ) + assert samples == ["101"] * 10 + + @pytest.mark.parametrize( "norb, nelec, spin_summed", [