From 545860c7a56c7f197ec6fb87d4d9793360bdfb19 Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 09:59:27 -0800 Subject: [PATCH 1/9] Add AdaptHD centroid update rule, and fix #120 --- torchhd/models.py | 43 ++++++++++++++++++++++++++++++++++------- torchhd/tensors/fhrr.py | 5 +++++ torchhd/tensors/hrr.py | 5 +++++ torchhd/tensors/map.py | 5 +++++ torchhd/tensors/vtb.py | 5 +++++ 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/torchhd/models.py b/torchhd/models.py index 0f7de6dc..d806f657 100644 --- a/torchhd/models.py +++ b/torchhd/models.py @@ -28,12 +28,8 @@ from torch import Tensor from torch.nn.parameter import Parameter import torch.nn.init as init -import torch.utils.data as data -from tqdm import tqdm - import torchhd.functional as functional -import torchhd.datasets as datasets import torchhd.embeddings as embeddings @@ -71,6 +67,7 @@ class Centroid(nn.Module): >>> output.size() torch.Size([128, 30]) """ + __constants__ = ["in_features", "out_features"] in_features: int out_features: int @@ -108,6 +105,30 @@ def add(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: """Adds the input vectors scaled by the lr to the target prototype vectors.""" self.weight.index_add_(0, target, input, alpha=lr) + @torch.no_grad() + def add_adapt(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: + r"""Only updates the prototype vectors on wrongly predicted inputs. + + Implements the iterative training method as described in `AdaptHD: Adaptive Efficient Training for Brain-Inspired Hyperdimensional Computing `_. + + Subtracts the input from the mispredicted class prototype scaled by the learning rate + and adds the input to the target prototype scaled by the learning rate. + """ + logit = self(input) + pred = logit.argmax(1) + is_wrong = target != pred + + # cancel update if all predictions were correct + if is_wrong.sum().item() == 0: + return + + input = input[is_wrong] + target = target[is_wrong] + pred = pred[is_wrong] + + self.weight.index_add_(0, target, input, alpha=lr) + self.weight.index_add_(0, pred, input, alpha=-lr) + @torch.no_grad() def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: r"""Only updates the prototype vectors on wrongly predicted inputs. @@ -137,8 +158,8 @@ def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: alpha1 = 1.0 - logit.gather(1, target.unsqueeze(1)) alpha2 = logit.gather(1, pred.unsqueeze(1)) - 1.0 - self.weight.index_add_(0, target, lr * alpha1 * input) - self.weight.index_add_(0, pred, lr * alpha2 * input) + self.weight.index_add_(0, target, alpha1 * input, alpha=lr) + self.weight.index_add_(0, pred, alpha2 * input, alpha=lr) @torch.no_grad() def normalize(self, eps=1e-12) -> None: @@ -148,12 +169,20 @@ def normalize(self, eps=1e-12) -> None: Training further after calling this method is not advised. """ norms = self.weight.norm(dim=1, keepdim=True) + + if torch.isclose(norms, torch.zeros_like(norms), equal_nan=True).any(): + import warnings + + warnings.warn( + "The norm of a prototype vector is nearly zero upon normalizing, this could indicate a bug." + ) + norms.clamp_(min=eps) self.weight.div_(norms) def extra_repr(self) -> str: return "in_features={}, out_features={}".format( - self.in_features, self.out_features is not None + self.in_features, self.out_features ) diff --git a/torchhd/tensors/fhrr.py b/torchhd/tensors/fhrr.py index 7f8d0fa7..94a75cb8 100644 --- a/torchhd/tensors/fhrr.py +++ b/torchhd/tensors/fhrr.py @@ -395,5 +395,10 @@ def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor: else: magnitude = self_mag * others_mag + if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any(): + import warnings + + warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.") + magnitude = torch.clamp(magnitude, min=eps) return self.dot_similarity(others) / magnitude diff --git a/torchhd/tensors/hrr.py b/torchhd/tensors/hrr.py index 9fd08815..95541cb5 100644 --- a/torchhd/tensors/hrr.py +++ b/torchhd/tensors/hrr.py @@ -382,5 +382,10 @@ def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor: else: magnitude = self_mag * others_mag + if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any(): + import warnings + + warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.") + magnitude = torch.clamp(magnitude, min=eps) return self.dot_similarity(others) / magnitude diff --git a/torchhd/tensors/map.py b/torchhd/tensors/map.py index 60e8e3ac..f325f84e 100644 --- a/torchhd/tensors/map.py +++ b/torchhd/tensors/map.py @@ -368,5 +368,10 @@ def cosine_similarity( else: magnitude = self_mag * others_mag + if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any(): + import warnings + + warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.") + magnitude = torch.clamp(magnitude, min=eps) return self.dot_similarity(others, dtype=dtype) / magnitude diff --git a/torchhd/tensors/vtb.py b/torchhd/tensors/vtb.py index f7bd84de..f79697b3 100644 --- a/torchhd/tensors/vtb.py +++ b/torchhd/tensors/vtb.py @@ -411,5 +411,10 @@ def cosine_similarity(self, others: "VTBTensor", *, eps=1e-08) -> Tensor: else: magnitude = self_mag * others_mag + if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any(): + import warnings + + warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.") + magnitude = torch.clamp(magnitude, min=eps) return self.dot_similarity(others) / magnitude From 4f7135d2c0e2f20f7ec925b44efb8829065d1dd1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 5 Mar 2024 20:53:41 +0000 Subject: [PATCH 2/9] [github-action] formatting fixes --- setup.py | 1 + torchhd/structures.py | 30 ++++++++++-------------------- torchhd/tensors/bsbc.py | 1 + torchhd/tensors/fhrr.py | 4 +++- torchhd/tensors/hrr.py | 4 +++- torchhd/tensors/map.py | 4 +++- torchhd/tensors/vtb.py | 4 +++- 7 files changed, 24 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index 61565201..13b260f5 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ https://packaging.python.org/guides/distributing-packages-using-setuptools/ https://github.com/pypa/sampleproject """ + from setuptools import setup, find_packages # Read the version without importing any dependencies diff --git a/torchhd/structures.py b/torchhd/structures.py index ec1f96bf..c381d76e 100644 --- a/torchhd/structures.py +++ b/torchhd/structures.py @@ -186,12 +186,10 @@ class Multiset: @overload def __init__( self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None - ): - ... + ): ... @overload - def __init__(self, input: VSATensor, *, size=0): - ... + def __init__(self, input: VSATensor, *, size=0): ... def __init__(self, dim_or_input: Any, vsa: VSAOptions = "MAP", **kwargs): self.size = kwargs.get("size", 0) @@ -334,12 +332,10 @@ class HashTable: @overload def __init__( self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None - ): - ... + ): ... @overload - def __init__(self, input: VSATensor, *, size=0): - ... + def __init__(self, input: VSATensor, *, size=0): ... def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs): self.size = kwargs.get("size", 0) @@ -501,12 +497,10 @@ class BundleSequence: @overload def __init__( self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None - ): - ... + ): ... @overload - def __init__(self, input: VSATensor, *, size=0): - ... + def __init__(self, input: VSATensor, *, size=0): ... def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs): self.size = kwargs.get("size", 0) @@ -693,12 +687,10 @@ class BindSequence: @overload def __init__( self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None - ): - ... + ): ... @overload - def __init__(self, input: VSATensor, *, size=0): - ... + def __init__(self, input: VSATensor, *, size=0): ... def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs): self.size = kwargs.get("size", 0) @@ -861,12 +853,10 @@ def __init__( directed=False, device=None, dtype=None - ): - ... + ): ... @overload - def __init__(self, input: VSATensor, *, directed=False): - ... + def __init__(self, input: VSATensor, *, directed=False): ... def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs): self.is_directed = kwargs.get("directed", False) diff --git a/torchhd/tensors/bsbc.py b/torchhd/tensors/bsbc.py index e2c1688e..3f79d0bc 100644 --- a/torchhd/tensors/bsbc.py +++ b/torchhd/tensors/bsbc.py @@ -36,6 +36,7 @@ class BSBCTensor(VSATensor): Because the vectors are sparse and have a fixed magnitude, we only represent the index of the non-zero value. """ + block_size: int supported_dtypes: Set[torch.dtype] = { torch.float32, diff --git a/torchhd/tensors/fhrr.py b/torchhd/tensors/fhrr.py index 94a75cb8..55d0ddf5 100644 --- a/torchhd/tensors/fhrr.py +++ b/torchhd/tensors/fhrr.py @@ -398,7 +398,9 @@ def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor: if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any(): import warnings - warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.") + warnings.warn( + "The norm of a vector is nearly zero, this could indicate a bug." + ) magnitude = torch.clamp(magnitude, min=eps) return self.dot_similarity(others) / magnitude diff --git a/torchhd/tensors/hrr.py b/torchhd/tensors/hrr.py index 95541cb5..34ffca4f 100644 --- a/torchhd/tensors/hrr.py +++ b/torchhd/tensors/hrr.py @@ -385,7 +385,9 @@ def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor: if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any(): import warnings - warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.") + warnings.warn( + "The norm of a vector is nearly zero, this could indicate a bug." + ) magnitude = torch.clamp(magnitude, min=eps) return self.dot_similarity(others) / magnitude diff --git a/torchhd/tensors/map.py b/torchhd/tensors/map.py index f325f84e..b93c4a54 100644 --- a/torchhd/tensors/map.py +++ b/torchhd/tensors/map.py @@ -371,7 +371,9 @@ def cosine_similarity( if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any(): import warnings - warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.") + warnings.warn( + "The norm of a vector is nearly zero, this could indicate a bug." + ) magnitude = torch.clamp(magnitude, min=eps) return self.dot_similarity(others, dtype=dtype) / magnitude diff --git a/torchhd/tensors/vtb.py b/torchhd/tensors/vtb.py index f79697b3..8329bb86 100644 --- a/torchhd/tensors/vtb.py +++ b/torchhd/tensors/vtb.py @@ -414,7 +414,9 @@ def cosine_similarity(self, others: "VTBTensor", *, eps=1e-08) -> Tensor: if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any(): import warnings - warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.") + warnings.warn( + "The norm of a vector is nearly zero, this could indicate a bug." + ) magnitude = torch.clamp(magnitude, min=eps) return self.dot_similarity(others) / magnitude From 0ac1add477c5c1a94bc111cc1f7650f8b8890651 Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 20:14:29 -0800 Subject: [PATCH 3/9] Add test --- torchhd/models.py | 1 - torchhd/tests/test_models.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/torchhd/models.py b/torchhd/models.py index d806f657..af6d7b2b 100644 --- a/torchhd/models.py +++ b/torchhd/models.py @@ -161,7 +161,6 @@ def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: self.weight.index_add_(0, target, alpha1 * input, alpha=lr) self.weight.index_add_(0, pred, alpha2 * input, alpha=lr) - @torch.no_grad() def normalize(self, eps=1e-12) -> None: """Transforms all the class prototype vectors into unit vectors. diff --git a/torchhd/tests/test_models.py b/torchhd/tests/test_models.py index e4226f50..93a9eca7 100644 --- a/torchhd/tests/test_models.py +++ b/torchhd/tests/test_models.py @@ -82,6 +82,16 @@ def test_add_online(self): logits = model(samples) assert logits.shape == (10, 3) + def test_add_adapt(self): + samples = torch.randn(10, 12) + targets = torch.randint(0, 3, (10,)) + + model = models.Centroid(12, 3) + model.add_adapt(samples, targets) + + logits = model(samples) + assert logits.shape == (10, 3) + class TestIntRVFL: @pytest.mark.parametrize("dtype", torch_dtypes) From 5e32cbebac793dddb75b1211e13b7d9baa783f58 Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 20:56:41 -0800 Subject: [PATCH 4/9] Simplify google drive download --- torchhd/datasets/utils.py | 85 ++++----------------------------------- 1 file changed, 7 insertions(+), 78 deletions(-) diff --git a/torchhd/datasets/utils.py b/torchhd/datasets/utils.py index 05898e84..71cd97f9 100644 --- a/torchhd/datasets/utils.py +++ b/torchhd/datasets/utils.py @@ -23,13 +23,8 @@ # import zipfile import requests -import re import tqdm -# Code adapted from: -# https://github.com/wkentaro/gdown/blob/941200a9a1f4fd7ab903fb595baa5cad34a30a45/gdown/download.py -# https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url - def download_file(url, destination): response = requests.get(url, allow_redirects=True, stream=True) @@ -37,79 +32,13 @@ def download_file(url, destination): def download_file_from_google_drive(file_id, destination): - URL = "https://docs.google.com/uc" - params = dict(id=file_id, export="download") - - with requests.Session() as session: - response = session.get(URL, params=params, stream=True) - - # downloads right away - if "Content-Disposition" in response.headers: - write_response_to_disk(response, destination) - return - - # try to find a confirmation token - token = get_google_drive_confirm_token(response) - - if token: - params = dict(id=id, confirm=token) - response = session.get(URL, params=params, stream=True) - - # download if confirmation token worked - if "Content-Disposition" in response.headers: - write_response_to_disk(response, destination) - return - - # extract download url from confirmation page - url = get_url_from_gdrive_confirmation(response.text) - response = session.get(url, stream=True) - - write_response_to_disk(response, destination) - - -def get_google_drive_confirm_token(response): - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - return value - - return None - - -def get_url_from_gdrive_confirmation(contents): - url = "" - for line in contents.splitlines(): - m = re.search(r'href="(\/uc\?export=download[^"]+)', line) - if m: - url = "https://docs.google.com" + m.groups()[0] - url = url.replace("&", "&") - break - m = re.search('id="downloadForm" action="(.+?)"', line) - if m: - url = m.groups()[0] - url = url.replace("&", "&") - break - m = re.search('id="download-form" action="(.+?)"', line) - if m: - url = m.groups()[0] - url = url.replace("&", "&") - break - m = re.search('"downloadUrl":"([^"]+)', line) - if m: - url = m.groups()[0] - url = url.replace("\\u003d", "=") - url = url.replace("\\u0026", "&") - break - m = re.search('

