Skip to content

Commit

Permalink
handle sample spinless
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung committed Jun 30, 2024
1 parent 967e8bf commit 1abcec6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/ffsim/states/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,14 @@ def sample_state_vector(
TypeError: When passing vec as a StateVector, norb and nelec must both be None.
"""
vec, norb, nelec = canonicalize_vec_norb_nelec(vec, norb, nelec)
all_orbs = list(range(norb if isinstance(nelec, (int, np.integer)) else 2 * norb))
if orbs is None:
orbs = range(2 * norb)
orbs = all_orbs
rng = np.random.default_rng(seed)
probabilities = np.abs(vec) ** 2
samples = rng.choice(len(vec), size=shots, p=probabilities)
bitstrings = indices_to_strings(samples, norb, nelec)
if list(orbs) == list(range(2 * norb)):
if list(orbs) == all_orbs:
return bitstrings
return ["".join(bitstring[-1 - i] for i in orbs[::-1]) for bitstring in bitstrings]

Expand Down
20 changes: 18 additions & 2 deletions tests/python/states/states_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def test_slater_determinant_one_rdm_same_rotation(
np.testing.assert_allclose(rdm, expected, atol=1e-12)


def test_sample_state_vector():
"""Test sampling state vector."""
def test_sample_state_vector_spinful():
"""Test sampling state vector, spinful."""
norb = 5
nelec = (3, 2)
index = ffsim.strings_to_indices(["1000101101"], norb=norb, nelec=nelec)[0]
Expand All @@ -380,6 +380,22 @@ def test_sample_state_vector():
assert samples == ["101101"] * 10


def test_sample_state_vector_spinless():
"""Test sampling state vector, spinless."""
norb = 5
nelec = 3
index = ffsim.strings_to_indices(["01101"], norb=norb, nelec=nelec)[0]
vec = ffsim.linalg.one_hot(ffsim.dim(norb, nelec), index)

samples = ffsim.sample_state_vector(vec, norb=norb, nelec=nelec)
assert samples == ["01101"]

samples = ffsim.sample_state_vector(
vec, orbs=[0, 1, 3], shots=10, norb=norb, nelec=nelec
)
assert samples == ["101"] * 10


@pytest.mark.parametrize(
"norb, nelec, spin_summed",
[
Expand Down

0 comments on commit 1abcec6

Please sign in to comment.