From 1f7e9dbf92a4a1be0d4285b6ccd7231afa85c2a3 Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 7 Jun 2022 16:32:40 -0700 Subject: [PATCH] Add complex hypervector types (#81) * WIP implement complex hypervectors * Implement complex level and circular hypervector sets * Add complex tests * Simplify plotting utility * Use unbind function * Add unbind to docs --- docs/functional.rst | 1 + torchhd/__init__.py | 2 + torchhd/functional.py | 310 ++++++++++++++++----- torchhd/structures.py | 28 +- torchhd/tests/basis_hv/test_circular_hv.py | 45 ++- torchhd/tests/basis_hv/test_identity_hv.py | 6 - torchhd/tests/basis_hv/test_level_hv.py | 43 ++- torchhd/tests/basis_hv/test_random_hv.py | 49 ++-- torchhd/tests/test_encodings.py | 44 +-- torchhd/tests/test_operations.py | 18 -- torchhd/tests/utils.py | 2 +- torchhd/utils.py | 4 +- 12 files changed, 345 insertions(+), 207 deletions(-) diff --git a/docs/functional.rst b/docs/functional.rst index 4d363e85..1766cc48 100644 --- a/docs/functional.rst +++ b/docs/functional.rst @@ -28,6 +28,7 @@ Operations :template: function.rst bind + unbind bundle permute cleanup diff --git a/torchhd/__init__.py b/torchhd/__init__.py index a1e03fcb..e6e40eba 100644 --- a/torchhd/__init__.py +++ b/torchhd/__init__.py @@ -10,6 +10,7 @@ level_hv, circular_hv, bind, + unbind, bundle, permute, ) @@ -27,6 +28,7 @@ "level_hv", "circular_hv", "bind", + "unbind", "bundle", "permute", ] diff --git a/torchhd/functional.py b/torchhd/functional.py index 39f7a967..20c65a77 100644 --- a/torchhd/functional.py +++ b/torchhd/functional.py @@ -12,6 +12,7 @@ "level_hv", "circular_hv", "bind", + "unbind", "bundle", "permute", "cleanup", @@ -56,20 +57,42 @@ def identity_hv( Examples:: - >>> functional.identity_hv(2, 3) - tensor([[ 1., 1., 1.], - [ 1., 1., 1.]]) + >>> functional.identity_hv(3, 6) + tensor([[1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1.]]) + + >>> functional.identity_hv(3, 6, dtype=torch.bool) + tensor([[False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False]]) + + >>> functional.identity_hv(3, 6, dtype=torch.long) + tensor([[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]) + + >>> functional.identity_hv(3, 6, dtype=torch.complex64) + tensor([[1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j], + [1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j], + [1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j]]) """ if dtype is None: dtype = torch.get_default_dtype() - if dtype in {torch.complex64, torch.complex128}: - raise NotImplementedError("Complex hypervectors are not supported yet.") - if dtype == torch.uint8: raise ValueError("Unsigned integer hypervectors are not supported.") + if dtype in {torch.complex64, torch.complex128}: + return torch.full( + (num_embeddings, embedding_dim), + 1 + 0j, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + if dtype == torch.bool: return torch.zeros( num_embeddings, @@ -107,7 +130,7 @@ def random_hv( Args: num_embeddings (int): the number of hypervectors to generate. embedding_dim (int): the dimensionality of the hypervectors. - sparsity (float, optional): the expected fraction of elements to be +1. Default: ``0.5``. + sparsity (float, optional): the expected fraction of elements to be in-active. Has no effect on complex hypervectors. Default: ``0.5``. generator (``torch.Generator``, optional): a pseudorandom number generator for sampling. dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None``, uses a global default (see ``torch.set_default_tensor_type()``). device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. @@ -116,25 +139,57 @@ def random_hv( Examples:: >>> functional.random_hv(2, 5) - tensor([[-1., 1., -1., -1., 1.], - [ 1., -1., -1., -1., -1.]]) + tensor([[ 1., -1., -1., 1., -1., 1.], + [-1., 1., 1., 1., 1., 1.], + [-1., 1., 1., 1., 1., -1.]]) + >>> functional.random_hv(2, 5, sparsity=0.9) - tensor([[ 1., 1., 1., -1., 1.], - [ 1., 1., 1., 1., 1.]]) - >>> functional.random_hv(2, 5, dtype=torch.long) - tensor([[ 1, -1, 1, 1, 1], - [ 1, 1, -1, -1, 1]]) + tensor([[ 1., 1., 1., -1., -1., 1.], + [-1., 1., 1., 1., 1., 1.], + [ 1., 1., 1., -1., 1., 1.]]) + + >>> functional.random_hv(3, 6, dtype=torch.long) + tensor([[ 1, 1, 1, 1, 1, -1], + [ 1, -1, 1, 1, -1, 1], + [ 1, 1, -1, 1, 1, -1]]) + + >>> functional.random_hv(3, 6, dtype=torch.bool) + tensor([[ True, False, False, False, False, True], + [ True, True, False, True, True, False], + [False, False, False, True, False, True]]) + + >>> functional.random_hv(3, 6, dtype=torch.bool) + tensor([[ True, False, False, False, False, True], + [ True, True, False, True, True, False], + [False, False, False, True, False, True]]) + + >>> functional.random_hv(3, 6, dtype=torch.complex64) + tensor([[-0.9849-0.1734j, 0.1267+0.9919j, -0.9160+0.4012j, 0.5063-0.8624j, 0.9898-0.1424j, -0.4378+0.8991j], + [-0.4516+0.8922j, 0.7086-0.7056j, 0.8579+0.5138j, 0.9629-0.2699j, -0.2023+0.9793j, -0.9787-0.2052j], + [-0.2974+0.9548j, -0.9936+0.1127j, -0.9740+0.2264j, -0.9999+0.0113j, 0.4434-0.8963j, 0.3888+0.9213j]]) """ if dtype is None: dtype = torch.get_default_dtype() - if dtype in {torch.complex64, torch.complex128}: - raise NotImplementedError("Complex hypervectors are not supported yet.") - if dtype == torch.uint8: raise ValueError("Unsigned integer hypervectors are not supported.") + if dtype in {torch.complex64, torch.complex128}: + dtype = torch.float if dtype == torch.complex64 else torch.double + + angle = torch.empty( + num_embeddings, embedding_dim, dtype=dtype, device=device + ) + angle.uniform_(-math.pi, math.pi) + magnitude = torch.ones( + num_embeddings, embedding_dim, dtype=dtype, device=device + ) + + result = torch.polar(magnitude, angle) + result.requires_grad = requires_grad + return result + select = torch.empty( ( num_embeddings, @@ -173,7 +228,7 @@ def level_hv( Args: num_embeddings (int): the number of hypervectors to generate. embedding_dim (int): the dimensionality of the hypervectors. - sparsity (float, optional): the expected fraction of elements to be +1. Default: ``0.5``. + sparsity (float, optional): the expected fraction of elements to be in-active. Has no effect on complex hypervectors. Default: ``0.5``. randomness (float, optional): r-value to interpolate between level at ``0.0`` and random-hypervectors at ``1.0``. Default: ``0.0``. generator (``torch.Generator``, optional): a pseudorandom number generator for sampling. dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None``, uses a global default (see ``torch.set_default_tensor_type()``). @@ -189,16 +244,32 @@ def level_hv( [ 1., -1., 1., -1., -1., -1., 1., -1., -1., 1.], [ 1., -1., 1., -1., -1., 1., 1., 1., -1., 1.]]) + >>> functional.level_hv(5, 8, dtype=torch.bool) + tensor([[ True, False, False, True, False, False, False, True], + [ True, True, False, True, True, False, False, True], + [ True, True, False, True, True, False, False, False], + [ True, True, False, True, True, False, True, False], + [ True, True, False, True, True, False, True, False]]) + + >>> functional.level_hv(5, 6, dtype=torch.complex64) + tensor([[ 9.4413e-01+0.3296j, -9.5562e-01-0.2946j, 7.9306e-04+1.0000j, -8.8154e-01-0.4721j, -6.6328e-01+0.7484j, -8.6131e-01-0.5081j], + [ 9.4413e-01+0.3296j, -9.5562e-01-0.2946j, 7.9306e-04+1.0000j, -8.8154e-01-0.4721j, -6.6328e-01+0.7484j, -8.6131e-01-0.5081j], + [ 9.4413e-01+0.3296j, -9.5562e-01-0.2946j, 7.9306e-04+1.0000j, -8.8154e-01-0.4721j, -6.6328e-01+0.7484j, -8.6131e-01-0.5081j], + [-9.9803e-01+0.0627j, -9.5562e-01-0.2946j, 7.9306e-04+1.0000j, 9.9992e-01+0.0126j, -6.6328e-01+0.7484j, -8.6131e-01-0.5081j], + [-9.9803e-01+0.0627j, -8.5366e-01+0.5208j, 6.5232e-01-0.7579j, 9.9992e-01+0.0126j, 3.6519e-01+0.9309j, 9.7333e-01-0.2294j]]) """ if dtype is None: dtype = torch.get_default_dtype() - if dtype in {torch.complex64, torch.complex128}: - raise NotImplementedError("Complex hypervectors are not supported yet.") - if dtype == torch.uint8: raise ValueError("Unsigned integer hypervectors are not supported.") + # convert from normalized "randomness" variable r to number of orthogonal vectors sets "span" + levels_per_span = (1 - randomness) * (num_embeddings - 1) + randomness * 1 + # must be at least one to deal with the case that num_embeddings is less than 2 + levels_per_span = max(levels_per_span, 1) + span = (num_embeddings - 1) / levels_per_span + hv = torch.empty( num_embeddings, embedding_dim, @@ -206,11 +277,6 @@ def level_hv( device=device, ) - # convert from normalized "randomness" variable r to number of orthogonal vectors sets "span" - levels_per_span = (1 - randomness) * (num_embeddings - 1) + randomness * 1 - # must be at least one to deal with the case that num_embeddings is less than 2 - levels_per_span = max(levels_per_span, 1) - span = (num_embeddings - 1) / levels_per_span # generate the set of orthogonal vectors within the level vector set span_hv = random_hv( int(math.ceil(span + 1)), @@ -274,7 +340,7 @@ def circular_hv( Args: num_embeddings (int): the number of hypervectors to generate. embedding_dim (int): the dimensionality of the hypervectors. - sparsity (float, optional): the expected fraction of elements to be +1. Default: ``0.5``. + sparsity (float, optional): the expected fraction of elements to be in-active. Has no effect on complex hypervectors. Default: ``0.5``. randomness (float, optional): r-value to interpolate between circular at ``0.0`` and random-hypervectors at ``1.0``. Default: ``0.0``. generator (``torch.Generator``, optional): a pseudorandom number generator for sampling. dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None``, uses a global default (see ``torch.set_default_tensor_type()``). @@ -293,13 +359,34 @@ def circular_hv( [ 1., 1., 1., -1., 1., 1., -1., 1., -1., -1.], [ 1., 1., -1., -1., 1., 1., -1., 1., -1., -1.]]) + >>> functional.circular_hv(10, 8, dtype=torch.bool) + tensor([[False, True, False, False, True, False, True, True], + [False, True, False, False, True, False, True, True], + [False, True, False, False, True, False, True, True], + [False, True, False, False, True, False, True, False], + [False, True, False, False, False, False, False, False], + [ True, True, False, True, False, False, False, False], + [ True, True, False, True, False, False, False, False], + [ True, True, False, True, False, False, False, False], + [ True, True, False, True, False, False, False, True], + [ True, True, False, True, True, False, True, True]]) + + >>> functional.circular_hv(10, 6, dtype=torch.complex64) + tensor([[ 0.0691+0.9976j, -0.1847+0.9828j, -0.4434-0.8963j, -0.8287+0.5596j, -0.8357-0.5493j, -0.5358+0.8443j], + [ 0.0691+0.9976j, -0.1847+0.9828j, -0.4434-0.8963j, -0.8287+0.5596j, 0.9427-0.3336j, -0.5358+0.8443j], + [ 0.0691+0.9976j, -0.1847+0.9828j, -0.4434-0.8963j, -0.0339-0.9994j, 0.9427-0.3336j, -0.6510-0.7591j], + [ 0.0691+0.9976j, -0.3881+0.9216j, -0.4434-0.8963j, -0.0339-0.9994j, 0.9427-0.3336j, -0.6510-0.7591j], + [-0.6866-0.7271j, -0.3881+0.9216j, -0.4434-0.8963j, -0.0339-0.9994j, 0.9427-0.3336j, -0.6510-0.7591j], + [-0.6866-0.7271j, -0.3881+0.9216j, -0.7324+0.6809j, -0.0339-0.9994j, 0.9427-0.3336j, -0.6510-0.7591j], + [-0.6866-0.7271j, -0.3881+0.9216j, -0.7324+0.6809j, -0.0339-0.9994j, -0.8357-0.5493j, -0.6510-0.7591j], + [-0.6866-0.7271j, -0.3881+0.9216j, -0.7324+0.6809j, -0.8287+0.5596j, -0.8357-0.5493j, -0.5358+0.8443j], + [-0.6866-0.7271j, -0.1847+0.9828j, -0.7324+0.6809j, -0.8287+0.5596j, -0.8357-0.5493j, -0.5358+0.8443j], + [ 0.0691+0.9976j, -0.1847+0.9828j, -0.7324+0.6809j, -0.8287+0.5596j, -0.8357-0.5493j, -0.5358+0.8443j]]) + """ if dtype is None: dtype = torch.get_default_dtype() - if dtype in {torch.complex64, torch.complex128}: - raise NotImplementedError("Complex hypervectors are not supported yet.") - if dtype == torch.uint8: raise ValueError("Unsigned integer hypervectors are not supported.") @@ -362,7 +449,7 @@ def circular_hv( temp_hv = torch.where(threshold_v[span_idx] < t, span_start_hv, span_end_hv) - mutation_history.append(bind(temp_hv, mutation_hv)) + mutation_history.append(unbind(temp_hv, mutation_hv)) mutation_hv = temp_hv if i % 2 == 0: @@ -370,7 +457,7 @@ def circular_hv( for i in range(num_embeddings + 1, num_embeddings * 2 - 1): mut = mutation_history.popleft() - mutation_hv = bind(mutation_hv, mut) + mutation_hv = unbind(mutation_hv, mut) if i % 2 == 0: hv[i // 2] = mutation_hv @@ -411,8 +498,49 @@ def bind(input: Tensor, other: Tensor) -> Tensor: """ dtype = input.dtype - if torch.is_complex(input): - raise NotImplementedError("Complex hypervectors are not supported yet.") + if dtype == torch.uint8: + raise ValueError("Unsigned integer hypervectors are not supported.") + + if dtype == torch.bool: + return torch.logical_xor(input, other) + + return torch.mul(input, other) + + +def unbind(input: Tensor, other: Tensor) -> Tensor: + r"""Inverse of the binding operation. + + See :func:`~torchhd.functional.bind`. + + Aliased as ``torchhd.unbind``. + + Args: + input (Tensor): input hypervector + other (Tensor): other input hypervector + + Shapes: + - Input: :math:`(*)` + - Other: :math:`(*)` + - Output: :math:`(*)` + + Examples:: + + >>> x = functional.random_hv(2, 6) + >>> x + tensor([[-1., 1., 1., -1., -1., 1.], + [-1., -1., 1., 1., 1., -1.]]) + >>> functional.unbind(functional.bind(x[0], x[1]), x[1]) + tensor([-1., 1., 1., -1., -1., 1.]) + + >>> x = functional.random_hv(2, 6, dtype=torch.complex64) + >>> x + tensor([[-0.6510+0.7591j, -0.9675+0.2528j, 0.7358-0.6772j, -0.1791-0.9838j, -0.9874-0.1585j, -0.3726+0.9280j], + [ 0.1429+0.9897j, -0.9173+0.3983j, -0.4906+0.8714j, 0.4710-0.8821j, 0.6478+0.7618j, 0.8753+0.4836j]]) + >>> functional.unbind(functional.bind(x[0], x[1]), x[1]) + tensor([-0.6510+0.7591j, -0.9675+0.2528j, 0.7358-0.6772j, -0.1791-0.9838j, -0.9874-0.1585j, -0.3726+0.9280j]) + + """ + dtype = input.dtype if dtype == torch.uint8: raise ValueError("Unsigned integer hypervectors are not supported.") @@ -420,6 +548,9 @@ def bind(input: Tensor, other: Tensor) -> Tensor: if dtype == torch.bool: return torch.logical_xor(input, other) + if torch.is_complex(input): + return torch.mul(input, other.conj()) + return torch.mul(input, other) @@ -456,9 +587,6 @@ def bundle(input: Tensor, other: Tensor, *, tie: BoolTensor = None) -> Tensor: """ dtype = input.dtype - if torch.is_complex(input): - raise NotImplementedError("Complex hypervectors are not supported yet.") - if dtype == torch.uint8: raise ValueError("Unsigned integer hypervectors are not supported.") @@ -559,54 +687,95 @@ def hard_quantize(input: Tensor): return torch.where(input > 0, positive, negative) -def cosine_similarity(input: Tensor, others: Tensor) -> Tensor: - """Cosine similarity between the input vector and each vector in others. +def dot_similarity(input: Tensor, others: Tensor) -> Tensor: + """Dot product between the input vector and each vector in others. Args: - input (Tensor): one-dimensional tensor - others (Tensor): two-dimensional tensor + input (Tensor): hypervectors to compare against others + others (Tensor): hypervectors to compare with Shapes: - - Input: :math:`(d)` - - Others: :math:`(n, d)` - - Output: :math:`(n)` + - Input: :math:`(*, d)` + - Others: :math:`(n, d)` or :math:`(d)` + - Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others Examples:: - >>> x = functional.random_hv(2, 3) + >>> x = functional.random_hv(3, 6) >>> x - tensor([[-1., -1., 1.], - [ 1., 1., -1.]]) - >>> functional.cosine_similarity(x[0], x) - tensor([ 1., -1.]) + tensor([[ 1., -1., -1., 1., -1., -1.], + [ 1., -1., -1., -1., 1., -1.], + [-1., 1., 1., -1., 1., -1.]]) + >>> functional.dot_similarity(x, x) + tensor([[ 6., 2., -4.], + [ 2., 6., 0.], + [-4., 0., 6.]]) + + >>> x = functional.random_hv(3, 6, dtype=torch.complex64) + >>> x + tensor([[ 0.5931+0.8051j, -0.7391+0.6736j, -0.9725+0.2328j, -0.9290+0.3701j, -0.8220+0.5696j, 0.9757-0.2190j], + [-0.1053+0.9944j, 0.6918-0.7221j, -0.6242+0.7813j, -0.9580-0.2869j, 0.4799+0.8773j, -0.4127+0.9109j], + [ 0.4230-0.9061j, -0.9658+0.2592j, 0.9961-0.0883j, -0.3829+0.9238j, -0.2551-0.9669j, 0.7827-0.6224j]]) + >>> functional.dot_similarity(x, x) + tensor([[ 6.0000, 0.8164, 0.6771], + [ 0.8164, 6.0000, -4.2506], + [ 0.6771, -4.2506, 6.0000]]) """ - return F.cosine_similarity(input, others) + if torch.is_complex(input): + return F.linear(input, others.conj()).real + return F.linear(input, others) -def dot_similarity(input: Tensor, others: Tensor) -> Tensor: - """Dot product between the input vector and each vector in others. + +def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor: + """Cosine similarity between the input vector and each vector in others. Args: - input (Tensor): one-dimensional tensor - others (Tensor): two-dimensional tensor + input (Tensor): hypervectors to compare against others + others (Tensor): hypervectors to compare with Shapes: - - Input: :math:`(d)` - - Others: :math:`(n, d)` - - Output: :math:`(n)` + - Input: :math:`(*, d)` + - Others: :math:`(n, d)` or :math:`(d)` + - Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others Examples:: - >>> x = functional.random_hv(2, 3) + >>> x = functional.random_hv(3, 6) >>> x - tensor([[ 1., -1., 1.], - [ 1., 1., 1.]]) - >>> functional.dot_similarity(x[0], x) - tensor([3., 1.]) + tensor([[-1., 1., 1., -1., 1., -1.], + [ 1., 1., 1., 1., 1., 1.], + [ 1., 1., 1., -1., 1., -1.]]) + >>> functional.cosine_similarity(x, x) + tensor([[1.0000, 0.0000, 0.6667], + [0.0000, 1.0000, 0.3333], + [0.6667, 0.3333, 1.0000]]) + + >>> x = functional.random_hv(3, 6, dtype=torch.complex64) + >>> x + tensor([[-0.5578-0.8299j, -0.0043-1.0000j, -0.0181+0.9998j, 0.1107+0.9939j, -0.8215-0.5702j, -0.4585+0.8887j], + [-0.7400-0.6726j, 0.6895-0.7243j, -0.8760+0.4823j, -0.4582-0.8889j, -0.6128+0.7903j, -0.4839-0.8751j], + [-0.7839+0.6209j, -0.9239-0.3827j, -0.9961-0.0884j, 0.4614+0.8872j, -0.8546+0.5193j, -0.5468-0.8372j]]) + >>> functional.cosine_similarity(x, x) + tensor([[1.0000, 0.1255, 0.1806], + [0.1255, 1.0000, 0.2607], + [0.1806, 0.2607, 1.0000]]) """ - return F.linear(input, others) + if torch.is_complex(input): + input_mag = torch.real(input * input.conj()).sum(dim=-1).sqrt() + others_mag = torch.real(others * others.conj()).sum(dim=-1).sqrt() + else: + input_mag = torch.sum(input * input, dim=-1).sqrt() + others_mag = torch.sum(others * others, dim=-1).sqrt() + + if input.dim() > 1: + magnitude = input_mag.unsqueeze(-1) * others_mag.unsqueeze(0) + else: + magnitude = input_mag * others_mag + + return dot_similarity(input, others) / (magnitude + eps) def hamming_similarity(input: Tensor, others: Tensor) -> LongTensor: @@ -666,9 +835,6 @@ def multiset(input: Tensor) -> Tensor: dim = -2 dtype = input.dtype - if dtype in {torch.complex64, torch.complex128}: - raise NotImplementedError("Complex hypervectors are not supported yet.") - if dtype == torch.uint8: raise ValueError("Unsigned integer hypervectors are not supported.") @@ -714,14 +880,14 @@ def multibind(input: Tensor) -> Tensor: tensor([ 1., 1., -1.]) """ - if input.dtype in {torch.complex64, torch.complex128}: - raise NotImplementedError("Complex hypervectors are not supported yet.") + dtype = input.dtype + dim = -2 - if input.dtype == torch.uint8: + if dtype == torch.uint8: raise ValueError("Unsigned integer hypervectors are not supported.") - if input.dtype == torch.bool: - hvs = torch.unbind(input, -2) + if dtype == torch.bool: + hvs = torch.unbind(input, dim) result = hvs[0] for i in range(1, len(hvs)): @@ -729,7 +895,7 @@ def multibind(input: Tensor) -> Tensor: return result - return torch.prod(input, dim=-2, dtype=input.dtype) + return torch.prod(input, dim=dim, dtype=dtype) def cross_product(input: Tensor, other: Tensor) -> Tensor: diff --git a/torchhd/structures.py b/torchhd/structures.py index 314d3d38..db554f2b 100644 --- a/torchhd/structures.py +++ b/torchhd/structures.py @@ -212,10 +212,10 @@ def contains(self, input: Tensor) -> Tensor: Examples:: >>> M.contains(letters_hv[0]) - tensor([0.4575]) + tensor(0.4575) """ - return functional.cosine_similarity(input, self.value.unsqueeze(0)) + return functional.cosine_similarity(input, self.value) def __len__(self) -> int: """Returns the size of the multiset. @@ -363,7 +363,7 @@ def get(self, key: Tensor) -> Tensor: tensor([ 1., -1., 1., ..., -1., 1., -1.]) """ - return functional.bind(self.value, key) + return functional.unbind(self.value, key) def replace(self, key: Tensor, old: Tensor, new: Tensor) -> None: """Replace the value from key-value pair in the hash table. @@ -711,7 +711,7 @@ def pop(self, input: Tensor) -> None: """ self.size -= 1 - self.value = functional.bind(self.value, input) + self.value = functional.unbind(self.value, input) self.value = functional.permute(self.value, shifts=-1) def popleft(self, input: Tensor) -> None: @@ -727,7 +727,7 @@ def popleft(self, input: Tensor) -> None: """ self.size -= 1 rotated_input = functional.permute(input, shifts=len(self)) - self.value = functional.bind(self.value, rotated_input) + self.value = functional.unbind(self.value, rotated_input) def replace(self, index: int, old: Tensor, new: Tensor) -> None: """Replace the old hypervector value from the given index, for the new hypervector value. @@ -744,7 +744,7 @@ def replace(self, index: int, old: Tensor, new: Tensor) -> None: """ rotated_old = functional.permute(old, shifts=self.size - index - 1) - self.value = functional.bind(self.value, rotated_old) + self.value = functional.unbind(self.value, rotated_old) rotated_new = functional.permute(new, shifts=self.size - index - 1) self.value = functional.bind(self.value, rotated_new) @@ -880,13 +880,13 @@ def node_neighbors(self, input: Tensor, outgoing=True) -> Tensor: """ if self.is_directed: if outgoing: - permuted_neighbors = functional.bind(self.value, input) + permuted_neighbors = functional.unbind(self.value, input) return functional.permute(permuted_neighbors, shifts=-1) else: permuted_node = functional.permute(input, shifts=1) - return functional.bind(self.value, permuted_node) + return functional.unbind(self.value, permuted_node) else: - return functional.bind(self.value, input) + return functional.unbind(self.value, input) def contains(self, input: Tensor) -> Tensor: """Returns the cosine similarity of the input vector against the graph. @@ -898,9 +898,9 @@ def contains(self, input: Tensor) -> Tensor: >>> e = G.encode_edge(letters_hv[0], letters_hv[1]) >>> G.contains(e) - tensor([1.]) + tensor(1.) """ - return functional.cosine_similarity(input, self.value.unsqueeze(0)) + return functional.cosine_similarity(input, self.value) def clear(self) -> None: """Empties the graph. @@ -1012,7 +1012,7 @@ def get_leaf(self, path: List[str]) -> Tensor: hv_path, functional.permute(self.right, shifts=idx) ) - return functional.bind(hv_path, self.value) + return functional.unbind(self.value, hv_path) def clear(self) -> None: """Empties the tree. @@ -1084,8 +1084,8 @@ def transition(self, state: Tensor, action: Tensor) -> Tensor: tensor([ 1., 1., -1., ..., -1., -1., 1.]) """ - next_state = functional.bind(self.value, state) - next_state = functional.bind(next_state, action) + next_state = functional.unbind(self.value, state) + next_state = functional.unbind(next_state, action) return functional.permute(next_state, shifts=-1) def clear(self) -> None: diff --git a/torchhd/tests/basis_hv/test_circular_hv.py b/torchhd/tests/basis_hv/test_circular_hv.py index d848c3a1..3660ab6b 100644 --- a/torchhd/tests/basis_hv/test_circular_hv.py +++ b/torchhd/tests/basis_hv/test_circular_hv.py @@ -47,23 +47,40 @@ def test_value(self, dtype): assert torch.all( (hv == True) | (hv == False) ).item(), "values are either 1 or 0" + elif dtype in torch_complex_dtypes: + magnitudes= hv.abs() + assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1" else: assert torch.all( (hv == -1) | (hv == 1) ).item(), "values are either -1 or +1" + hv = functional.circular_hv(8, 1000000, generator=generator, dtype=dtype) - sims = functional.hamming_similarity(hv[0], hv).float() / 1000000 - sims_diff = sims[:-1] - sims[1:] + if dtype in torch_complex_dtypes: + sims = functional.cosine_similarity(hv[0], hv) + sims_diff = sims[:-1] - sims[1:] - assert torch.all( - sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1]) - ), "second half must get more similar" + assert torch.all( + sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1]) + ), "second half must get more similar" - abs_sims_diff = sims_diff.abs() - assert torch.all( - (0.124 < abs_sims_diff) & (abs_sims_diff < 0.126) - ).item(), "similarity changes linearly" + abs_sims_diff = sims_diff.abs() + assert torch.all( + (0.248 < abs_sims_diff) & (abs_sims_diff < 0.252) + ).item(), "similarity changes linearly" + else: + sims = functional.hamming_similarity(hv[0], hv).float() / 1000000 + sims_diff = sims[:-1] - sims[1:] + + assert torch.all( + sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1]) + ), "second half must get more similar" + + abs_sims_diff = sims_diff.abs() + assert torch.all( + (0.124 < abs_sims_diff) & (abs_sims_diff < 0.126) + ).item(), "similarity changes linearly" @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0]) @pytest.mark.parametrize("dtype", torch_dtypes) @@ -71,6 +88,10 @@ def test_sparsity(self, sparsity, dtype): if not supported_dtype(dtype): return + if dtype in torch_complex_dtypes: + # Complex hypervectors don't support sparsity. + return + generator = torch.Generator() generator.manual_seed(seed) @@ -96,12 +117,6 @@ def test_device(self, dtype): @pytest.mark.parametrize("dtype", torch_dtypes) def test_dtype(self, dtype): - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.circular_hv(3, 26, dtype=dtype) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.circular_hv(3, 26, dtype=dtype) diff --git a/torchhd/tests/basis_hv/test_identity_hv.py b/torchhd/tests/basis_hv/test_identity_hv.py index dce1f24a..1fb1d299 100644 --- a/torchhd/tests/basis_hv/test_identity_hv.py +++ b/torchhd/tests/basis_hv/test_identity_hv.py @@ -47,12 +47,6 @@ def test_device(self, dtype): @pytest.mark.parametrize("dtype", torch_dtypes) def test_dtype(self, dtype): - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.identity_hv(3, 26, dtype=dtype) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.identity_hv(3, 26, dtype=dtype) diff --git a/torchhd/tests/basis_hv/test_level_hv.py b/torchhd/tests/basis_hv/test_level_hv.py index 3379f500..366f5671 100644 --- a/torchhd/tests/basis_hv/test_level_hv.py +++ b/torchhd/tests/basis_hv/test_level_hv.py @@ -47,22 +47,37 @@ def test_value(self, dtype): assert torch.all( (hv == True) | (hv == False) ).item(), "values are either 1 or 0" + elif dtype in torch_complex_dtypes: + magnitudes= hv.abs() + assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1" else: assert torch.all( (hv == -1) | (hv == 1) ).item(), "values are either -1 or +1" # look at the similarity profile w.r.t. the first hypervector - sims = functional.hamming_similarity(hv[0], hv).float() / 10000 - sims_diff = sims[:-1] - sims[1:] - assert torch.all(sims_diff > 0).item(), "similarity must be decreasing" + if dtype in torch_complex_dtypes: + sims = functional.cosine_similarity(hv[0], hv) + sims_diff = sims[:-1] - sims[1:] + assert torch.all(sims_diff > 0).item(), "similarity must be decreasing" - hv = functional.level_hv(5, 1000000, generator=generator, dtype=dtype) - sims = functional.hamming_similarity(hv[0], hv).float() / 1000000 - sims_diff = sims[:-1] - sims[1:] - assert torch.all( - (0.124 < sims_diff) & (sims_diff < 0.126) - ).item(), "similarity decreases linearly" + hv = functional.level_hv(5, 1000000, generator=generator, dtype=dtype) + sims = functional.cosine_similarity(hv[0], hv) + sims_diff = sims[:-1] - sims[1:] + assert torch.all( + (0.248 < sims_diff) & (sims_diff < 0.252) + ).item(), "similarity decreases linearly" + else: + sims = functional.hamming_similarity(hv[0], hv).float() / 10000 + sims_diff = sims[:-1] - sims[1:] + assert torch.all(sims_diff > 0).item(), "similarity must be decreasing" + + hv = functional.level_hv(5, 1000000, generator=generator, dtype=dtype) + sims = functional.hamming_similarity(hv[0], hv).float() / 1000000 + sims_diff = sims[:-1] - sims[1:] + assert torch.all( + (0.124 < sims_diff) & (sims_diff < 0.126) + ).item(), "similarity decreases linearly" @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0]) @pytest.mark.parametrize("dtype", torch_dtypes) @@ -70,6 +85,10 @@ def test_sparsity(self, sparsity, dtype): if not supported_dtype(dtype): return + if dtype in torch_complex_dtypes: + # Complex hypervectors don't support sparsity. + return + generator = torch.Generator() generator.manual_seed(seed) @@ -95,12 +114,6 @@ def test_device(self, dtype): @pytest.mark.parametrize("dtype", torch_dtypes) def test_dtype(self, dtype): - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.level_hv(3, 26, dtype=dtype) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.level_hv(3, 26, dtype=dtype) diff --git a/torchhd/tests/basis_hv/test_random_hv.py b/torchhd/tests/basis_hv/test_random_hv.py index 72dab2ee..4387d36e 100644 --- a/torchhd/tests/basis_hv/test_random_hv.py +++ b/torchhd/tests/basis_hv/test_random_hv.py @@ -42,14 +42,15 @@ def test_value(self, dtype): generator = torch.Generator() generator.manual_seed(seed) + hv = functional.random_hv(100, 10000, dtype=dtype, generator=generator) + if dtype == torch.bool: - hv = functional.random_hv(100, 10000, dtype=dtype, generator=generator) assert torch.all((hv == False) | (hv == True)).item() - - return - - hv = functional.random_hv(100, 10000, dtype=dtype, generator=generator) - assert torch.all((hv == -1) | (hv == 1)).item() + elif dtype in torch_complex_dtypes: + magnitudes= hv.abs() + assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1" + else: + assert torch.all((hv == -1) | (hv == 1)).item() @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0]) @pytest.mark.parametrize("dtype", torch_dtypes) @@ -57,6 +58,10 @@ def test_sparsity(self, sparsity, dtype): if not supported_dtype(dtype): return + if dtype in torch_complex_dtypes: + # Complex hypervectors don't support sparsity. + return + generator = torch.Generator() generator.manual_seed(seed) @@ -83,14 +88,24 @@ def test_orthogonality(self, dtype): generator = torch.Generator() generator.manual_seed(seed) - sims = [None] * 100 - for i in range(100): - hv = functional.random_hv(2, 10000, dtype=dtype, generator=generator) - sims[i] = functional.hamming_similarity(hv[0], hv[1].unsqueeze(0)) - - sims = torch.cat(sims).float() / 10000 - assert within(sims.mean().item(), 0.5, 0.001) - assert sims.std().item() < 0.01 + if dtype in torch_complex_dtypes: + sims = [None] * 100 + for i in range(100): + hv = functional.random_hv(2, 10000, dtype=dtype, generator=generator) + sims[i] = functional.cosine_similarity(hv[0], hv[1]) + + sims = torch.stack(sims).float() / 10000 + assert within(sims.mean().item(), 0.0, 0.001) + assert sims.std().item() < 0.01 + else: + sims = [None] * 100 + for i in range(100): + hv = functional.random_hv(2, 10000, dtype=dtype, generator=generator) + sims[i] = functional.hamming_similarity(hv[0], hv[1].unsqueeze(0)) + + sims = torch.stack(sims).float() / 10000 + assert within(sims.mean().item(), 0.5, 0.001) + assert sims.std().item() < 0.01 @pytest.mark.parametrize("dtype", torch_dtypes) def test_device(self, dtype): @@ -103,12 +118,6 @@ def test_device(self, dtype): @pytest.mark.parametrize("dtype", torch_dtypes) def test_dtype(self, dtype): - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.random_hv(3, 26, dtype=dtype) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.random_hv(3, 26, dtype=dtype) diff --git a/torchhd/tests/test_encodings.py b/torchhd/tests/test_encodings.py index 630000ca..aa166601 100644 --- a/torchhd/tests/test_encodings.py +++ b/torchhd/tests/test_encodings.py @@ -56,12 +56,6 @@ def test_value(self, dtype): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.multiset(hv) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.multiset(hv) @@ -128,13 +122,7 @@ def test_value(self, dtype): @pytest.mark.parametrize("dtype", torch_dtypes) def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.multibind(hv) - - return - + if dtype == torch.uint8: with pytest.raises(ValueError): functional.multibind(hv) @@ -175,12 +163,6 @@ def test_value(self): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.cross_product(hv, hv) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.cross_product(hv, hv) @@ -215,12 +197,6 @@ def test_value(self): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.ngrams(hv) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.ngrams(hv) @@ -254,12 +230,6 @@ def test_value(self): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.hash_table(hv, hv) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.hash_table(hv, hv) @@ -294,12 +264,6 @@ def test_value(self): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.bundle_sequence(hv) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.bundle_sequence(hv) @@ -334,12 +298,6 @@ def test_value(self): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.bind_sequence(hv) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.bind_sequence(hv) diff --git a/torchhd/tests/test_operations.py b/torchhd/tests/test_operations.py index 879b09f5..cfddc7d7 100644 --- a/torchhd/tests/test_operations.py +++ b/torchhd/tests/test_operations.py @@ -50,12 +50,6 @@ def test_value(self, dtype): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.bind(hv[0], hv[1]) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.bind(hv[0], hv[1]) @@ -141,12 +135,6 @@ def test_value(self, dtype): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.bundle(hv[0], hv[1]) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.bundle(hv[0], hv[1]) @@ -256,12 +244,6 @@ def test_threshold(self): def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype) - if dtype in torch_complex_dtypes: - with pytest.raises(NotImplementedError): - functional.cleanup(hv[0], hv, threshold=-1) - - return - if dtype == torch.uint8: with pytest.raises(ValueError): functional.cleanup(hv[0], hv, threshold=-1) diff --git a/torchhd/tests/utils.py b/torchhd/tests/utils.py index 7bd7228a..111442c1 100644 --- a/torchhd/tests/utils.py +++ b/torchhd/tests/utils.py @@ -49,5 +49,5 @@ def within(value: number, target: number, delta: number) -> bool: def supported_dtype(dtype: torch.dtype) -> bool: - not_supported = dtype in torch_complex_dtypes or dtype == torch.uint8 + not_supported = dtype == torch.uint8 return not not_supported diff --git a/torchhd/utils.py b/torchhd/utils.py index 45426b3c..1b05be64 100644 --- a/torchhd/utils.py +++ b/torchhd/utils.py @@ -35,9 +35,7 @@ def plot_pair_similarity(memory: Tensor, ax=None, **kwargs): See https://matplotlib.org/stable/users/installing/index.html for more information." ) - similarity = [] - for vector in memory: - similarity.append(functional.cosine_similarity(vector, memory).tolist()) + similarity = functional.cosine_similarity(memory, memory).tolist() if ax is None: ax = plt.gca()