Skip to content

Commit

Permalink
Add AdaptHD centroid update rule, and fix #120
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Mar 5, 2024
1 parent 53828b8 commit 545860c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 7 deletions.
43 changes: 36 additions & 7 deletions torchhd/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.init as init
import torch.utils.data as data
from tqdm import tqdm


import torchhd.functional as functional
import torchhd.datasets as datasets
import torchhd.embeddings as embeddings


Expand Down Expand Up @@ -71,6 +67,7 @@ class Centroid(nn.Module):
>>> output.size()
torch.Size([128, 30])
"""

__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
Expand Down Expand Up @@ -108,6 +105,30 @@ def add(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
"""Adds the input vectors scaled by the lr to the target prototype vectors."""
self.weight.index_add_(0, target, input, alpha=lr)

@torch.no_grad()
def add_adapt(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
r"""Only updates the prototype vectors on wrongly predicted inputs.
Implements the iterative training method as described in `AdaptHD: Adaptive Efficient Training for Brain-Inspired Hyperdimensional Computing <https://ieeexplore.ieee.org/document/8918974>`_.
Subtracts the input from the mispredicted class prototype scaled by the learning rate
and adds the input to the target prototype scaled by the learning rate.
"""
logit = self(input)
pred = logit.argmax(1)
is_wrong = target != pred

# cancel update if all predictions were correct
if is_wrong.sum().item() == 0:
return

input = input[is_wrong]
target = target[is_wrong]
pred = pred[is_wrong]

self.weight.index_add_(0, target, input, alpha=lr)
self.weight.index_add_(0, pred, input, alpha=-lr)

@torch.no_grad()
def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
r"""Only updates the prototype vectors on wrongly predicted inputs.
Expand Down Expand Up @@ -137,8 +158,8 @@ def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
alpha1 = 1.0 - logit.gather(1, target.unsqueeze(1))
alpha2 = logit.gather(1, pred.unsqueeze(1)) - 1.0

self.weight.index_add_(0, target, lr * alpha1 * input)
self.weight.index_add_(0, pred, lr * alpha2 * input)
self.weight.index_add_(0, target, alpha1 * input, alpha=lr)
self.weight.index_add_(0, pred, alpha2 * input, alpha=lr)

@torch.no_grad()
def normalize(self, eps=1e-12) -> None:
Expand All @@ -148,12 +169,20 @@ def normalize(self, eps=1e-12) -> None:
Training further after calling this method is not advised.
"""
norms = self.weight.norm(dim=1, keepdim=True)

if torch.isclose(norms, torch.zeros_like(norms), equal_nan=True).any():
import warnings

warnings.warn(
"The norm of a prototype vector is nearly zero upon normalizing, this could indicate a bug."
)

norms.clamp_(min=eps)
self.weight.div_(norms)

def extra_repr(self) -> str:
return "in_features={}, out_features={}".format(
self.in_features, self.out_features is not None
self.in_features, self.out_features
)


Expand Down
5 changes: 5 additions & 0 deletions torchhd/tensors/fhrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,5 +395,10 @@ def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor:
else:
magnitude = self_mag * others_mag

if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
import warnings

warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.")

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others) / magnitude
5 changes: 5 additions & 0 deletions torchhd/tensors/hrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,5 +382,10 @@ def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor:
else:
magnitude = self_mag * others_mag

if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
import warnings

warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.")

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others) / magnitude
5 changes: 5 additions & 0 deletions torchhd/tensors/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,5 +368,10 @@ def cosine_similarity(
else:
magnitude = self_mag * others_mag

if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
import warnings

warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.")

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others, dtype=dtype) / magnitude
5 changes: 5 additions & 0 deletions torchhd/tensors/vtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,5 +411,10 @@ def cosine_similarity(self, others: "VTBTensor", *, eps=1e-08) -> Tensor:
else:
magnitude = self_mag * others_mag

if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
import warnings

warnings.warn("The norm of a vector is nearly zero, this could indicate a bug.")

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others) / magnitude

0 comments on commit 545860c

Please sign in to comment.