Skip to content

Commit a513455

Browse files
authored
Add Cyclic Group Representation (CGR) class (#182)
* Add BaseMCRTensor Add a new base class to support future implementation of variations of MCR proposals. * Fix bug in BaseMCRTensor:__torch_function__() There is a bug when args is a collection of collections instead of a plain tuple. In this case, the old args parsing was unable to search for block_size in nested structures containing BaseMCRTensor. For instance, let `a` and `b` be BaseMCRTensor variables. Calling `torch.stack((a, b))` results in error, as the `args` received in `__torch_function__()` is a nested tuple. * Add CGRTensor Add a new VSA class named Cyclic Group Representation (CGR). This class is similar to MCR, but differs in bundling. * Add CGRTensor to functional embeddings Allow its usage in level and circular embeddings as done with MCRTensor. * Fix CGR bundle Ensure both inputs are in the same shape. * Add CGR into tests The CGR should behave almost the same as MCR, but diverges in bundling. Implement a custom bundling test for it. * Use Tensor.dim() in CGRTensor Shorter syntax and more readable. * Remove asserts in CGR/MCR classes No need to check block size in CGR/MCR functions as the __torch_function__() in BaseMCRTensor already checks it. * Improve exception in BaseMCRTensor if multiple block_sizes - Raise ValueError instead of RuntimeError * Add CGRTensor to docs Include CGRTensor in the README and in the built docs. * Add BaseMCRTensor to docs * Fix typo in README.md * Improve documentation for MCRTensor and CGRTensor Replicate function signatures in inherited classes (MCRTensor and CGRTensor) from their parent class (BaseMCRTensor) to let sphinx print the documentation of each parent function.
1 parent 6b7478d commit a513455

17 files changed

+736
-393
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Currently, the library supports the following HDC/VSA models:
9595
- [Holographic Reduced Representations (HRR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.HRRTensor.html)
9696
- [Fourier Holographic Reduced Representations (FHRR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.FHRRTensor.html)
9797
- [Binary Sparse Block Codes (B-SBC)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.BSBCTensor.html)
98+
- [Cyclic Group Representation (CGR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.CGRTensor.html)
9899
- [Modular Composite Representation (MCR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.MCRTensor.html)
99100
- [Vector-Derived Transformation Binding (VTB)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.VTBTensor.html)
100101

docs/torchhd.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ This module consists of the basic hypervector generation functions and operation
1010
Basis-hypervector sets
1111
----------------------------------
1212

13-
.. autosummary::
13+
.. autosummary::
1414
:toctree: generated/
1515
:template: function.rst
1616

@@ -20,7 +20,7 @@ Basis-hypervector sets
2020
level
2121
thermometer
2222
circular
23-
23+
2424

2525
Operations
2626
--------------------
@@ -51,7 +51,7 @@ Similarities
5151
.. autosummary::
5252
:toctree: generated/
5353
:template: function.rst
54-
54+
5555
cosine_similarity
5656
dot_similarity
5757
hamming_similarity
@@ -87,7 +87,9 @@ VSA Models
8787
HRRTensor
8888
FHRRTensor
8989
BSBCTensor
90+
BaseMCRTensor
9091
MCRTensor
92+
CGRTensor
9193
VTBTensor
9294

9395

@@ -97,7 +99,7 @@ Utilities
9799
.. autosummary::
98100
:toctree: generated/
99101
:template: function.rst
100-
102+
101103
ensure_vsa_tensor
102104
map_range
103105
value_to_index

torchhd/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
from torchhd.tensors.fhrr import FHRRTensor
3838
from torchhd.tensors.bsbc import BSBCTensor
3939
from torchhd.tensors.vtb import VTBTensor
40+
from torchhd.tensors.basemcr import BaseMCRTensor
4041
from torchhd.tensors.mcr import MCRTensor
42+
from torchhd.tensors.cgr import CGRTensor
4143

4244
from torchhd.functional import (
4345
ensure_vsa_tensor,
@@ -91,7 +93,9 @@
9193
"FHRRTensor",
9294
"BSBCTensor",
9395
"VTBTensor",
96+
"BaseMCRTensor",
9497
"MCRTensor",
98+
"CGRTensor",
9599
"functional",
96100
"embeddings",
97101
"structures",

torchhd/functional.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from torchhd.tensors.bsbc import BSBCTensor
3737
from torchhd.tensors.vtb import VTBTensor
3838
from torchhd.tensors.mcr import MCRTensor
39+
from torchhd.tensors.cgr import CGRTensor
3940
from torchhd.types import VSAOptions
4041

4142

@@ -93,6 +94,8 @@ def get_vsa_tensor_class(vsa: VSAOptions) -> Type[VSATensor]:
9394
return VTBTensor
9495
elif vsa == "MCR":
9596
return MCRTensor
97+
elif vsa == "CGR":
98+
return CGRTensor
9699

97100
raise ValueError(f"Provided VSA model is not supported, specified: {vsa}")
98101

@@ -361,7 +364,7 @@ def level(
361364
device=span_hv.device,
362365
).as_subclass(vsa_tensor)
363366

364-
if vsa == "BSBC" or vsa == "MCR":
367+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
365368
hv.block_size = span_hv.block_size
366369

367370
for i in range(num_vectors):
@@ -588,7 +591,7 @@ def circular(
588591
device=span_hv.device,
589592
).as_subclass(vsa_tensor)
590593

591-
if vsa == "BSBC" or vsa == "MCR":
594+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
592595
hv.block_size = span_hv.block_size
593596

594597
mutation_history = deque()

0 commit comments

Comments
 (0)