Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation of AdaptHD and zero-norm warning #165

Merged
merged 9 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ pytest
black
tqdm
openpyxl
coverage
coverage
gdown
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
https://packaging.python.org/guides/distributing-packages-using-setuptools/
https://github.com/pypa/sampleproject
"""

from setuptools import setup, find_packages

# Read the version without importing any dependencies
Expand Down
85 changes: 8 additions & 77 deletions torchhd/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,93 +23,24 @@
#
import zipfile
import requests
import re
import tqdm

# Code adapted from:
# https://github.com/wkentaro/gdown/blob/941200a9a1f4fd7ab903fb595baa5cad34a30a45/gdown/download.py
# https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url


def download_file(url, destination):
response = requests.get(url, allow_redirects=True, stream=True)
write_response_to_disk(response, destination)


def download_file_from_google_drive(file_id, destination):
URL = "https://docs.google.com/uc"
params = dict(id=file_id, export="download")

with requests.Session() as session:
response = session.get(URL, params=params, stream=True)

# downloads right away
if "Content-Disposition" in response.headers:
write_response_to_disk(response, destination)
return

# try to find a confirmation token
token = get_google_drive_confirm_token(response)

if token:
params = dict(id=id, confirm=token)
response = session.get(URL, params=params, stream=True)

# download if confirmation token worked
if "Content-Disposition" in response.headers:
write_response_to_disk(response, destination)
return

# extract download url from confirmation page
url = get_url_from_gdrive_confirmation(response.text)
response = session.get(url, stream=True)

write_response_to_disk(response, destination)


def get_google_drive_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith("download_warning"):
return value

return None


def get_url_from_gdrive_confirmation(contents):
url = ""
for line in contents.splitlines():
m = re.search(r'href="(\/uc\?export=download[^"]+)', line)
if m:
url = "https://docs.google.com" + m.groups()[0]
url = url.replace("&", "&")
break
m = re.search('id="downloadForm" action="(.+?)"', line)
if m:
url = m.groups()[0]
url = url.replace("&", "&")
break
m = re.search('id="download-form" action="(.+?)"', line)
if m:
url = m.groups()[0]
url = url.replace("&", "&")
break
m = re.search('"downloadUrl":"([^"]+)', line)
if m:
url = m.groups()[0]
url = url.replace("\\u003d", "=")
url = url.replace("\\u0026", "&")
break
m = re.search('<p class="uc-error-subcaption">(.*)</p>', line)
if m:
error = m.groups()[0]
raise RuntimeError(error)
if not url:
raise RuntimeError(
"Cannot retrieve the public link of the file. "
"You may need to change the permission to "
"'Anyone with the link', or have had many accesses."
try:
import gdown
except ImportError:
raise ImportError(
"Downloading files from Google drive requires gdown to be installed, see: https://github.com/wkentaro/gdown"
)
return url

url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, destination)


def get_download_progress_bar(response):
Expand Down
44 changes: 36 additions & 8 deletions torchhd/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.init as init
import torch.utils.data as data
from tqdm import tqdm


import torchhd.functional as functional
import torchhd.datasets as datasets
import torchhd.embeddings as embeddings


Expand Down Expand Up @@ -71,6 +67,7 @@ class Centroid(nn.Module):
>>> output.size()
torch.Size([128, 30])
"""