(.*)

', line) - if m: - error = m.groups()[0] - raise RuntimeError(error) - if not url: - raise RuntimeError( - "Cannot retrieve the public link of the file. " - "You may need to change the permission to " - "'Anyone with the link', or have had many accesses." - ) - return url + try: + import gdown + except ImportError: + raise ImportError("Downloading files from Google drive requires gdown to be installed, see: https://github.com/wkentaro/gdown") + + url = f"https://drive.google.com/uc?id={file_id}" + gdown.download(url, destination) def get_download_progress_bar(response): From 2c87ecc0d9b92fe860ecbc4ce48021b4177fa1bf Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 20:57:45 -0800 Subject: [PATCH 5/9] Add gdown to dev dependencies --- dev-requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index f242a5cd..cb8699a8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -9,4 +9,5 @@ pytest black tqdm openpyxl -coverage \ No newline at end of file +coverage +gdown \ No newline at end of file From 5a75491b09b313d9b90c9932f4176c4355e8e7cf Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 6 Mar 2024 04:58:11 +0000 Subject: [PATCH 6/9] [github-action] formatting fixes --- torchhd/datasets/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchhd/datasets/utils.py b/torchhd/datasets/utils.py index 71cd97f9..0b7be1f8 100644 --- a/torchhd/datasets/utils.py +++ b/torchhd/datasets/utils.py @@ -35,8 +35,10 @@ def download_file_from_google_drive(file_id, destination): try: import gdown except ImportError: - raise ImportError("Downloading files from Google drive requires gdown to be installed, see: https://github.com/wkentaro/gdown") - + raise ImportError( + "Downloading files from Google drive requires gdown to be installed, see: https://github.com/wkentaro/gdown" + ) + url = f"https://drive.google.com/uc?id={file_id}" gdown.download(url, destination) From f0b9232b5b10b26cc455d7e72f38e57fbac8a2fd Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 21:18:51 -0800 Subject: [PATCH 7/9] Update tests --- torchhd/tests/test_embeddings.py | 22 +++++++++++----------- torchhd/tests/test_encodings.py | 8 -------- torchhd/tests/test_similarities.py | 6 +++--- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/torchhd/tests/test_embeddings.py b/torchhd/tests/test_embeddings.py index a9abda34..17b6362f 100644 --- a/torchhd/tests/test_embeddings.py +++ b/torchhd/tests/test_embeddings.py @@ -74,7 +74,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(idx).dtype == torch.bool elif vsa == "MAP" or vsa == "HRR": - assert emb(idx).dtype == torch.float + assert emb(idx).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert ( emb(idx).dtype == torch.complex64 or emb(idx).dtype == torch.complex32 @@ -142,7 +142,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(idx).dtype == torch.bool elif vsa in {"MAP", "HRR", "VTB"}: - assert emb(idx).dtype == torch.float + assert emb(idx).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert emb(idx).dtype in {torch.complex64, torch.complex32} @@ -244,7 +244,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(idx).dtype == torch.bool elif vsa in {"MAP", "HRR", "VTB"}: - assert emb(idx).dtype == torch.float + assert emb(idx).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert emb(idx).dtype in {torch.complex64, torch.complex32} @@ -295,7 +295,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(idx).dtype == torch.bool elif vsa in {"MAP", "HRR", "VTB"}: - assert emb(idx).dtype == torch.float + assert emb(idx).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert emb(idx).dtype in {torch.complex64, torch.complex32} @@ -365,7 +365,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(angle).dtype == torch.bool elif vsa == "MAP": - assert emb(angle).dtype == torch.float + assert emb(angle).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert ( emb(angle).dtype == torch.complex64 @@ -441,7 +441,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(angle).dtype == torch.bool elif vsa == "MAP": - assert emb(angle).dtype == torch.float + assert emb(angle).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert ( emb(angle).dtype == torch.complex64 @@ -504,7 +504,7 @@ def test_dtype(self, vsa): emb = embeddings.Projection(in_features, out_features, vsa=vsa) x = torch.randn(1, in_features) if vsa == "MAP" or vsa == "HRR": - assert emb(x).dtype == torch.float + assert emb(x).dtype == torch.get_default_dtype() else: return @@ -549,7 +549,7 @@ def test_dtype(self, vsa): emb = embeddings.Sinusoid(in_features, out_features, vsa=vsa) x = torch.randn(1, in_features) if vsa == "MAP" or vsa == "HRR": - assert emb(x).dtype == torch.float + assert emb(x).dtype == torch.get_default_dtype() else: return @@ -611,7 +611,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(x).dtype == torch.bool elif vsa == "MAP": - assert emb(x).dtype == torch.float + assert emb(x).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert emb(x).dtype == torch.complex64 or emb(x).dtype == torch.complex32 else: @@ -664,9 +664,9 @@ def test_default_dtype(self, vsa): assert y.shape == (2, dimensions) if vsa == "HRR": - assert y.dtype == torch.float32 + assert y.dtype == torch.get_default_dtype() elif vsa == "FHRR": - assert y.dtype == torch.complex64 + assert fhrr_type_conversion[y.dtype] == torch.get_default_dtype() else: return diff --git a/torchhd/tests/test_encodings.py b/torchhd/tests/test_encodings.py index 927993b3..af205bb9 100644 --- a/torchhd/tests/test_encodings.py +++ b/torchhd/tests/test_encodings.py @@ -141,10 +141,6 @@ def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype).as_subclass(MAPTensor) if dtype in {torch.float16}: - # torch.product is not implemented on CPU for these dtypes - with pytest.raises(RuntimeError): - functional.multibind(hv) - return res = functional.multibind(hv) @@ -288,10 +284,6 @@ def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype).as_subclass(MAPTensor) if dtype in {torch.float16}: - # torch.product is not implemented on CPU for these dtypes - with pytest.raises(RuntimeError): - functional.multibind(hv) - return res = functional.bind_sequence(hv) diff --git a/torchhd/tests/test_similarities.py b/torchhd/tests/test_similarities.py index eb104885..d33c0b48 100644 --- a/torchhd/tests/test_similarities.py +++ b/torchhd/tests/test_similarities.py @@ -118,7 +118,7 @@ def test_value(self, vsa, dtype): ).as_subclass(BSCTensor) res = functional.dot_similarity(hv, hv) - exp = torch.tensor([[10, 4], [4, 10]], dtype=torch.long) + exp = torch.tensor([[10, 4], [4, 10]], dtype=res.dtype) assert torch.all(res == exp).item() elif vsa == "FHRR": @@ -339,7 +339,7 @@ def test_value(self, vsa, dtype): ).as_subclass(BSCTensor) res = functional.cosine_similarity(hv, hv) - exp = torch.tensor([[1, 0.4], [0.4, 1]], dtype=torch.float) + exp = torch.tensor([[1, 0.4], [0.4, 1]], dtype=res.dtype) assert torch.allclose(res, exp) elif vsa == "FHRR": @@ -529,7 +529,7 @@ def test_value(self, vsa, dtype): ).as_subclass(BSCTensor) res = functional.hamming_similarity(hv, hv) - exp = torch.tensor([[10, 7], [7, 10]], dtype=torch.long) + exp = torch.tensor([[10, 7], [7, 10]], dtype=res.dtype) assert torch.all(res == exp).item() elif vsa == "FHRR": From 19ca16321499de27223cac31a92aeaffd8cd2a8c Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 21:44:13 -0800 Subject: [PATCH 8/9] Fix test --- torchhd/tests/test_similarities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchhd/tests/test_similarities.py b/torchhd/tests/test_similarities.py index d33c0b48..96f28717 100644 --- a/torchhd/tests/test_similarities.py +++ b/torchhd/tests/test_similarities.py @@ -388,7 +388,7 @@ def test_value(self, vsa, dtype): ).as_subclass(MAPTensor) res = functional.cosine_similarity(hv, hv) - exp = torch.tensor([[1, -0.4], [-0.4, 1]], dtype=torch.float) + exp = torch.tensor([[1, -0.4], [-0.4, 1]], dtype=res.dtype) assert torch.allclose(res, exp) @pytest.mark.parametrize("vsa", vsa_tensors) From fc3b80c52eb79868112a756408b71750127c907a Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 22:00:26 -0800 Subject: [PATCH 9/9] Simpler intrvfl setup --- torchhd/tests/test_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchhd/tests/test_models.py b/torchhd/tests/test_models.py index 93a9eca7..a721a31f 100644 --- a/torchhd/tests/test_models.py +++ b/torchhd/tests/test_models.py @@ -113,7 +113,9 @@ def test_initialization(self, dtype): assert model.weight.device.type == device.type def test_fit_ridge_regression(self): - samples = torch.eye(10, 12) + a = torch.randn(10) + b = torch.randn(12) + samples = torch.outer(a, b) targets = torch.arange(10) model = models.IntRVFL(12, 1245, 10)