diff --git a/python/ffsim/variational/ucj_spin_balanced.py b/python/ffsim/variational/ucj_spin_balanced.py index 0505c9135..c5641d99c 100644 --- a/python/ffsim/variational/ucj_spin_balanced.py +++ b/python/ffsim/variational/ucj_spin_balanced.py @@ -98,6 +98,12 @@ class UCJOpSpinBalanced: atol: InitVar[float] = 1e-8 def __post_init__(self, validate: bool, rtol: float, atol: float): + if not np.iscomplexobj(self.orbital_rotations): + raise TypeError("Orbital rotations must have complex data type.") + if self.final_orbital_rotation is not None and not np.iscomplexobj( + self.final_orbital_rotation + ): + raise TypeError("Final orbital rotation must have complex data type.") if validate: if self.diag_coulomb_mats.ndim != 4 or self.diag_coulomb_mats.shape[1] != 2: raise ValueError( diff --git a/python/ffsim/variational/ucj_spin_unbalanced.py b/python/ffsim/variational/ucj_spin_unbalanced.py index 13aa17ae5..3f7d6043c 100644 --- a/python/ffsim/variational/ucj_spin_unbalanced.py +++ b/python/ffsim/variational/ucj_spin_unbalanced.py @@ -100,6 +100,12 @@ class UCJOpSpinUnbalanced: atol: InitVar[float] = 1e-8 def __post_init__(self, validate: bool, rtol: float, atol: float): + if not np.iscomplexobj(self.orbital_rotations): + raise TypeError("Orbital rotations must have complex data type.") + if self.final_orbital_rotation is not None and not np.iscomplexobj( + self.final_orbital_rotation + ): + raise TypeError("Final orbital rotation must have complex data type.") if validate: if self.diag_coulomb_mats.ndim != 4 or self.diag_coulomb_mats.shape[1] != 3: raise ValueError( diff --git a/python/ffsim/variational/ucj_spinless.py b/python/ffsim/variational/ucj_spinless.py index 4e36fd9b9..6e9dfbde0 100644 --- a/python/ffsim/variational/ucj_spinless.py +++ b/python/ffsim/variational/ucj_spinless.py @@ -87,6 +87,12 @@ class UCJOpSpinless: atol: InitVar[float] = 1e-8 def __post_init__(self, validate: bool, rtol: float, atol: float): + if not np.iscomplexobj(self.orbital_rotations): + raise TypeError("Orbital rotations must have complex data type.") + if self.final_orbital_rotation is not None and not np.iscomplexobj( + self.final_orbital_rotation + ): + raise TypeError("Final orbital rotation must have complex data type.") if validate: if self.diag_coulomb_mats.ndim != 3: raise ValueError( diff --git a/tests/python/variational/ucj_spin_balanced_test.py b/tests/python/variational/ucj_spin_balanced_test.py index f92efcb71..f186d5e4e 100644 --- a/tests/python/variational/ucj_spin_balanced_test.py +++ b/tests/python/variational/ucj_spin_balanced_test.py @@ -214,7 +214,7 @@ def test_validate(): norb = 4 eye = np.eye(norb) diag_coulomb_mats = np.stack([np.stack([eye, eye]) for _ in range(n_reps)]) - orbital_rotations = np.stack([eye for _ in range(n_reps)]) + orbital_rotations = np.stack([eye.astype(complex) for _ in range(n_reps)]) _ = ffsim.UCJOpSpinBalanced( diag_coulomb_mats=rng.standard_normal(10), diff --git a/tests/python/variational/ucj_spin_unbalanced_test.py b/tests/python/variational/ucj_spin_unbalanced_test.py index 4e08b530a..216d51279 100644 --- a/tests/python/variational/ucj_spin_unbalanced_test.py +++ b/tests/python/variational/ucj_spin_unbalanced_test.py @@ -249,7 +249,9 @@ def test_validate(): norb = 4 eye = np.eye(norb) diag_coulomb_mats = np.stack([np.stack([eye, eye, eye]) for _ in range(n_reps)]) - orbital_rotations = np.stack([np.stack([eye, eye]) for _ in range(n_reps)]) + orbital_rotations = np.stack( + [np.stack([eye, eye]).astype(complex) for _ in range(n_reps)] + ) _ = ffsim.UCJOpSpinUnbalanced( diag_coulomb_mats=rng.standard_normal(10), diff --git a/tests/python/variational/ucj_spinless_test.py b/tests/python/variational/ucj_spinless_test.py index 2ec1ec785..538d365a4 100644 --- a/tests/python/variational/ucj_spinless_test.py +++ b/tests/python/variational/ucj_spinless_test.py @@ -219,7 +219,7 @@ def test_validate(): norb = 4 eye = np.eye(norb) diag_coulomb_mats = np.stack([eye for _ in range(n_reps)]) - orbital_rotations = np.stack([eye for _ in range(n_reps)]) + orbital_rotations = np.stack([eye.astype(complex) for _ in range(n_reps)]) _ = ffsim.UCJOpSpinless( diag_coulomb_mats=rng.standard_normal(10), @@ -241,13 +241,13 @@ def test_validate(): with pytest.raises(ValueError, match="shape"): _ = ffsim.UCJOpSpinless( diag_coulomb_mats=diag_coulomb_mats, - orbital_rotations=rng.standard_normal(10), + orbital_rotations=rng.standard_normal(10).astype(complex), ) with pytest.raises(ValueError, match="shape"): _ = ffsim.UCJOpSpinless( diag_coulomb_mats=diag_coulomb_mats, orbital_rotations=orbital_rotations, - final_orbital_rotation=rng.standard_normal(10), + final_orbital_rotation=rng.standard_normal(10).astype(complex), ) with pytest.raises(ValueError, match="dimension"): _ = ffsim.UCJOpSpinless( @@ -262,11 +262,11 @@ def test_validate(): with pytest.raises(ValueError, match="unitary"): _ = ffsim.UCJOpSpinless( diag_coulomb_mats=diag_coulomb_mats, - orbital_rotations=rng.standard_normal((n_reps, norb, norb)), + orbital_rotations=rng.standard_normal((n_reps, norb, norb)).astype(complex), ) with pytest.raises(ValueError, match="unitary"): _ = ffsim.UCJOpSpinless( diag_coulomb_mats=diag_coulomb_mats, orbital_rotations=orbital_rotations, - final_orbital_rotation=rng.standard_normal((norb, norb)), + final_orbital_rotation=rng.standard_normal((norb, norb)).astype(complex), )