Skip to content

Commit

Permalink
[github-action] formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Jun 7, 2022
1 parent 1f7e9db commit d0826d0
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 12 deletions.
4 changes: 1 addition & 3 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def random_hv(
if dtype in {torch.complex64, torch.complex128}:
dtype = torch.float if dtype == torch.complex64 else torch.double

angle = torch.empty(
num_embeddings, embedding_dim, dtype=dtype, device=device
)
angle = torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device)
angle.uniform_(-math.pi, math.pi)
magnitude = torch.ones(
num_embeddings, embedding_dim, dtype=dtype, device=device
Expand Down
7 changes: 4 additions & 3 deletions torchhd/tests/basis_hv/test_circular_hv.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ def test_value(self, dtype):
(hv == True) | (hv == False)
).item(), "values are either 1 or 0"
elif dtype in torch_complex_dtypes:
magnitudes= hv.abs()
assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1"
magnitudes = hv.abs()
assert torch.allclose(
magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)
), "magnitude must be 1"
else:
assert torch.all(
(hv == -1) | (hv == 1)
).item(), "values are either -1 or +1"


hv = functional.circular_hv(8, 1000000, generator=generator, dtype=dtype)
if dtype in torch_complex_dtypes:
sims = functional.cosine_similarity(hv[0], hv)
Expand Down
6 changes: 4 additions & 2 deletions torchhd/tests/basis_hv/test_level_hv.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def test_value(self, dtype):
(hv == True) | (hv == False)
).item(), "values are either 1 or 0"
elif dtype in torch_complex_dtypes:
magnitudes= hv.abs()
assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1"
magnitudes = hv.abs()
assert torch.allclose(
magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)
), "magnitude must be 1"
else:
assert torch.all(
(hv == -1) | (hv == 1)
Expand Down
6 changes: 4 additions & 2 deletions torchhd/tests/basis_hv/test_random_hv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def test_value(self, dtype):
if dtype == torch.bool:
assert torch.all((hv == False) | (hv == True)).item()
elif dtype in torch_complex_dtypes:
magnitudes= hv.abs()
assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1"
magnitudes = hv.abs()
assert torch.allclose(
magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)
), "magnitude must be 1"
else:
assert torch.all((hv == -1) | (hv == 1)).item()

Expand Down
2 changes: 1 addition & 1 deletion torchhd/tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_value(self, dtype):
@pytest.mark.parametrize("dtype", torch_dtypes)
def test_dtype(self, dtype):
hv = torch.zeros(23, 1000, dtype=dtype)

if dtype == torch.uint8:
with pytest.raises(ValueError):
functional.multibind(hv)
Expand Down
2 changes: 1 addition & 1 deletion torchhd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def plot_pair_similarity(memory: Tensor, ax=None, **kwargs):
See https://matplotlib.org/stable/users/installing/index.html for more information."
)

similarity = functional.cosine_similarity(memory, memory).tolist()
similarity = functional.cosine_similarity(memory, memory).tolist()

if ax is None:
ax = plt.gca()
Expand Down

0 comments on commit d0826d0

Please sign in to comment.