From 6c103b61a7f5084354b10d720a079849b8e61d03 Mon Sep 17 00:00:00 2001
From: "Kevin J. Sung" <kevjsung@umich.edu>
Date: Fri, 23 Feb 2024 14:38:12 -0600
Subject: [PATCH] only return spin square, not multiplicity

---
 python/ffsim/states/states.py    | 5 +----
 tests/gates/diag_coulomb_test.py | 4 ++--
 tests/variational/ucj_test.py    | 8 ++++----
 3 files changed, 7 insertions(+), 10 deletions(-)

diff --git a/python/ffsim/states/states.py b/python/ffsim/states/states.py
index ad8f789c7..6a83538ed 100644
--- a/python/ffsim/states/states.py
+++ b/python/ffsim/states/states.py
@@ -215,7 +215,4 @@ def spin_square(fcivec: np.ndarray, norb: int, nelec: tuple[int, int]):
         ci1 += 1j * contract_ss(fcivec.imag, norb, nelec)
     else:
         ci1 = contract_ss(fcivec, norb, nelec)
-    ss = np.einsum("ij,ij->", fcivec.reshape(ci1.shape), ci1.conj()).real
-    s = np.sqrt(ss + 0.25) - 0.5
-    multip = s * 2 + 1
-    return ss, multip
+    return np.einsum("ij,ij->", fcivec.reshape(ci1.shape), ci1.conj()).real
diff --git a/tests/gates/diag_coulomb_test.py b/tests/gates/diag_coulomb_test.py
index 825098abe..c5ea65a1c 100644
--- a/tests/gates/diag_coulomb_test.py
+++ b/tests/gates/diag_coulomb_test.py
@@ -80,7 +80,7 @@ def test_apply_diag_coulomb_evolution_spin(z_representation: bool):
         orbital_rotation = ffsim.random.random_unitary(norb, seed=rng)
         vec = ffsim.random.random_statevector(dim, seed=rng)
 
-        spin_squared_init, _ = ffsim.spin_square(vec, norb=norb, nelec=nelec)
+        spin_squared_init = ffsim.spin_square(vec, norb=norb, nelec=nelec)
 
         time = rng.uniform()
         result = ffsim.apply_diag_coulomb_evolution(
@@ -93,7 +93,7 @@ def test_apply_diag_coulomb_evolution_spin(z_representation: bool):
             z_representation=z_representation,
         )
 
-        spin_squared_result, _ = ffsim.spin_square(result, norb=norb, nelec=nelec)
+        spin_squared_result = ffsim.spin_square(result, norb=norb, nelec=nelec)
 
         np.testing.assert_allclose(spin_squared_result, spin_squared_init)
 
diff --git a/tests/variational/ucj_test.py b/tests/variational/ucj_test.py
index fa84d86f4..0b09a8b20 100644
--- a/tests/variational/ucj_test.py
+++ b/tests/variational/ucj_test.py
@@ -208,7 +208,7 @@ def test_t_amplitudes_spin():
     reference_state = ffsim.slater_determinant(
         norb=norb, occupied_orbitals=(range(n_alpha), range(n_beta))
     )
-    spin_squared, _ = ffsim.spin_square(reference_state, norb=norb, nelec=nelec)
+    spin_squared = ffsim.spin_square(reference_state, norb=norb, nelec=nelec)
     np.testing.assert_allclose(spin_squared, 0)
 
     # Apply the operator to the reference state
@@ -222,7 +222,7 @@ def test_t_amplitudes_spin():
     np.testing.assert_allclose(energy, -108.595692)
 
     # Compute the spin of the ansatz state
-    spin_squared, _ = ffsim.spin_square(ansatz_state, norb=norb, nelec=nelec)
+    spin_squared = ffsim.spin_square(ansatz_state, norb=norb, nelec=nelec)
     np.testing.assert_allclose(spin_squared, 0, atol=1e-12)
 
 
@@ -407,7 +407,7 @@ def test_real_ucj_t_amplitudes_spin():
     reference_state = ffsim.slater_determinant(
         norb=norb, occupied_orbitals=(range(n_alpha), range(n_beta))
     )
-    spin_squared, _ = ffsim.spin_square(reference_state, norb=norb, nelec=nelec)
+    spin_squared = ffsim.spin_square(reference_state, norb=norb, nelec=nelec)
     np.testing.assert_allclose(spin_squared, 0)
 
     # Apply the operator to the reference state
@@ -421,7 +421,7 @@ def test_real_ucj_t_amplitudes_spin():
     np.testing.assert_allclose(energy, -108.595692)
 
     # Compute the spin of the ansatz state
-    spin_squared, _ = ffsim.spin_square(ansatz_state, norb=norb, nelec=nelec)
+    spin_squared = ffsim.spin_square(ansatz_state, norb=norb, nelec=nelec)
     np.testing.assert_allclose(spin_squared, 0, atol=1e-12)