__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
Expand Down Expand Up @@ -108,6 +105,30 @@ def add(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
"""Adds the input vectors scaled by the lr to the target prototype vectors."""
self.weight.index_add_(0, target, input, alpha=lr)

@torch.no_grad()
def add_adapt(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
r"""Only updates the prototype vectors on wrongly predicted inputs.

Implements the iterative training method as described in `AdaptHD: Adaptive Efficient Training for Brain-Inspired Hyperdimensional Computing <https://ieeexplore.ieee.org/document/8918974>`_.

Subtracts the input from the mispredicted class prototype scaled by the learning rate
and adds the input to the target prototype scaled by the learning rate.
"""
logit = self(input)
pred = logit.argmax(1)
is_wrong = target != pred

# cancel update if all predictions were correct
if is_wrong.sum().item() == 0:
return

input = input[is_wrong]
target = target[is_wrong]
pred = pred[is_wrong]

self.weight.index_add_(0, target, input, alpha=lr)
self.weight.index_add_(0, pred, input, alpha=-lr)

@torch.no_grad()
def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
r"""Only updates the prototype vectors on wrongly predicted inputs.
Expand Down Expand Up @@ -137,23 +158,30 @@ def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
alpha1 = 1.0 - logit.gather(1, target.unsqueeze(1))
alpha2 = logit.gather(1, pred.unsqueeze(1)) - 1.0

self.weight.index_add_(0, target, lr * alpha1 * input)
self.weight.index_add_(0, pred, lr * alpha2 * input)
self.weight.index_add_(0, target, alpha1 * input, alpha=lr)
self.weight.index_add_(0, pred, alpha2 * input, alpha=lr)

@torch.no_grad()
def normalize(self, eps=1e-12) -> None:
"""Transforms all the class prototype vectors into unit vectors.

After calling this, inferences can be made more efficiently by specifying ``dot=True`` in the forward pass.
Training further after calling this method is not advised.
"""
norms = self.weight.norm(dim=1, keepdim=True)

if torch.isclose(norms, torch.zeros_like(norms), equal_nan=True).any():
import warnings

warnings.warn(
"The norm of a prototype vector is nearly zero upon normalizing, this could indicate a bug."
)

norms.clamp_(min=eps)
self.weight.div_(norms)

def extra_repr(self) -> str:
return "in_features={}, out_features={}".format(
self.in_features, self.out_features is not None
self.in_features, self.out_features
)


Expand Down
30 changes: 10 additions & 20 deletions torchhd/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,10 @@ class Multiset:
@overload
def __init__(
self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None
):
...
): ...

@overload
def __init__(self, input: VSATensor, *, size=0):
...
def __init__(self, input: VSATensor, *, size=0): ...

def __init__(self, dim_or_input: Any, vsa: VSAOptions = "MAP", **kwargs):
self.size = kwargs.get("size", 0)
Expand Down Expand Up @@ -334,12 +332,10 @@ class HashTable:
@overload
def __init__(
self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None
):
...
): ...

@overload
def __init__(self, input: VSATensor, *, size=0):
...
def __init__(self, input: VSATensor, *, size=0): ...

def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs):
self.size = kwargs.get("size", 0)
Expand Down Expand Up @@ -501,12 +497,10 @@ class BundleSequence:
@overload
def __init__(
self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None
):
...
): ...

@overload
def __init__(self, input: VSATensor, *, size=0):
...
def __init__(self, input: VSATensor, *, size=0): ...

def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs):
self.size = kwargs.get("size", 0)
Expand Down Expand Up @@ -693,12 +687,10 @@ class BindSequence:
@overload
def __init__(
self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None
):
...
): ...

@overload
def __init__(self, input: VSATensor, *, size=0):
...
def __init__(self, input: VSATensor, *, size=0): ...

def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs):
self.size = kwargs.get("size", 0)
Expand Down Expand Up @@ -861,12 +853,10 @@ def __init__(
directed=False,
device=None,
dtype=None
):
...
): ...

@overload
def __init__(self, input: VSATensor, *, directed=False):
...
def __init__(self, input: VSATensor, *, directed=False): ...

def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs):
self.is_directed = kwargs.get("directed", False)
Expand Down
1 change: 1 addition & 0 deletions torchhd/tensors/bsbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BSBCTensor(VSATensor):

Because the vectors are sparse and have a fixed magnitude, we only represent the index of the non-zero value.
"""

block_size: int
supported_dtypes: Set[torch.dtype] = {
torch.float32,
Expand Down
7 changes: 7 additions & 0 deletions torchhd/tensors/fhrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,5 +395,12 @@ def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor:
else:
magnitude = self_mag * others_mag

if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
import warnings

warnings.warn(
"The norm of a vector is nearly zero, this could indicate a bug."
)

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others) / magnitude
7 changes: 7 additions & 0 deletions torchhd/tensors/hrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,5 +382,12 @@ def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor:
else:
magnitude = self_mag * others_mag

if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
import warnings

warnings.warn(
"The norm of a vector is nearly zero, this could indicate a bug."
)

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others) / magnitude
7 changes: 7 additions & 0 deletions torchhd/tensors/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,5 +368,12 @@ def cosine_similarity(
else:
magnitude = self_mag * others_mag

if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
import warnings

warnings.warn(
"The norm of a vector is nearly zero, this could indicate a bug."
)

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others, dtype=dtype) / magnitude
7 changes: 7 additions & 0 deletions torchhd/tensors/vtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,5 +411,12 @@ def cosine_similarity(self, others: "VTBTensor", *, eps=1e-08) -> Tensor:
else:
magnitude = self_mag * others_mag

if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
import warnings

warnings.warn(
"The norm of a vector is nearly zero, this could indicate a bug."
)

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others) / magnitude
Loading
Loading