Skip to content

Commit

Permalink
Add normalization function of hypervectors and deprecate hard_quantize (
Browse files Browse the repository at this point in the history
#173)

* Add normalization function of hypervectors and deprecate hard_quantize

* [github-action] formatting fixes

* Test newer python versions

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
mikeheddes and github-actions[bot] authored Sep 1, 2024
1 parent c23e945 commit bdffc21
Show file tree
Hide file tree
Showing 14 changed files with 272 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
timeout-minutes: 20
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.10', '3.11', '3.12']
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
Expand Down
1 change: 1 addition & 0 deletions docs/torchhd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Operations
permute
inverse
negative
normalize
cleanup
randsel
multirandsel
Expand Down
2 changes: 2 additions & 0 deletions torchhd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
permute,
inverse,
negative,
normalize,
cleanup,
create_random_permute,
randsel,
Expand Down Expand Up @@ -109,6 +110,7 @@
"permute",
"inverse",
"negative",
"normalize",
"cleanup",
"create_random_permute",
"randsel",
Expand Down
47 changes: 47 additions & 0 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
from torch import LongTensor, FloatTensor, Tensor
from collections import deque
import warnings

from torchhd.tensors.base import VSATensor
from torchhd.tensors.bsc import BSCTensor
Expand All @@ -50,6 +51,7 @@
"permute",
"inverse",
"negative",
"normalize",
"cleanup",
"create_random_permute",
"hard_quantize",
Expand Down Expand Up @@ -673,6 +675,11 @@ def bundle(input: VSATensor, other: VSATensor) -> VSATensor:
\oplus: \mathcal{H} \times \mathcal{H} \to \mathcal{H}
.. note::
This operation does not normalize the resulting hypervectors.
Normalized hypervectors can be obtained with :func:`~torchhd.normalize`.
Args:
input (VSATensor): input hypervector
other (VSATensor): other input hypervector
Expand Down Expand Up @@ -885,6 +892,12 @@ def hard_quantize(input: Tensor):
tensor([ 1., -1., -1., -1., 1., -1.])
"""
warnings.warn(
"torchhd.hard_quantize is deprecated, consider using torchhd.normalize instead.",
DeprecationWarning,
stacklevel=2,
)

# Make sure that the output tensor has the same dtype and device
# as the input tensor.
positive = torch.tensor(1.0, dtype=input.dtype, device=input.device)
Expand All @@ -893,6 +906,35 @@ def hard_quantize(input: Tensor):
return torch.where(input > 0, positive, negative)


def normalize(input: VSATensor) -> VSATensor:
"""Normalize the input hypervectors.
Args:
input (Tensor): input tensor
Shapes:
- Input: :math:`(*)`
- Output: :math:`(*)`
Examples::
>>> x = torchhd.random(4, 10, "MAP").multibundle()
>>> x
MAPTensor([ 0., 0., -2., -2., 2., -2., 2., 2., 2., 0.])
>>> torchhd.normalize(x)
MAPTensor([-1., -1., -1., -1., 1., -1., 1., 1., 1., -1.])
>>> x = torchhd.random(4, 10, "HRR").multibundle()
>>> x
HRRTensor([-0.2999, 0.4686, 0.1797, -0.4830, 0.2718, -0.3663, 0.3079, 0.2558, -1.5157, -0.5196])
>>> torchhd.normalize(x)
HRRTensor([-0.1601, 0.2501, 0.0959, -0.2578, 0.1451, -0.1955, 0.1643, 0.1365, -0.8089, -0.2773])
"""
input = ensure_vsa_tensor(input)
return input.normalize()


def dot_similarity(input: VSATensor, others: VSATensor, **kwargs) -> VSATensor:
"""Dot product between the input vector and each vector in others.
Expand Down Expand Up @@ -1037,6 +1079,11 @@ def multiset(input: VSATensor) -> VSATensor:
\bigoplus_{i=0}^{n-1} V_i
.. note::
This operation does not normalize the resulting or intermediate hypervectors.
Normalized hypervectors can be obtained with :func:`~torchhd.normalize`.
Args:
input (VSATensor): input hypervector tensor
Expand Down
18 changes: 9 additions & 9 deletions torchhd/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,22 @@ def read(self, query: Tensor) -> VSATensor:
"""
# first dims from query, last dim from value
out_shape = (*query.shape[:-1], self.value_dim)
out_shape = tuple(query.shape[:-1]) + (self.value_dim,)

if query.dim() == 1:
query = query.unsqueeze(0)

# make sure to have at least two dimension for index_add_
intermediate_shape = (*query.shape[:-1], self.value_dim)
intermediate_shape = tuple(query.shape[:-1]) + (self.value_dim,)

similarity = query @ self.keys.T
is_active = similarity >= self.threshold

# sparse matrix-vector multiplication
r_indices, v_indices = is_active.nonzero().T
read = query.new_zeros(intermediate_shape)
read.index_add_(0, r_indices, self.values[v_indices])
return read.view(out_shape)
# Sparse matrix-vector multiplication.
to_indices, from_indices = is_active.nonzero().T

read = torch.zeros(intermediate_shape, dtype=query.dtype, device=query.device)
read.index_add_(0, to_indices, self.values[from_indices])
return read.view(out_shape).as_subclass(functional.MAPTensor)

@torch.no_grad()
def write(self, keys: Tensor, values: Tensor) -> None:
Expand All @@ -161,7 +161,7 @@ def write(self, keys: Tensor, values: Tensor) -> None:
similarity = keys @ self.keys.T
is_active = similarity >= self.threshold

# sparse outer product and addition
# Sparse outer product and addition.
from_indices, to_indices = is_active.nonzero().T
self.values.index_add_(0, to_indices, values[from_indices])

Expand Down
6 changes: 5 additions & 1 deletion torchhd/tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
from typing import List, Set, Any
from typing import List, Set
import torch
from torch import Tensor

Expand Down Expand Up @@ -131,6 +131,10 @@ def permute(self, shifts: int = 1) -> "VSATensor":
"""Permute the hypervector"""
raise NotImplementedError

def normalize(self) -> "VSATensor":
"""Normalize the hypervector"""
raise NotImplementedError

def dot_similarity(self, others: "VSATensor") -> Tensor:
"""Inner product with other hypervectors"""
raise NotImplementedError
Expand Down
20 changes: 20 additions & 0 deletions torchhd/tensors/bsbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,26 @@ def permute(self, shifts: int = 1) -> "BSBCTensor":
"""
return torch.roll(self, shifts=shifts, dims=-1)

def normalize(self) -> "BSBCTensor":
r"""Normalize the hypervector.
Each operation on BSBC hypervectors ensures it remains normalized, so this returns a copy of self.
Shapes:
- Self: :math:`(*)`
- Output: :math:`(*)`
Examples::
>>> x = torchhd.BSBCTensor.random(4, 6, block_size=64).multibundle()
>>> x
BSBCTensor([28, 27, 20, 44, 57, 18])
>>> x.normalize()
BSBCTensor([28, 27, 20, 44, 57, 18])
"""
return self.clone()

def dot_similarity(self, others: "BSBCTensor", *, dtype=None) -> Tensor:
"""Inner product with other hypervectors"""
if dtype is None:
Expand Down
20 changes: 20 additions & 0 deletions torchhd/tensors/bsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,26 @@ def permute(self, shifts: int = 1) -> "BSCTensor":
"""
return super().roll(shifts=shifts, dims=-1)

def normalize(self) -> "BSCTensor":
r"""Normalize the hypervector.
Each operation on BSC hypervectors ensures it remains normalized, so this returns a copy of self.
Shapes:
- Self: :math:`(*)`
- Output: :math:`(*)`
Examples::
>>> x = torchhd.BSCTensor.random(4, 6).multibundle()
>>> x
BSCTensor([ True, False, False, False, False, False])
>>> x.normalize()
BSCTensor([ True, False, False, False, False, False])
"""
return self.clone()

def dot_similarity(self, others: "BSCTensor", *, dtype=None) -> Tensor:
"""Inner product with other hypervectors."""
device = self.device
Expand Down
23 changes: 23 additions & 0 deletions torchhd/tensors/fhrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,29 @@ def permute(self, shifts: int = 1) -> "FHRRTensor":
"""
return torch.roll(self, shifts=shifts, dims=-1)

def normalize(self) -> "FHRRTensor":
r"""Normalize the hypervector.
The normalization preserves the element phase but sets the magnitude to one.
Shapes:
- Self: :math:`(*)`
- Output: :math:`(*)`
Examples::
>>> x = torchhd.FHRRTensor.random(4, 6).multibundle()
>>> x
FHRRTensor([ 1.0878+0.9382j, 2.0057-1.5603j, -2.2828-1.4410j, 1.9643-1.8269j,
-0.9710-0.0120j, -0.7432+0.6956j])
>>> x.normalize()
FHRRTensor([ 0.7572+0.6531j, 0.7893-0.6140j, -0.8456-0.5338j, 0.7322-0.6810j,
-0.9999-0.0124j, -0.7301+0.6833j])
"""
angle = self.angle()
return torch.complex(angle.cos(), angle.sin())

def dot_similarity(self, others: "FHRRTensor") -> Tensor:
"""Inner product with other hypervectors"""
if others.dim() >= 2:
Expand Down
28 changes: 25 additions & 3 deletions torchhd/tensors/hrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
from torch import Tensor
from torch.fft import fft, ifft
import torch.nn.functional as F
import math

from torchhd.tensors.base import VSATensor
Expand Down Expand Up @@ -155,7 +156,7 @@ def random(
) -> "HRRTensor":
"""Creates a set of random independent hypervectors.
The resulting hypervectors are sampled at random from a normal with mean 0 and standard deviation 1/dimensions.
The resulting hypervectors are sampled uniformly at random from the (dimensions - 1)-unit sphere.
Args:
num_vectors (int): the number of hypervectors to generate.
Expand Down Expand Up @@ -186,8 +187,8 @@ def random(
raise ValueError(f"{name} vectors must be one of dtype {options}.")

size = (num_vectors, dimensions)
result = torch.empty(size, dtype=dtype, device=device)
result.normal_(0, 1.0 / math.sqrt(dimensions), generator=generator)
result = torch.randn(size, dtype=dtype, device=device, generator=generator)
result = F.normalize(result, p=2, dim=-1)

result.requires_grad = requires_grad
return result.as_subclass(cls)
Expand Down Expand Up @@ -362,6 +363,27 @@ def permute(self, shifts: int = 1) -> "HRRTensor":
"""
return torch.roll(self, shifts=shifts, dims=-1)

def normalize(self) -> "HRRTensor":
r"""Normalize the hypervector.
The normalization preserves the direction of the hypervector but makes it unit norm.
This means that it is mapped to the closest point on the unit sphere.
Shapes:
- Self: :math:`(*)`
- Output: :math:`(*)`
Examples::
>>> x = torchhd.HRRTensor.random(4, 6).multibundle()
>>> x
HRRTensor([-0.6150, 0.4260, 0.6975, 0.3110, 0.9387, 0.0696])
>>> x.normalize()
HRRTensor([-0.4317, 0.2990, 0.4897, 0.2184, 0.6590, 0.0489])
"""
return F.normalize(self, p=2, dim=-1)

def dot_similarity(self, others: "HRRTensor") -> Tensor:
"""Inner product with other hypervectors"""
if others.dim() >= 2:
Expand Down
27 changes: 24 additions & 3 deletions torchhd/tensors/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#
import torch
from torch import Tensor
import torch.nn.functional as F
from typing import Set

from torchhd.tensors.base import VSATensor
Expand All @@ -38,8 +37,6 @@ class MAPTensor(VSATensor):
supported_dtypes: Set[torch.dtype] = {
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
torch.int8,
torch.int16,
torch.int32,
Expand Down Expand Up @@ -318,6 +315,30 @@ def permute(self, shifts: int = 1) -> "MAPTensor":
"""
return torch.roll(self, shifts=shifts, dims=-1)

def normalize(self) -> "MAPTensor":
r"""Normalize the hypervector.
The normalization sets all positive entries to +1 and all other entries to -1.
Shapes:
- Self: :math:`(*)`
- Output: :math:`(*)`
Examples::
>>> x = torchhd.MAPTensor.random(4, 6).multibundle()
>>> x
MAPTensor([-2., -4., 4., 0., 4., -2.])
>>> x.normalize()
MAPTensor([-1., -1., 1., -1., 1., -1.])
"""
# Ensure that the output tensor has the same dtype and device as the self tensor.
positive = torch.tensor(1.0, dtype=self.dtype, device=self.device)
negative = torch.tensor(-1.0, dtype=self.dtype, device=self.device)

return torch.where(self > 0, positive, negative)

def clipping(self, kappa) -> "MAPTensor":
r"""Performs the clipping function that clips the lower and upper values.
Expand Down
Loading

0 comments on commit bdffc21

Please sign in to comment.