From 150cfeabe0c3e456834e98de88130abbf90ac657 Mon Sep 17 00:00:00 2001 From: milad2073 Date: Wed, 16 Oct 2024 09:51:25 +0330 Subject: [PATCH] Add similarity calculation for all pairs of vectors in the unit test for the circular function --- torchhd/tests/basis_hv/test_circular_hv.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/torchhd/tests/basis_hv/test_circular_hv.py b/torchhd/tests/basis_hv/test_circular_hv.py index 8aec30f..6989f06 100644 --- a/torchhd/tests/basis_hv/test_circular_hv.py +++ b/torchhd/tests/basis_hv/test_circular_hv.py @@ -127,15 +127,18 @@ def test_value(self, dtype, vsa): ) else: hv = functional.circular(8, 1000000, vsa, generator=generator, dtype=dtype) - sims = functional.cosine_similarity(hv[0], hv) - sims_diff = sims[:-1] - sims[1:] - assert torch.all( - sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1]) - ), "second half must get more similar" - - assert torch.allclose( - sims_diff.abs(), torch.tensor(0.25, dtype=sims_diff.dtype), atol=0.005 - ), "similarity decreases linearly" + + for i in range(8-1): + sims = functional.cosine_similarity(hv[0], hv) + sims_diff = sims[:-1] - sims[1:] + assert torch.all( + sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1]) + ), f"element #{i}: second half must get more similar" + + assert torch.allclose( + sims_diff.abs(), torch.tensor(0.25, dtype=sims_diff.dtype), atol=0.005 + ), f"element #{i}: similarity decreases linearly" + hv = torch.roll(hv,1,0) @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0]) @pytest.mark.parametrize("dtype", torch_dtypes)