diff --git a/torchhd/classify.py b/torchhd/classify.py index 53d12db1..501b996c 100644 --- a/torchhd/classify.py +++ b/torchhd/classify.py @@ -26,7 +26,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import TensorDataset, DataLoader from torch import Tensor, LongTensor from torch.nn.parameter import Parameter @@ -35,6 +34,9 @@ from torchhd.models import Centroid +DataLoader = Iterable[Tuple[Tensor, LongTensor]] + + __all__ = [ "Classifier", "Vanilla", @@ -67,7 +69,7 @@ def __init__( def forward(self, samples: Tensor) -> Tensor: return self.model(self.encoder(samples)) - def fit(self, samples: Tensor, labels: LongTensor) -> Self: + def fit(self, data_loader: DataLoader) -> Self: raise NotImplementedError() def predict(self, samples: Tensor) -> LongTensor: @@ -115,13 +117,9 @@ def __init__( def encoder(self, samples: Tensor) -> Tensor: return functional.hash_table(self.keys.weight, self.levels(samples)).sign() - def fit(self, samples: Tensor, labels: LongTensor) -> Self: - - loader = DataLoader( - TensorDataset(samples, labels), self.batch_size, shuffle=False - ) + def fit(self, data_loader: DataLoader) -> Self: - for samples, labels in loader: + for samples, labels in data_loader: encoded = self.encoder(samples) self.model.add(encoded, labels) @@ -170,14 +168,10 @@ def __init__( def encoder(self, samples: Tensor) -> Tensor: return functional.hash_table(self.keys.weight, self.levels(samples)).sign() - def fit(self, samples: Tensor, labels: LongTensor) -> Self: - - loader = DataLoader( - TensorDataset(samples, labels), self.batch_size, shuffle=True - ) + def fit(self, data_loader: DataLoader) -> Self: for _ in range(self.epochs): - for samples, labels in loader: + for samples, labels in data_loader: encoded = self.encoder(samples) self.model.add_adapt(encoded, labels, lr=self.lr) @@ -214,20 +208,75 @@ def __init__( self.encoder = Sinusoid(n_features, n_dimensions, device=device, dtype=dtype) self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) - def fit(self, samples: Tensor, labels: LongTensor) -> Self: - - loader = DataLoader( - TensorDataset(samples, labels), self.batch_size, shuffle=True - ) + def fit(self, data_loader: DataLoader) -> Self: for _ in range(self.epochs): - for samples, labels in loader: + for samples, labels in data_loader: encoded = self.encoder(samples) self.model.add_online(encoded, labels, lr=self.lr) return self +# Adapted from: https://gitlab.com/biaslab/neuralhd +class NeuralHD(Classifier): + r"""Implements `Scalable edge-based hyperdimensional learning system with brain-like neural adaptation `_.""" + + encoder: Sinusoid + model: Centroid + + def __init__( + self, + n_features: int, + n_dimensions: int, + n_classes: int, + *, + regen_freq: int = 20, + regen_rate: float = 0.04, + epochs: int = 120, + lr: float = 0.37, + batch_size: Union[int, None] = 1024, + device: torch.device = None, + dtype: torch.dtype = None + ) -> None: + super().__init__( + n_features, n_dimensions, n_classes, device=device, dtype=dtype + ) + + self.regen_freq = regen_freq + self.regen_rate = regen_rate + self.epochs = epochs + self.lr = lr + self.batch_size = batch_size + + self.encoder = Sinusoid(n_features, n_dimensions, device=device, dtype=dtype) + self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) + + def fit(self, data_loader: DataLoader) -> Self: + + n_regen_dims = math.ceil(self.regen_rate * self.n_dimensions) + + for samples, labels in data_loader: + encoded = self.encoder(samples) + self.model.add(encoded, labels) + + for epoch_idx in range(1, self.epochs): + for samples, labels in data_loader: + encoded = self.encoder(samples) + self.model.add_adapt(encoded, labels, lr=self.lr) + + if (epoch_idx % self.regen_freq) == (self.regen_freq - 1): + weight = F.normalize(self.model.weight, dim=1) + scores = torch.var(weight, dim=0) + + regen_dims = torch.topk(scores, n_regen_dims, largest=False).indices + self.model.weight.data[:, regen_dims].zero_() + self.encoder.weight.data[regen_dims, :].normal_() + self.encoder.bias.data[regen_dims].uniform_(0, 2 * math.pi) + + return self + + # Adapted from: https://github.com/jwang235/DistHD/ class DistHD(Classifier): r"""Implements `DistHD: A Learner-Aware Dynamic Encoding Method for Hyperdimensional Classification `_.""" @@ -241,12 +290,12 @@ def __init__( n_dimensions: int, n_classes: int, *, - n_regen: int = 20, + regen_freq: int = 20, regen_rate: float = 0.04, alpha: float = 0.5, beta: float = 1, theta: float = 0.25, - epochs: int = 20, + epochs: int = 120, lr: float = 0.05, batch_size: Union[int, None] = 1024, device: torch.device = None, @@ -256,7 +305,7 @@ def __init__( n_features, n_dimensions, n_classes, device=device, dtype=dtype ) - self.n_regen = n_regen + self.regen_freq = regen_freq self.regen_rate = regen_rate self.alpha = alpha self.beta = beta @@ -268,27 +317,23 @@ def __init__( self.encoder = Projection(n_features, n_dimensions, device=device, dtype=dtype) self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) - def fit(self, samples: Tensor, labels: LongTensor) -> Self: + def fit(self, data_loader: DataLoader) -> Self: n_regen_dims = math.ceil(self.regen_rate * self.n_dimensions) - loader = DataLoader( - TensorDataset(samples, labels), self.batch_size, shuffle=True - ) - - for _ in range(self.n_regen): - for _ in range(self.epochs): - for samples, labels in loader: - encoded = self.encoder(samples) - self.model.add_online(encoded, labels, lr=self.lr) + for epoch_idx in range(self.epochs): + for samples, labels in data_loader: + encoded = self.encoder(samples) + self.model.add_online(encoded, labels, lr=self.lr) - scores = 0 - for samples, labels in loader: - scores += self.regen_score(samples, labels) + if (epoch_idx % self.regen_freq) == (self.regen_freq - 1): + scores = 0 + for samples, labels in data_loader: + scores += self.regen_score(samples, labels) - regen_dims = torch.topk(scores, n_regen_dims, largest=False).indices - self.model.weight.data[:, regen_dims].zero_() - self.encoder.weight.data[regen_dims, :].normal_() + regen_dims = torch.topk(scores, n_regen_dims, largest=False).indices + self.model.weight.data[:, regen_dims].zero_() + self.encoder.weight.data[regen_dims, :].normal_() return self