Skip to content

Commit

Permalink
Add MCRTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
milad2073 committed Oct 6, 2024
1 parent 132ee17 commit 345a7c0
Show file tree
Hide file tree
Showing 13 changed files with 569 additions and 110 deletions.
2 changes: 2 additions & 0 deletions torchhd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torchhd.tensors.fhrr import FHRRTensor
from torchhd.tensors.bsbc import BSBCTensor
from torchhd.tensors.vtb import VTBTensor
from torchhd.tensors.mcr import MCRTensor

from torchhd.functional import (
ensure_vsa_tensor,
Expand Down Expand Up @@ -90,6 +91,7 @@
"FHRRTensor",
"BSBCTensor",
"VTBTensor",
"MCRTensor",
"functional",
"embeddings",
"structures",
Expand Down
7 changes: 5 additions & 2 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torchhd.tensors.fhrr import FHRRTensor
from torchhd.tensors.bsbc import BSBCTensor
from torchhd.tensors.vtb import VTBTensor
from torchhd.tensors.mcr import MCRTensor
from torchhd.types import VSAOptions


Expand Down Expand Up @@ -90,6 +91,8 @@ def get_vsa_tensor_class(vsa: VSAOptions) -> Type[VSATensor]:
return BSBCTensor
elif vsa == "VTB":
return VTBTensor
elif vsa == "MCR":
return MCRTensor

raise ValueError(f"Provided VSA model is not supported, specified: {vsa}")

Expand Down Expand Up @@ -358,7 +361,7 @@ def level(
device=span_hv.device,
).as_subclass(vsa_tensor)

if vsa == "BSBC":
if vsa == "BSBC" or vsa == "MCR":
hv.block_size = span_hv.block_size

for i in range(num_vectors):
Expand Down Expand Up @@ -585,7 +588,7 @@ def circular(
device=span_hv.device,
).as_subclass(vsa_tensor)

if vsa == "BSBC":
if vsa == "BSBC" or vsa == "MCR":
hv.block_size = span_hv.block_size

mutation_history = deque()
Expand Down
Loading

0 comments on commit 345a7c0

Please sign in to comment.