From d6343281c618f7859412abf7280ad03e4e39f727 Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Tue, 27 May 2025 10:11:54 -0300 Subject: [PATCH 01/13] Add BaseMCRTensor Add a new base class to support future implementation of variations of MCR proposals. --- torchhd/__init__.py | 2 + torchhd/tensors/basemcr.py | 379 +++++++++++++++++++++++++++++++++++++ torchhd/tensors/mcr.py | 349 +--------------------------------- 3 files changed, 389 insertions(+), 341 deletions(-) create mode 100644 torchhd/tensors/basemcr.py diff --git a/torchhd/__init__.py b/torchhd/__init__.py index e1aa3c2..8586ce3 100644 --- a/torchhd/__init__.py +++ b/torchhd/__init__.py @@ -37,6 +37,7 @@ from torchhd.tensors.fhrr import FHRRTensor from torchhd.tensors.bsbc import BSBCTensor from torchhd.tensors.vtb import VTBTensor +from torchhd.tensors.basemcr import BaseMCRTensor from torchhd.tensors.mcr import MCRTensor from torchhd.functional import ( @@ -91,6 +92,7 @@ "FHRRTensor", "BSBCTensor", "VTBTensor", + "BaseMCRTensor", "MCRTensor", "functional", "embeddings", diff --git a/torchhd/tensors/basemcr.py b/torchhd/tensors/basemcr.py new file mode 100644 index 0000000..3d07286 --- /dev/null +++ b/torchhd/tensors/basemcr.py @@ -0,0 +1,379 @@ +# +# MIT License +# +# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import torch +from torch import Tensor +import torch.nn.functional as F +from typing import Set + +from torchhd.tensors.base import VSATensor + + +class BaseMCRTensor(VSATensor): + r"""Base class for VSA Modular Composite Representations (MCR) + + Proposed in `Modular Composite Representation `_, this model works with modular integer vectors. The base class is used as template for the MCR and the Cyclic Group Representation (CGR), which is very similar to MCR but uses a different bundling operation. + """ + + block_size: int + supported_dtypes: Set[torch.dtype] = { + torch.float32, + torch.float64, + torch.int16, + torch.int32, + torch.int64, + } + + @classmethod + def empty( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + generator=None, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "BaseMCRTensor": + r"""Creates a set of hypervectors representing empty sets. + + When bundled with a hypervector :math:`x`, the result is :math:`x`. + Because of the low precession of the MCR model an empty set cannot be explicitly represented, therefore the returned hypervectors are identical to random-hypervectors. + + Args: + num_vectors (int): the number of hypervectors to generate. + dimensions (int): the dimensionality of the hypervectors. + block_size (int): the number of elements per block which controls the angular granularity. + generator (``torch.Generator``, optional): a pseudorandom number generator for sampling. + dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: ``int64``. + device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. + + Examples:: + + >>> torchhd.BaseMCRTensor.empty(3, 6, block_size=64) + BaseMCRTensor([[54, 3, 22, 27, 41, 21], + [17, 31, 55, 3, 44, 52], + [42, 37, 60, 54, 13, 41]]) + + """ + + if dtype == None: + dtype = torch.int64 + + if dtype not in cls.supported_dtypes: + name = cls.__name__ + options = ", ".join([str(x) for x in cls.supported_dtypes]) + raise ValueError( + f"{name} vectors must be one of dtype {options}, got {dtype}." + ) + + result = torch.randint( + 0, + block_size, + (num_vectors, dimensions), + generator=generator, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + result = result.as_subclass(cls) + result.block_size = block_size + return result + + @classmethod + def identity( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "BaseMCRTensor": + r"""Creates a set of identity hypervectors. + + When bound with a random-hypervector :math:`x`, the result is :math:`x`. + + Args: + num_vectors (int): the number of hypervectors to generate. + dimensions (int): the dimensionality of the hypervectors. + block_size (int): the number of elements per block which controls the angular granularity. + dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``int64`` depends on VSATensor. + device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. + + Examples:: + + >>> torchhd.BaseMCRTensor.identity(3, 6, block_size=64) + BaseMCRTensor([[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]) + + """ + if dtype == None: + dtype = torch.int64 + + if dtype not in cls.supported_dtypes: + name = cls.__name__ + options = ", ".join([str(x) for x in cls.supported_dtypes]) + raise ValueError( + f"{name} vectors must be one of dtype {options}, got {dtype}." + ) + + result = torch.zeros( + num_vectors, + dimensions, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + result = result.as_subclass(cls) + result.block_size = block_size + return result + + @classmethod + def random( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + generator=None, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "BaseMCRTensor": + r"""Creates a set of random independent hypervectors. + + The resulting hypervectors sample uniformly random integers between 0 and ``block_size``. + + Args: + num_vectors (int): the number of hypervectors to generate. + dimensions (int): the dimensionality of the hypervectors. + block_size (int): the number of elements per block which controls the angular granularity. + generator (``torch.Generator``, optional): a pseudorandom number generator for sampling. + dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: ``int64``. + device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. + + Examples:: + + >>> torchhd.BaseMCRTensor.random(3, 6, block_size=64) + BaseMCRTensor([[ 7, 1, 39, 8, 55, 22], + [51, 38, 59, 45, 13, 29], + [19, 26, 30, 5, 15, 51]]) + >>> torchhd.BaseMCRTensor.random(3, 6, block_size=128, dtype=torch.float32) + BaseMCRTensor([[116., 25., 100., 10., 21., 86.], + [ 69., 49., 2., 56., 78., 70.], + [ 77., 47., 37., 106., 8., 30.]]) + + """ + if dtype == None: + dtype = torch.int64 + + if dtype not in cls.supported_dtypes: + name = cls.__name__ + options = ", ".join([str(x) for x in cls.supported_dtypes]) + raise ValueError( + f"{name} vectors must be one of dtype {options}, got {dtype}." + ) + + result = torch.randint( + 0, + block_size, + (num_vectors, dimensions), + generator=generator, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + result = result.as_subclass(cls) + result.block_size = block_size + return result + + def to_complex_unit(self): + angles = 2 * torch.pi * self / self.block_size + return torch.polar(torch.ones_like(self, dtype=angles.dtype), angles) + + def bundle(self, other: "BaseMCRTensor") -> "BaseMCRTensor": + """Bundle the hypervector with other""" + raise NotImplementedError + + def multibundle(self) -> "BaseMCRTensor": + """Bundle multiple hypervectors""" + return super().multibundle() + + def bind(self, other: "BaseMCRTensor") -> "BaseMCRTensor": + r"""Bind the hypervector with other using circular convolution. + + This produces a hypervector dissimilar to both. + + Binding is used to associate information, for instance, to assign values to variables. + + Args: + other (BaseMCRTensor): other input hypervector + + Shapes: + - Self: :math:`(*)` + - Other: :math:`(*)` + - Output: :math:`(*)` + + Examples:: + + >>> a, b = torchhd.BaseMCRTensor.random(2, 10, block_size=64) + >>> a + BaseMCRTensor([18, 55, 40, 62, 39, 26, 35, 24, 49, 41]) + >>> b + BaseMCRTensor([46, 36, 21, 23, 25, 12, 29, 53, 54, 41]) + >>> a.bind(b) + BaseMCRTensor([ 0, 27, 61, 21, 0, 38, 0, 13, 39, 18]) + + """ + assert self.block_size == other.block_size + return torch.remainder(torch.add(self, other), self.block_size) + + def multibind(self) -> "BaseMCRTensor": + """Bind multiple hypervectors""" + return torch.remainder( + torch.sum(self, dim=-2, dtype=self.dtype), self.block_size + ) + + def inverse(self) -> "BaseMCRTensor": + r"""Invert the hypervector for binding. + + Shapes: + - Self: :math:`(*)` + - Output: :math:`(*)` + + Examples:: + + >>> a = torchhd.BaseMCRTensor.random(1, 10, block_size=64) + >>> a + BaseMCRTensor([[ 5, 30, 15, 43, 19, 36, 4, 14, 57, 34]]) + >>> a.inverse() + BaseMCRTensor([[59, 34, 49, 21, 45, 28, 60, 50, 7, 30]]) + + """ + + return torch.remainder(torch.negative(self), self.block_size) + + def permute(self, shifts: int = 1) -> "BaseMCRTensor": + r"""Permute the hypervector. + + The permutation operator is commonly used to assign an order to hypervectors. + + Args: + shifts (int, optional): The number of places by which the elements of the tensor are shifted. + + Shapes: + - Self: :math:`(*)` + - Output: :math:`(*)` + + Examples:: + + >>> a = torchhd.BaseMCRTensor.random(1, 10, block_size=64) + >>> a + BaseMCRTensor([[33, 24, 1, 36, 2, 57, 11, 59, 33, 3]]) + >>> a.permute(4) + BaseMCRTensor([[11, 59, 33, 3, 33, 24, 1, 36, 2, 57]]) + + """ + return torch.roll(self, shifts=shifts, dims=-1) + + def normalize(self) -> "BaseMCRTensor": + r"""Normalize the hypervector. + + Each operation on MCR hypervectors ensures it remains normalized, so this returns a copy of self. + + Shapes: + - Self: :math:`(*)` + - Output: :math:`(*)` + + Examples:: + + >>> x = torchhd.BaseMCRTensor.random(4, 6, block_size=64).multibundle() + >>> x + BaseMCRTensor([28, 27, 20, 44, 57, 18]) + >>> x.normalize() + BaseMCRTensor([28, 27, 20, 44, 57, 18]) + + """ + return self.clone() + + def dot_similarity(self, others: "BaseMCRTensor", *, dtype=None) -> Tensor: + """Based on 'Manhattan Distance in a Modular Space'. + Distance of two elements devided by the avearage distance of two random numbers. + """ + if dtype is None: + dtype = torch.get_default_dtype() + + random_distance = self.block_size/4 + + + if self.dim() > 1 and others.dim() > 1: + aminusb = torch.remainder(self.unsqueeze(-2) - others.unsqueeze(-3), self.block_size) + bminusa = torch.remainder(others.unsqueeze(-3) - self.unsqueeze(-2), self.block_size) + else: + aminusb = torch.remainder(self - others, self.block_size) + bminusa = torch.remainder(others - self, self.block_size) + distance = torch.min(aminusb,bminusa) + normalized_distance = distance/random_distance + + return torch.sum(1-normalized_distance, dim=-1, dtype=dtype) + + def cosine_similarity(self, others: "BaseMCRTensor", *, dtype=None) -> Tensor: + """Cosine similarity with other hypervectors""" + magnitude = self.size(-1) + return self.dot_similarity(others, dtype=dtype) / magnitude + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + # Ensure that all the build-in torch operations on this Tensor subclass maintain the block_size property + + if kwargs is None: + kwargs = {} + + block_sizes = set(a.block_size for a in args if hasattr(a, "block_size")) + if len(block_sizes) != 1: + raise RuntimeError( + f"Call to {func} must contain exactly one block size, got {list(block_sizes)}" + ) + + # Call with super to avoid infinite recursion + ret = super().__torch_function__(func, types, args, kwargs) + + if isinstance(ret, BaseMCRTensor): + ret.block_size = list(block_sizes)[0] + elif isinstance(ret, (tuple, list)): + for x in ret: + if isinstance(x, BaseMCRTensor): + x.block_size = list(block_sizes)[0] + + # TODO: handle more return types + return ret diff --git a/torchhd/tensors/mcr.py b/torchhd/tensors/mcr.py index 6801a49..3037999 100644 --- a/torchhd/tensors/mcr.py +++ b/torchhd/tensors/mcr.py @@ -27,200 +27,15 @@ from typing import Set from torchhd.tensors.base import VSATensor +from torchhd.tensors.basemcr import BaseMCRTensor -class MCRTensor(VSATensor): +class MCRTensor(BaseMCRTensor): r"""Modular Composite Representation (MCR) Proposed in `Modular Composite Representation `_, this model works with modular integer vectors. """ - block_size: int - supported_dtypes: Set[torch.dtype] = { - torch.float32, - torch.float64, - torch.int16, - torch.int32, - torch.int64, - } - - @classmethod - def empty( - cls, - num_vectors: int, - dimensions: int, - *, - block_size: int, - generator=None, - dtype=torch.int64, - device=None, - requires_grad=False, - ) -> "MCRTensor": - r"""Creates a set of hypervectors representing empty sets. - - When bundled with a hypervector :math:`x`, the result is :math:`x`. - Because of the low precession of the MCR model an empty set cannot be explicitly represented, therefore the returned hypervectors are identical to random-hypervectors. - - Args: - num_vectors (int): the number of hypervectors to generate. - dimensions (int): the dimensionality of the hypervectors. - block_size (int): the number of elements per block which controls the angular granularity. - generator (``torch.Generator``, optional): a pseudorandom number generator for sampling. - dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: ``int64``. - device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. - - Examples:: - - >>> torchhd.MCRTensor.empty(3, 6, block_size=64) - MCRTensor([[54, 3, 22, 27, 41, 21], - [17, 31, 55, 3, 44, 52], - [42, 37, 60, 54, 13, 41]]) - - """ - - if dtype == None: - dtype = torch.int64 - - if dtype not in cls.supported_dtypes: - name = cls.__name__ - options = ", ".join([str(x) for x in cls.supported_dtypes]) - raise ValueError( - f"{name} vectors must be one of dtype {options}, got {dtype}." - ) - - result = torch.randint( - 0, - block_size, - (num_vectors, dimensions), - generator=generator, - dtype=dtype, - device=device, - requires_grad=requires_grad, - ) - - result = result.as_subclass(cls) - result.block_size = block_size - return result - - @classmethod - def identity( - cls, - num_vectors: int, - dimensions: int, - *, - block_size: int, - dtype=torch.int64, - device=None, - requires_grad=False, - ) -> "MCRTensor": - r"""Creates a set of identity hypervectors. - - When bound with a random-hypervector :math:`x`, the result is :math:`x`. - - Args: - num_vectors (int): the number of hypervectors to generate. - dimensions (int): the dimensionality of the hypervectors. - block_size (int): the number of elements per block which controls the angular granularity. - dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``int64`` depends on VSATensor. - device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. - - Examples:: - - >>> torchhd.MCRTensor.identity(3, 6, block_size=64) - MCRTensor([[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]]) - - """ - if dtype == None: - dtype = torch.int64 - - if dtype not in cls.supported_dtypes: - name = cls.__name__ - options = ", ".join([str(x) for x in cls.supported_dtypes]) - raise ValueError( - f"{name} vectors must be one of dtype {options}, got {dtype}." - ) - - result = torch.zeros( - num_vectors, - dimensions, - dtype=dtype, - device=device, - requires_grad=requires_grad, - ) - - result = result.as_subclass(cls) - result.block_size = block_size - return result - - @classmethod - def random( - cls, - num_vectors: int, - dimensions: int, - *, - block_size: int, - generator=None, - dtype=torch.int64, - device=None, - requires_grad=False, - ) -> "MCRTensor": - r"""Creates a set of random independent hypervectors. - - The resulting hypervectors sample uniformly random integers between 0 and ``block_size``. - - Args: - num_vectors (int): the number of hypervectors to generate. - dimensions (int): the dimensionality of the hypervectors. - block_size (int): the number of elements per block which controls the angular granularity. - generator (``torch.Generator``, optional): a pseudorandom number generator for sampling. - dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: ``int64``. - device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. - - Examples:: - - >>> torchhd.MCRTensor.random(3, 6, block_size=64) - MCRTensor([[ 7, 1, 39, 8, 55, 22], - [51, 38, 59, 45, 13, 29], - [19, 26, 30, 5, 15, 51]]) - >>> torchhd.MCRTensor.random(3, 6, block_size=128, dtype=torch.float32) - MCRTensor([[116., 25., 100., 10., 21., 86.], - [ 69., 49., 2., 56., 78., 70.], - [ 77., 47., 37., 106., 8., 30.]]) - - """ - if dtype == None: - dtype = torch.int64 - - if dtype not in cls.supported_dtypes: - name = cls.__name__ - options = ", ".join([str(x) for x in cls.supported_dtypes]) - raise ValueError( - f"{name} vectors must be one of dtype {options}, got {dtype}." - ) - - result = torch.randint( - 0, - block_size, - (num_vectors, dimensions), - generator=generator, - dtype=dtype, - device=device, - requires_grad=requires_grad, - ) - - result = result.as_subclass(cls) - result.block_size = block_size - return result - - def to_complex_unit(self): - angles = 2 * torch.pi * self / self.block_size - return torch.polar(torch.ones_like(self, dtype=angles.dtype), angles) - def bundle(self, other: "MCRTensor") -> "MCRTensor": r"""Bundle the hypervector with normalized complex vector addition. @@ -251,11 +66,11 @@ def bundle(self, other: "MCRTensor") -> "MCRTensor": self_phasor = self.to_complex_unit() other_phasor = other.to_complex_unit() - + # Adding the vectors of each element sum_of_phasors = self_phasor + other_phasor - - # To define the ultimate number that the summation will land on + + # To define the ultimate number that the summation will land on # we first find the theta of summation then quantize it to block_size angels = torch.angle(sum_of_phasors) result = self.block_size * (angels / (2 * torch.pi)) @@ -265,7 +80,7 @@ def bundle(self, other: "MCRTensor") -> "MCRTensor": # We return the average of two operands in such a case. is_zero = torch.isclose(sum_of_phasors, torch.zeros_like(sum_of_phasors)) result = torch.where(is_zero, (self + other) / 2, result).round() - + return torch.remainder(result, self.block_size).type(self.dtype) def multibundle(self) -> "MCRTensor": @@ -274,7 +89,7 @@ def multibundle(self) -> "MCRTensor": self_phasor = self.to_complex_unit() sum_of_phasors = torch.sum(self_phasor, dim=-2) - # To define the ultimate number that the summation will land on + # To define the ultimate number that the summation will land on # we first find the theta of summation then quantize it to block_size angels = torch.angle(sum_of_phasors) result = self.block_size * (angels / (2 * torch.pi)) @@ -284,154 +99,6 @@ def multibundle(self) -> "MCRTensor": # We return the average of two operands in such a case. is_zero = torch.isclose(sum_of_phasors, torch.zeros_like(sum_of_phasors)) result = torch.where(is_zero, torch.mean(self, dim=-2, dtype=torch.float), result).round() - - return torch.remainder(result, self.block_size).type(self.dtype) - - def bind(self, other: "MCRTensor") -> "MCRTensor": - r"""Bind the hypervector with other using circular convolution. - - This produces a hypervector dissimilar to both. - - Binding is used to associate information, for instance, to assign values to variables. - - Args: - other (MCRTensor): other input hypervector - - Shapes: - - Self: :math:`(*)` - - Other: :math:`(*)` - - Output: :math:`(*)` - - Examples:: - >>> a, b = torchhd.MCRTensor.random(2, 10, block_size=64) - >>> a - MCRTensor([18, 55, 40, 62, 39, 26, 35, 24, 49, 41]) - >>> b - MCRTensor([46, 36, 21, 23, 25, 12, 29, 53, 54, 41]) - >>> a.bind(b) - MCRTensor([ 0, 27, 61, 21, 0, 38, 0, 13, 39, 18]) - - """ - assert self.block_size == other.block_size - return torch.remainder(torch.add(self, other), self.block_size) - - def multibind(self) -> "MCRTensor": - """Bind multiple hypervectors""" - return torch.remainder( - torch.sum(self, dim=-2, dtype=self.dtype), self.block_size - ) - - def inverse(self) -> "MCRTensor": - r"""Invert the hypervector for binding. - - Shapes: - - Self: :math:`(*)` - - Output: :math:`(*)` - - Examples:: - - >>> a = torchhd.MCRTensor.random(1, 10, block_size=64) - >>> a - MCRTensor([[ 5, 30, 15, 43, 19, 36, 4, 14, 57, 34]]) - >>> a.inverse() - MCRTensor([[59, 34, 49, 21, 45, 28, 60, 50, 7, 30]]) - - """ - - return torch.remainder(torch.negative(self), self.block_size) - - def permute(self, shifts: int = 1) -> "MCRTensor": - r"""Permute the hypervector. - - The permutation operator is commonly used to assign an order to hypervectors. - - Args: - shifts (int, optional): The number of places by which the elements of the tensor are shifted. - - Shapes: - - Self: :math:`(*)` - - Output: :math:`(*)` - - Examples:: - - >>> a = torchhd.MCRTensor.random(1, 10, block_size=64) - >>> a - MCRTensor([[33, 24, 1, 36, 2, 57, 11, 59, 33, 3]]) - >>> a.permute(4) - MCRTensor([[11, 59, 33, 3, 33, 24, 1, 36, 2, 57]]) - - """ - return torch.roll(self, shifts=shifts, dims=-1) - - def normalize(self) -> "MCRTensor": - r"""Normalize the hypervector. - - Each operation on MCR hypervectors ensures it remains normalized, so this returns a copy of self. - - Shapes: - - Self: :math:`(*)` - - Output: :math:`(*)` - - Examples:: - - >>> x = torchhd.MCRTensor.random(4, 6, block_size=64).multibundle() - >>> x - MCRTensor([28, 27, 20, 44, 57, 18]) - >>> x.normalize() - MCRTensor([28, 27, 20, 44, 57, 18]) - - """ - return self.clone() - - def dot_similarity(self, others: "MCRTensor", *, dtype=None) -> Tensor: - """Based on 'Manhattan Distance in a Modular Space'. - Distance of two elements devided by the avearage distance of two random numbers. - """ - if dtype is None: - dtype = torch.get_default_dtype() - - random_distance = self.block_size/4 - - - if self.dim() > 1 and others.dim() > 1: - aminusb = torch.remainder(self.unsqueeze(-2) - others.unsqueeze(-3), self.block_size) - bminusa = torch.remainder(others.unsqueeze(-3) - self.unsqueeze(-2), self.block_size) - else: - aminusb = torch.remainder(self - others, self.block_size) - bminusa = torch.remainder(others - self, self.block_size) - distance = torch.min(aminusb,bminusa) - normalized_distance = distance/random_distance - - return torch.sum(1-normalized_distance, dim=-1, dtype=dtype) - - def cosine_similarity(self, others: "MCRTensor", *, dtype=None) -> Tensor: - """Cosine similarity with other hypervectors""" - magnitude = self.size(-1) - return self.dot_similarity(others, dtype=dtype) / magnitude - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - # Ensure that all the build-in torch operations on this Tensor subclass maintain the block_size property - - if kwargs is None: - kwargs = {} - - block_sizes = set(a.block_size for a in args if hasattr(a, "block_size")) - if len(block_sizes) != 1: - raise RuntimeError( - f"Call to {func} must contain exactly one block size, got {list(block_sizes)}" - ) - - # Call with super to avoid infinite recursion - ret = super().__torch_function__(func, types, args, kwargs) - - if isinstance(ret, MCRTensor): - ret.block_size = list(block_sizes)[0] - elif isinstance(ret, (tuple, list)): - for x in ret: - if isinstance(x, MCRTensor): - x.block_size = list(block_sizes)[0] + return torch.remainder(result, self.block_size).type(self.dtype) - # TODO: handle more return types - return ret From c22f2a3283c7a1c4c2795bed18421f9a9012a5fd Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Tue, 27 May 2025 13:35:48 -0300 Subject: [PATCH 02/13] Fix bug in BaseMCRTensor:__torch_function__() There is a bug when args is a collection of collections instead of a plain tuple. In this case, the old args parsing was unable to search for block_size in nested structures containing BaseMCRTensor. For instance, let `a` and `b` be BaseMCRTensor variables. Calling `torch.stack((a, b))` results in error, as the `args` received in `__torch_function__()` is a nested tuple. --- torchhd/tensors/basemcr.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/torchhd/tensors/basemcr.py b/torchhd/tensors/basemcr.py index 3d07286..b3313ce 100644 --- a/torchhd/tensors/basemcr.py +++ b/torchhd/tensors/basemcr.py @@ -359,7 +359,20 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - block_sizes = set(a.block_size for a in args if hasattr(a, "block_size")) + def _parse_container_for_attr(container, attr): + s = set() + for a in container: + if type(a) is tuple or type(a) is list: + s |= _parse_container_for_attr(a, attr) + else: + if hasattr(a, attr): + s.add(a.block_size) + return s + + # Args is a tuple that can contain other tuples or lists. Parse it + # reccursively to find any BaseMCRTensor object + block_sizes = _parse_container_for_attr(args, "block_size") + if len(block_sizes) != 1: raise RuntimeError( f"Call to {func} must contain exactly one block size, got {list(block_sizes)}" From e4d1e995daa8dd4b15bd5959124c57fdf895c521 Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Tue, 27 May 2025 14:14:04 -0300 Subject: [PATCH 03/13] Add CGRTensor Add a new VSA class named Cyclic Group Representation (CGR). This class is similar to MCR, but differs in bundling. --- torchhd/__init__.py | 2 ++ torchhd/functional.py | 3 ++ torchhd/tensors/cgr.py | 75 ++++++++++++++++++++++++++++++++++++++++++ torchhd/types.py | 2 +- 4 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 torchhd/tensors/cgr.py diff --git a/torchhd/__init__.py b/torchhd/__init__.py index 8586ce3..e5517df 100644 --- a/torchhd/__init__.py +++ b/torchhd/__init__.py @@ -39,6 +39,7 @@ from torchhd.tensors.vtb import VTBTensor from torchhd.tensors.basemcr import BaseMCRTensor from torchhd.tensors.mcr import MCRTensor +from torchhd.tensors.cgr import CGRTensor from torchhd.functional import ( ensure_vsa_tensor, @@ -94,6 +95,7 @@ "VTBTensor", "BaseMCRTensor", "MCRTensor", + "CGRTensor", "functional", "embeddings", "structures", diff --git a/torchhd/functional.py b/torchhd/functional.py index 1d075c5..464a9b4 100644 --- a/torchhd/functional.py +++ b/torchhd/functional.py @@ -36,6 +36,7 @@ from torchhd.tensors.bsbc import BSBCTensor from torchhd.tensors.vtb import VTBTensor from torchhd.tensors.mcr import MCRTensor +from torchhd.tensors.cgr import CGRTensor from torchhd.types import VSAOptions @@ -93,6 +94,8 @@ def get_vsa_tensor_class(vsa: VSAOptions) -> Type[VSATensor]: return VTBTensor elif vsa == "MCR": return MCRTensor + elif vsa == "CGR": + return CGRTensor raise ValueError(f"Provided VSA model is not supported, specified: {vsa}") diff --git a/torchhd/tensors/cgr.py b/torchhd/tensors/cgr.py new file mode 100644 index 0000000..96591c8 --- /dev/null +++ b/torchhd/tensors/cgr.py @@ -0,0 +1,75 @@ +# +# MIT License +# +# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import torch +from torch import Tensor +import torch.nn.functional as F +from typing import Set + +from torchhd.tensors.basemcr import BaseMCRTensor + + +class CGRTensor(BaseMCRTensor): + r"""Cyclic Group Representation (CGR) + + First introduced in `Modular Composite Representation `_ and then better elaborated in `Understanding hyperdimensional computing for parallel single-pass learning `_, this model works with modular integer vectors. It works similar to the MCR class, but uses a bundling based on element-wise mode instead of addition of complex numbers. + """ + + def bundle(self, other: "CGRTensor") -> "CGRTensor": + r"""Bundle the hypervector with majority voting. Ties might be broken at random. However, the expected result is that the tie representing the lowest value wins. + + This produces a hypervector maximally similar to both. + + The bundling operation is used to aggregate information into a single hypervector. + + Args: + other (CGR): other input hypervector + + Shapes: + - Self: :math:`(*)` + - Other: :math:`(*)` + - Output: :math:`(*)` + + Examples:: + + >>> a, b = torchhd.CGRTensor.random(2, 10, block_size=64) + >>> a + CGRTensor([32, 26, 22, 22, 34, 30, 2, 4, 40, 43]) + >>> b + CGRTensor([32, 26, 39, 54, 27, 60, 60, 4, 40, 5]) + >>> a.bundle(b) + CGRTensor([32, 26, 39, 22, 27, 60, 2, 4, 40, 5]) + + """ + assert self.block_size == other.block_size + + t = torch.stack((self, other), dim=-2) + val, _ = torch.mode(t, dim=-2) + return val + + def multibundle(self) -> "CGRTensor": + """Bundle multiple hypervectors""" + + val, _ = torch.mode(self, dim=-2) + return val + diff --git a/torchhd/types.py b/torchhd/types.py index 85cb017..a7d9cab 100644 --- a/torchhd/types.py +++ b/torchhd/types.py @@ -23,4 +23,4 @@ # from typing import Literal -VSAOptions = Literal["BSC", "MAP", "HRR", "FHRR", "BSBC", "VTB", "MCR"] +VSAOptions = Literal["BSC", "MAP", "HRR", "FHRR", "BSBC", "VTB", "MCR", "CGR"] From 24c2676ff6f0b29d5a4d68ce8b77a6f3b6dff9da Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Tue, 27 May 2025 14:31:35 -0300 Subject: [PATCH 04/13] Add CGRTensor to functional embeddings Allow its usage in level and circular embeddings as done with MCRTensor. --- torchhd/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchhd/functional.py b/torchhd/functional.py index 464a9b4..a851f50 100644 --- a/torchhd/functional.py +++ b/torchhd/functional.py @@ -364,7 +364,7 @@ def level( device=span_hv.device, ).as_subclass(vsa_tensor) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv.block_size = span_hv.block_size for i in range(num_vectors): @@ -591,7 +591,7 @@ def circular( device=span_hv.device, ).as_subclass(vsa_tensor) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv.block_size = span_hv.block_size mutation_history = deque() From 9c471ca6a6626ce85fd332baea1600b97e9da91f Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Tue, 27 May 2025 17:14:01 -0300 Subject: [PATCH 05/13] Fix CGR bundle Ensure both inputs are in the same shape. --- torchhd/tensors/cgr.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/torchhd/tensors/cgr.py b/torchhd/tensors/cgr.py index 96591c8..759c340 100644 --- a/torchhd/tensors/cgr.py +++ b/torchhd/tensors/cgr.py @@ -63,13 +63,30 @@ def bundle(self, other: "CGRTensor") -> "CGRTensor": """ assert self.block_size == other.block_size - t = torch.stack((self, other), dim=-2) - val, _ = torch.mode(t, dim=-2) + # Ensure hypervectors are in the same shape, i.e., [..., 1, DIM] + t1 = self + if len(t1.shape) == 1: + t1 = t1.unsqueeze(0) + t2 = other + if len(t2.shape) == 1: + t2 = t2.unsqueeze(0) + + t = torch.stack((t1, t2), dim=-2) + val = t.multibundle() + + # Convert shape back to [DIM] if inputs are plain hypervectors + need_squeeze = len(self.shape) == 1 and len(other.shape) == 1 + if need_squeeze: + return val.squeeze(0) + return val def multibundle(self) -> "CGRTensor": """Bundle multiple hypervectors""" + # The use of torch.mode() makes untying deterministic as it always + # returns the lowest index among the ties. For example, if there is an + # equal number amount of 0s and 1s in a bundle, 0 is returned. val, _ = torch.mode(self, dim=-2) return val From 58b75d4e47ba74c168f030893cb8a9d09d69ca10 Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Tue, 27 May 2025 17:15:15 -0300 Subject: [PATCH 06/13] Add CGR into tests The CGR should behave almost the same as MCR, but diverges in bundling. Implement a custom bundling test for it. --- torchhd/tests/basis_hv/test_circular_hv.py | 18 +++---- torchhd/tests/basis_hv/test_empty_hv.py | 10 ++-- torchhd/tests/basis_hv/test_identity_hv.py | 10 ++-- torchhd/tests/basis_hv/test_level_hv.py | 18 +++---- torchhd/tests/basis_hv/test_random_hv.py | 16 +++--- torchhd/tests/test_embeddings.py | 60 +++++++++++----------- torchhd/tests/test_operations.py | 58 +++++++++++++-------- torchhd/tests/test_similarities.py | 48 ++++++++--------- torchhd/tests/utils.py | 1 + 9 files changed, 127 insertions(+), 112 deletions(-) diff --git a/torchhd/tests/basis_hv/test_circular_hv.py b/torchhd/tests/basis_hv/test_circular_hv.py index 6989f06..e7c72ae 100644 --- a/torchhd/tests/basis_hv/test_circular_hv.py +++ b/torchhd/tests/basis_hv/test_circular_hv.py @@ -41,7 +41,7 @@ def test_shape(self, n, d, vsa): if vsa == "HRR" or vsa == "VTB": return - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.circular(n, d, vsa, block_size=1024) else: hv = functional.circular(n, d, vsa) @@ -57,7 +57,7 @@ def test_generator(self, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv1 = functional.circular( 20, 10000, vsa, generator=generator, block_size=1024 ) @@ -66,7 +66,7 @@ def test_generator(self, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.circular( 20, 10000, vsa, generator=generator, block_size=1024 ) @@ -79,7 +79,7 @@ def test_generator(self, vsa): def test_value(self, dtype, vsa): if not supported_dtype(dtype, vsa): with pytest.raises(ValueError): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": functional.circular(3, 26, vsa, dtype=dtype, block_size=1024) else: functional.circular(3, 26, vsa, dtype=dtype) @@ -95,7 +95,7 @@ def test_value(self, dtype, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.circular( 50, 26569, vsa, dtype=dtype, generator=generator, block_size=1024 ) @@ -118,16 +118,16 @@ def test_value(self, dtype, vsa): mag, torch.tensor(1.0, dtype=mag.dtype), rtol=0.0001, atol=0.0001 ) - elif vsa == "BSBC" or vsa == "MCR": + elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": assert torch.all((hv >= 0) | (hv < 1024)).item() - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.circular( 8, 1000000, vsa, generator=generator, dtype=dtype, block_size=1024 ) else: hv = functional.circular(8, 1000000, vsa, generator=generator, dtype=dtype) - + for i in range(8-1): sims = functional.cosine_similarity(hv[0], hv) sims_diff = sims[:-1] - sims[1:] @@ -180,7 +180,7 @@ def test_device(self, dtype, vsa): return device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.circular( 3, 52, vsa, device=device, dtype=dtype, block_size=1024 ) diff --git a/torchhd/tests/basis_hv/test_empty_hv.py b/torchhd/tests/basis_hv/test_empty_hv.py index e9aff1c..2e3f5ea 100644 --- a/torchhd/tests/basis_hv/test_empty_hv.py +++ b/torchhd/tests/basis_hv/test_empty_hv.py @@ -41,7 +41,7 @@ class Testempty: @pytest.mark.parametrize("d", [84, 16]) @pytest.mark.parametrize("vsa", vsa_tensors) def test_shape(self, n, d, vsa): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.empty(n, d, vsa, block_size=1024) elif vsa == "VTB" and d == 84: with pytest.raises(ValueError): @@ -60,14 +60,14 @@ def test_shape(self, n, d, vsa): def test_value(self, dtype, vsa): if not supported_dtype(dtype, vsa): with pytest.raises(ValueError): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": functional.empty(3, 25, vsa, dtype=dtype, block_size=1024) else: functional.empty(3, 25, vsa, dtype=dtype) return - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.empty(8, 25, vsa, dtype=dtype, block_size=1024) else: hv = functional.empty(8, 25, vsa, dtype=dtype) @@ -80,7 +80,7 @@ def test_value(self, dtype, vsa): if vsa == "BSC": assert torch.all((hv == False) | (hv == True)).item() - elif vsa == "BSBC" or vsa == "MCR": + elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": assert torch.all((hv >= 0) | (hv < 1024)).item() else: @@ -94,7 +94,7 @@ def test_device(self, dtype, vsa): return device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.empty( 3, 52, vsa, device=device, dtype=dtype, block_size=1024 ) diff --git a/torchhd/tests/basis_hv/test_identity_hv.py b/torchhd/tests/basis_hv/test_identity_hv.py index e93a32e..f9afca6 100644 --- a/torchhd/tests/basis_hv/test_identity_hv.py +++ b/torchhd/tests/basis_hv/test_identity_hv.py @@ -41,7 +41,7 @@ class Testidentity: @pytest.mark.parametrize("d", [84, 16]) @pytest.mark.parametrize("vsa", vsa_tensors) def test_shape(self, n, d, vsa): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.identity(n, d, vsa, block_size=1042) elif vsa == "VTB" and d == 84: with pytest.raises(ValueError): @@ -61,14 +61,14 @@ def test_shape(self, n, d, vsa): def test_value(self, dtype, vsa): if not supported_dtype(dtype, vsa): with pytest.raises(ValueError): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": functional.identity(3, 26, vsa, dtype=dtype, block_size=1042) else: functional.identity(3, 25, vsa, dtype=dtype) return - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.identity(8, 25, vsa, dtype=dtype, block_size=1042) else: hv = functional.identity(8, 25, vsa, dtype=dtype) @@ -86,7 +86,7 @@ def test_value(self, dtype, vsa): x = torch.fft.fft(hv) assert torch.allclose(x, torch.full_like(x, 1.0)) - elif vsa == "BSBC" or vsa == "MCR": + elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": assert torch.all(hv == 0) elif vsa == "VTB": @@ -103,7 +103,7 @@ def test_device(self, dtype, vsa): return device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.identity( 3, 52, vsa, device=device, dtype=dtype, block_size=1042 ) diff --git a/torchhd/tests/basis_hv/test_level_hv.py b/torchhd/tests/basis_hv/test_level_hv.py index 67d828e..baf72b0 100644 --- a/torchhd/tests/basis_hv/test_level_hv.py +++ b/torchhd/tests/basis_hv/test_level_hv.py @@ -37,7 +37,7 @@ class Testlevel: @pytest.mark.parametrize("d", [84, 16]) @pytest.mark.parametrize("vsa", vsa_tensors) def test_shape(self, n, d, vsa): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.level(n, d, vsa, block_size=1024) elif vsa == "VTB" and d == 84: @@ -49,7 +49,7 @@ def test_shape(self, n, d, vsa): else: hv = functional.level(n, d, vsa) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": assert hv.block_size == 1024 assert hv.dim() == 2 @@ -60,7 +60,7 @@ def test_shape(self, n, d, vsa): def test_generator(self, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv1 = functional.level(20, 10000, vsa, generator=generator, block_size=1024) else: hv1 = functional.level(20, 10000, vsa, generator=generator) @@ -68,7 +68,7 @@ def test_generator(self, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.level(20, 10000, vsa, generator=generator, block_size=1024) else: hv2 = functional.level(20, 10000, vsa, generator=generator) @@ -79,7 +79,7 @@ def test_generator(self, vsa): def test_value(self, dtype, vsa): if not supported_dtype(dtype, vsa): with pytest.raises(ValueError): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": functional.level(3, 25, vsa, dtype=dtype, block_size=1024) else: functional.level(3, 25, vsa, dtype=dtype) @@ -89,7 +89,7 @@ def test_value(self, dtype, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.level( 50, 25921, vsa, dtype=dtype, generator=generator, block_size=1024 ) @@ -103,7 +103,7 @@ def test_value(self, dtype, vsa): if vsa == "BSC": assert torch.all((hv == False) | (hv == True)).item() - elif vsa == "BSBC" or vsa == "MCR": + elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": assert torch.all((hv >= 0) | (hv < 1024)).item() elif vsa == "MAP": @@ -123,7 +123,7 @@ def test_value(self, dtype, vsa): sims_diff = sims[:-1] - sims[1:] assert torch.all(sims_diff > 0).item(), "similarity must be decreasing" - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.level( 5, 1000000, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -163,7 +163,7 @@ def test_device(self, dtype, vsa): return device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.level( 3, 49, vsa, device=device, dtype=dtype, block_size=1024 ) diff --git a/torchhd/tests/basis_hv/test_random_hv.py b/torchhd/tests/basis_hv/test_random_hv.py index 9c8ba36..3c23d11 100644 --- a/torchhd/tests/basis_hv/test_random_hv.py +++ b/torchhd/tests/basis_hv/test_random_hv.py @@ -41,7 +41,7 @@ class Testrandom: @pytest.mark.parametrize("d", [84, 16]) @pytest.mark.parametrize("vsa", vsa_tensors) def test_shape(self, n, d, vsa): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random(n, d, vsa, block_size=64) elif vsa == "VTB" and d == 84: @@ -62,7 +62,7 @@ def test_generator(self, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv1 = functional.random(20, 10000, vsa, generator=generator, block_size=64) else: hv1 = functional.random(20, 10000, vsa, generator=generator) @@ -70,7 +70,7 @@ def test_generator(self, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.random(20, 10000, vsa, generator=generator, block_size=64) else: hv2 = functional.random(20, 10000, vsa, generator=generator) @@ -81,7 +81,7 @@ def test_generator(self, vsa): def test_value(self, dtype, vsa): if not supported_dtype(dtype, vsa): with pytest.raises(ValueError): - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": functional.random(3, 25, vsa, dtype=dtype, block_size=64) else: functional.random(3, 25, vsa, dtype=dtype) @@ -91,7 +91,7 @@ def test_value(self, dtype, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 8, 25921, vsa, dtype=dtype, generator=generator, block_size=64 ) @@ -122,7 +122,7 @@ def test_value(self, dtype, vsa): mag = hv.abs() assert torch.allclose(mag, torch.tensor(1.0, dtype=mag.dtype)) - elif vsa == "BSBC" or vsa == "MCR": + elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": assert torch.all((hv < 64) & (hv >= 0)) @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0]) @@ -155,7 +155,7 @@ def test_orthogonality(self, dtype, vsa): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 100, 10000, vsa, dtype=dtype, generator=generator, block_size=1042 ) @@ -174,7 +174,7 @@ def test_device(self, dtype, vsa): return device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 3, 49, vsa, device=device, dtype=dtype, block_size=64 ) diff --git a/torchhd/tests/test_embeddings.py b/torchhd/tests/test_embeddings.py index 61a6c1f..41e86cd 100644 --- a/torchhd/tests/test_embeddings.py +++ b/torchhd/tests/test_embeddings.py @@ -45,7 +45,7 @@ class TestEmpty: def test_embedding_dim(self, vsa): dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Empty(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Empty(embedding, dimensions, vsa=vsa) @@ -55,7 +55,7 @@ def test_embedding_dim(self, vsa): def test_num_embeddings(self, vsa): dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Empty(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Empty(embedding, dimensions, vsa=vsa) @@ -65,7 +65,7 @@ def test_num_embeddings(self, vsa): def test_dtype(self, vsa): dimensions = 4 embedding = 6 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Empty(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Empty(embedding, dimensions, vsa=vsa) @@ -86,7 +86,7 @@ def test_dtype(self, vsa): def test_value(self, vsa): dimensions = 10000 embedding = 4 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Empty(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Empty(embedding, dimensions, vsa=vsa) @@ -113,7 +113,7 @@ class TestIdentity: def test_embedding_dim(self, vsa): dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Identity(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Identity(embedding, dimensions, vsa=vsa) @@ -123,7 +123,7 @@ def test_embedding_dim(self, vsa): def test_num_embeddings(self, vsa): dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Identity(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Identity(embedding, dimensions, vsa=vsa) @@ -135,7 +135,7 @@ def test_dtype(self, vsa): embedding = 6 idx = torch.LongTensor([0, 1, 3]) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Identity(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Identity(embedding, dimensions, vsa=vsa) @@ -151,7 +151,7 @@ def test_value(self, vsa): dimensions = 9 embedding = 4 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Identity(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Identity(embedding, dimensions, vsa=vsa) @@ -216,7 +216,7 @@ class TestRandom: def test_embedding_dim(self, vsa): dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Random(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Random(embedding, dimensions, vsa=vsa) @@ -226,7 +226,7 @@ def test_embedding_dim(self, vsa): def test_num_embeddings(self, vsa): dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Random(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Random(embedding, dimensions, vsa=vsa) @@ -236,7 +236,7 @@ def test_num_embeddings(self, vsa): def test_dtype(self, vsa): dimensions = 4 embedding = 6 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Random(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Random(embedding, dimensions, vsa=vsa) @@ -252,7 +252,7 @@ def test_dtype(self, vsa): def test_value(self, vsa): dimensions = 10000 embedding = 4 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Random(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Random(embedding, dimensions, vsa=vsa) @@ -267,7 +267,7 @@ class TestLevel: def test_embedding_dim(self, vsa): dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Level(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Level(embedding, dimensions, vsa=vsa) @@ -277,7 +277,7 @@ def test_embedding_dim(self, vsa): def test_num_embeddings(self, vsa): dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Level(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Level(embedding, dimensions, vsa=vsa) @@ -287,7 +287,7 @@ def test_num_embeddings(self, vsa): def test_dtype(self, vsa): dimensions = 4 embedding = 6 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Level(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Level(embedding, dimensions, vsa=vsa) @@ -303,7 +303,7 @@ def test_dtype(self, vsa): def test_value(self, vsa): dimensions = 99856 embedding = 4 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Level(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Level(embedding, dimensions, vsa=vsa) @@ -332,7 +332,7 @@ def test_embedding_dim(self, vsa): return dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Circular(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Circular(embedding, dimensions, vsa=vsa) @@ -344,7 +344,7 @@ def test_num_embeddings(self, vsa): return dimensions = 1024 embedding = 10 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Circular(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Circular(embedding, dimensions, vsa=vsa) @@ -356,7 +356,7 @@ def test_dtype(self, vsa): return dimensions = 4 embedding = 6 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Circular(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Circular(embedding, dimensions, vsa=vsa) @@ -380,7 +380,7 @@ def test_value(self, vsa): return dimensions = 99856 embedding = 4 - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": emb = embeddings.Circular(embedding, dimensions, vsa=vsa, block_size=1024) else: emb = embeddings.Circular(embedding, dimensions, vsa=vsa) @@ -405,7 +405,7 @@ def test_value(self, vsa): class TestThermometer: @pytest.mark.parametrize("vsa", vsa_tensors) def test_embedding_dim(self, vsa): - if vsa in {"HRR", "BSBC", "MCR"}: + if vsa in {"HRR", "BSBC", "MCR", "CGR"}: return dimensions = 1024 embedding = 10 @@ -479,7 +479,7 @@ def test_value(self, vsa): class TestProjection: @pytest.mark.parametrize("vsa", vsa_tensors) def test_in_features(self, vsa): - if vsa in {"BSC", "FHRR", "BSBC", "MCR"}: + if vsa in {"BSC", "FHRR", "BSBC", "MCR", "CGR"}: return in_features = 1020 out_features = 16 @@ -488,7 +488,7 @@ def test_in_features(self, vsa): @pytest.mark.parametrize("vsa", vsa_tensors) def test_out_features(self, vsa): - if vsa in {"BSC", "FHRR", "BSBC", "MCR"}: + if vsa in {"BSC", "FHRR", "BSBC", "MCR", "CGR"}: return in_features = 1020 out_features = 16 @@ -497,7 +497,7 @@ def test_out_features(self, vsa): @pytest.mark.parametrize("vsa", vsa_tensors) def test_dtype(self, vsa): - if vsa in {"BSC", "FHRR", "BSBC", "MCR"}: + if vsa in {"BSC", "FHRR", "BSBC", "MCR", "CGR"}: return in_features = 1000 out_features = 16 @@ -510,7 +510,7 @@ def test_dtype(self, vsa): @pytest.mark.parametrize("vsa", vsa_tensors) def test_value(self, vsa): - if vsa in {"BSC", "FHRR", "BSBC", "MCR"}: + if vsa in {"BSC", "FHRR", "BSBC", "MCR", "CGR"}: return in_features = 100000 out_features = 100 @@ -524,7 +524,7 @@ def test_value(self, vsa): class TestSinusoid: @pytest.mark.parametrize("vsa", vsa_tensors) def test_in_features(self, vsa): - if vsa in {"BSC", "FHRR", "BSBC", "MCR"}: + if vsa in {"BSC", "FHRR", "BSBC", "MCR", "CGR"}: return in_features = 1000 out_features = 16 @@ -533,7 +533,7 @@ def test_in_features(self, vsa): @pytest.mark.parametrize("vsa", vsa_tensors) def test_out_features(self, vsa): - if vsa in {"BSC", "FHRR", "BSBC", "MCR"}: + if vsa in {"BSC", "FHRR", "BSBC", "MCR", "CGR"}: return in_features = 1000 out_features = 16 @@ -542,7 +542,7 @@ def test_out_features(self, vsa): @pytest.mark.parametrize("vsa", vsa_tensors) def test_dtype(self, vsa): - if vsa in {"BSC", "FHRR", "BSBC", "MCR"}: + if vsa in {"BSC", "FHRR", "BSBC", "MCR", "CGR"}: return in_features = 1000 out_features = 16 @@ -555,7 +555,7 @@ def test_dtype(self, vsa): @pytest.mark.parametrize("vsa", vsa_tensors) def test_value(self, vsa): - if vsa in {"BSC", "FHRR", "BSBC", "MCR"}: + if vsa in {"BSC", "FHRR", "BSBC", "MCR", "CGR"}: return in_features = 100000 out_features = 16 @@ -574,7 +574,7 @@ def test_embedding_dim(self, vsa): embedding = 16 if vsa not in {"BSC", "MAP", "FHRR"}: - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": with pytest.raises(ValueError): emb = embeddings.Density( embedding, dimensions, vsa=vsa, block_size=1024 diff --git a/torchhd/tests/test_operations.py b/torchhd/tests/test_operations.py index ab9437c..d21d29b 100644 --- a/torchhd/tests/test_operations.py +++ b/torchhd/tests/test_operations.py @@ -41,7 +41,7 @@ def test_value(self, vsa, dtype): if not supported_dtype(dtype, vsa): return - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.empty(2, 10, vsa, dtype=dtype, block_size=1024) else: hv = functional.empty(2, 16, vsa, dtype=dtype) @@ -56,7 +56,7 @@ def test_value(self, vsa, dtype): assert torch.all(res == ifft(torch.mul(fft(hv[0]), fft(hv[1])))).item() elif vsa == "BSBC": assert torch.all(res == ((hv[0] + hv[1]) % 1024)) - elif vsa == "MCR": + elif vsa == "MCR" or vsa == "CGR": assert torch.all(res == ((hv[0] + hv[1]) % 1024)) assert dtype == res.dtype @@ -80,7 +80,7 @@ def test_value(self, vsa, dtype): if not supported_dtype(dtype, vsa): return - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random(2, 10, vsa, dtype=dtype, block_size=1024) else: hv = functional.random(2, 16, vsa, dtype=dtype) @@ -112,7 +112,7 @@ def test_value(self, vsa, dtype): assert (res[i].item() == hv[0][i].item()) or ( res[i].item() == hv[1][i].item() ) - + if vsa == "MCR": x = torch.tensor([1, 3, 5, 7, 9, 0, 2, 4, 6, 8], dtype=dtype) x = functional.ensure_vsa_tensor(x,'MCR') @@ -122,10 +122,24 @@ def test_value(self, vsa, dtype): y.block_size = 10 res = functional.bundle(x, y) - + possible_values = [[0,1], [1,2], [3,4], [5], [6,7,1,2], [2,3,7,8], [4],[5,6], [7], [8,9]] for i in range(10): - assert (res[i].item() in possible_values[i]) + assert (res[i].item() in possible_values[i]) + + if vsa == "CGR": + x = torch.tensor([1, 3, 5, 7, 9, 0, 2, 4, 6, 8], dtype=dtype) + x = functional.ensure_vsa_tensor(x,'CGR') + x.block_size = 10 + y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=dtype) + y = functional.ensure_vsa_tensor(y,'CGR') + y.block_size = 10 + + res = functional.bundle(x, y) + + possible_values = [[1,0], [3,1], [5,2], [7,3], [9,4], [0,5], [2,6], [4,7], [6,8], [8,9]] + for i in range(10): + assert (res[i].item() in possible_values[i]) assert res.dtype == dtype @@ -149,7 +163,7 @@ def test_value(self, vsa, dtype): if not supported_dtype(dtype, vsa): return - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random(2, 100, vsa, dtype=dtype, block_size=1024) else: hv = functional.random(2, 100, vsa, dtype=dtype) @@ -185,7 +199,7 @@ def test_value(self, vsa, dtype): 0 ), "all element must not be the same" - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random(1, 10000, vsa, dtype=dtype, block_size=1024) else: hv = functional.random(1, 10000, vsa, dtype=dtype) @@ -214,7 +228,7 @@ def test_value(self, vsa, dtype): if not supported_dtype(dtype, vsa): return - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random(12, 900, vsa, dtype=dtype, block_size=1024) else: hv = functional.random(12, 900, vsa, dtype=dtype) @@ -266,13 +280,13 @@ def test_value(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(2147483644) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 5, 100, vsa, dtype=dtype, generator=generator, block_size=1024 ) else: hv = functional.random(5, 100, vsa, dtype=dtype, generator=generator) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": noise = functional.random( 1, 100, vsa, dtype=dtype, generator=generator, block_size=1024 ) @@ -290,13 +304,13 @@ def test_threshold(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(2147483644) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 5, 100, vsa, dtype=dtype, generator=generator, block_size=1024 ) else: hv = functional.random(5, 100, vsa, dtype=dtype, generator=generator) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": noise = functional.random( 1, 100, vsa, dtype=dtype, generator=generator, block_size=1024 ) @@ -312,7 +326,7 @@ def test_device(self, vsa, dtype): return device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 5, 100, vsa, dtype=dtype, device=device, block_size=1024 ) @@ -331,7 +345,7 @@ def test_value(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(2147483644) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": a, b = functional.random( 2, 1000, vsa, dtype=dtype, generator=generator, block_size=1024 ) @@ -340,7 +354,7 @@ def test_value(self, vsa, dtype): res = functional.randsel(a, b, p=0, generator=generator) assert torch.all(b == res) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": a, b = functional.random( 2, 1000, vsa, dtype=dtype, generator=generator, block_size=1024 ) @@ -349,7 +363,7 @@ def test_value(self, vsa, dtype): res = functional.randsel(a, b, p=1, generator=generator) assert torch.all(a == res) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": a, b = functional.random( 2, 1000, vsa, dtype=dtype, generator=generator, block_size=1024 ) @@ -366,7 +380,7 @@ def test_device(self, vsa, dtype): return device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": a, b = functional.random( 2, 100, vsa, dtype=dtype, device=device, block_size=1024 ) @@ -389,7 +403,7 @@ def test_value(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(2147483644) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": x = functional.random(4, 1000, vsa, dtype=dtype, block_size=1024) else: x = functional.random(4, 1024, vsa, dtype=dtype) @@ -399,7 +413,7 @@ def test_value(self, vsa, dtype): ) assert torch.all(x[2] == res) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": x = functional.random(4, 1000, vsa, dtype=dtype, block_size=1024) else: x = functional.random(4, 1024, vsa, dtype=dtype) @@ -408,7 +422,7 @@ def test_value(self, vsa, dtype): ) assert torch.all((x[0] == res) | (x[2] == res)) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": x = functional.random(4, 1000, vsa, dtype=dtype, block_size=1024) else: x = functional.random(4, 1024, vsa, dtype=dtype) @@ -435,7 +449,7 @@ def test_value(self, vsa, dtype): if not supported_dtype(dtype, vsa): return - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": x = functional.random(4, 100, vsa, block_size=1024) else: x = functional.random(4, 100, vsa) diff --git a/torchhd/tests/test_similarities.py b/torchhd/tests/test_similarities.py index 0a3681d..b92bffa 100644 --- a/torchhd/tests/test_similarities.py +++ b/torchhd/tests/test_similarities.py @@ -49,7 +49,7 @@ def test_shape(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -58,7 +58,7 @@ def test_shape(self, vsa, dtype): similarity = functional.dot_similarity(hv[0], hv[1]) assert similarity.shape == () - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -67,13 +67,13 @@ def test_shape(self, vsa, dtype): similarity = functional.dot_similarity(hv[0], hv) assert similarity.shape == (2,) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) else: hv = functional.random(2, 100, vsa, generator=generator, dtype=dtype) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.random( 4, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -82,7 +82,7 @@ def test_shape(self, vsa, dtype): similarity = functional.dot_similarity(hv, hv2) assert similarity.shape == (2, 4) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv1 = functional.random( 6, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ).view(2, 3, 100) @@ -90,7 +90,7 @@ def test_shape(self, vsa, dtype): hv1 = functional.random(6, 100, vsa, generator=generator, dtype=dtype).view( 2, 3, 100 ) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.random( 4, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -179,7 +179,7 @@ def test_dtype(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 3, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -239,7 +239,7 @@ def test_device(self, vsa, dtype): generator = torch.Generator(device) generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 3, 100, @@ -269,7 +269,7 @@ def test_shape(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -278,7 +278,7 @@ def test_shape(self, vsa, dtype): similarity = functional.cosine_similarity(hv[0], hv[1]) assert similarity.shape == () - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -287,13 +287,13 @@ def test_shape(self, vsa, dtype): similarity = functional.cosine_similarity(hv[0], hv) assert similarity.shape == (2,) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) else: hv = functional.random(2, 100, vsa, generator=generator, dtype=dtype) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.random( 4, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -302,7 +302,7 @@ def test_shape(self, vsa, dtype): similarity = functional.cosine_similarity(hv, hv2) assert similarity.shape == (2, 4) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv1 = functional.random( 6, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ).view(2, 3, 100) @@ -311,7 +311,7 @@ def test_shape(self, vsa, dtype): 2, 3, 100 ) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.random( 4, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -400,7 +400,7 @@ def test_dtype(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 3, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -429,7 +429,7 @@ def test_device(self, vsa, dtype): generator = torch.Generator(device) generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 3, 100, @@ -459,7 +459,7 @@ def test_shape(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -468,7 +468,7 @@ def test_shape(self, vsa, dtype): similarity = functional.hamming_similarity(hv[0], hv[1]) assert similarity.shape == () - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -477,13 +477,13 @@ def test_shape(self, vsa, dtype): similarity = functional.hamming_similarity(hv[0], hv) assert similarity.shape == (2,) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 2, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) else: hv = functional.random(2, 100, vsa, generator=generator, dtype=dtype) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.random( 4, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -492,7 +492,7 @@ def test_shape(self, vsa, dtype): similarity = functional.hamming_similarity(hv, hv2) assert similarity.shape == (2, 4) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv1 = functional.random( 6, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ).view(2, 3, 100) @@ -501,7 +501,7 @@ def test_shape(self, vsa, dtype): 2, 3, 100 ) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv2 = functional.random( 4, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -589,7 +589,7 @@ def test_dtype(self, vsa, dtype): generator = torch.Generator() generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 3, 100, vsa, generator=generator, dtype=dtype, block_size=1024 ) @@ -610,7 +610,7 @@ def test_device(self, vsa, dtype): generator = torch.Generator(device) generator.manual_seed(seed) - if vsa == "BSBC" or vsa == "MCR": + if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR": hv = functional.random( 3, 100, diff --git a/torchhd/tests/utils.py b/torchhd/tests/utils.py index 2518ec8..c26a1ff 100644 --- a/torchhd/tests/utils.py +++ b/torchhd/tests/utils.py @@ -95,4 +95,5 @@ def supported_dtype( "BSBC", "VTB", "MCR", + "CGR" ] From 124312a5bb387a001421ad64d0bea3015808edd3 Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Thu, 29 May 2025 12:50:17 -0300 Subject: [PATCH 07/13] Use Tensor.dim() in CGRTensor Shorter syntax and more readable. --- torchhd/tensors/cgr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchhd/tensors/cgr.py b/torchhd/tensors/cgr.py index 759c340..767905b 100644 --- a/torchhd/tensors/cgr.py +++ b/torchhd/tensors/cgr.py @@ -65,17 +65,17 @@ def bundle(self, other: "CGRTensor") -> "CGRTensor": # Ensure hypervectors are in the same shape, i.e., [..., 1, DIM] t1 = self - if len(t1.shape) == 1: + if t1.dim() == 1: t1 = t1.unsqueeze(0) t2 = other - if len(t2.shape) == 1: + if t2.dim() == 1: t2 = t2.unsqueeze(0) t = torch.stack((t1, t2), dim=-2) val = t.multibundle() # Convert shape back to [DIM] if inputs are plain hypervectors - need_squeeze = len(self.shape) == 1 and len(other.shape) == 1 + need_squeeze = self.dim() == 1 and other.dim() == 1 if need_squeeze: return val.squeeze(0) From 0290270b1a9a93ba8c636cccc282cc10eb82d3d3 Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Thu, 29 May 2025 12:54:21 -0300 Subject: [PATCH 08/13] Remove asserts in CGR/MCR classes No need to check block size in CGR/MCR functions as the __torch_function__() in BaseMCRTensor already checks it. --- torchhd/tensors/basemcr.py | 1 - torchhd/tensors/cgr.py | 3 --- torchhd/tensors/mcr.py | 2 -- 3 files changed, 6 deletions(-) diff --git a/torchhd/tensors/basemcr.py b/torchhd/tensors/basemcr.py index b3313ce..ea996ea 100644 --- a/torchhd/tensors/basemcr.py +++ b/torchhd/tensors/basemcr.py @@ -255,7 +255,6 @@ def bind(self, other: "BaseMCRTensor") -> "BaseMCRTensor": BaseMCRTensor([ 0, 27, 61, 21, 0, 38, 0, 13, 39, 18]) """ - assert self.block_size == other.block_size return torch.remainder(torch.add(self, other), self.block_size) def multibind(self) -> "BaseMCRTensor": diff --git a/torchhd/tensors/cgr.py b/torchhd/tensors/cgr.py index 767905b..60a9178 100644 --- a/torchhd/tensors/cgr.py +++ b/torchhd/tensors/cgr.py @@ -61,8 +61,6 @@ def bundle(self, other: "CGRTensor") -> "CGRTensor": CGRTensor([32, 26, 39, 22, 27, 60, 2, 4, 40, 5]) """ - assert self.block_size == other.block_size - # Ensure hypervectors are in the same shape, i.e., [..., 1, DIM] t1 = self if t1.dim() == 1: @@ -83,7 +81,6 @@ def bundle(self, other: "CGRTensor") -> "CGRTensor": def multibundle(self) -> "CGRTensor": """Bundle multiple hypervectors""" - # The use of torch.mode() makes untying deterministic as it always # returns the lowest index among the ties. For example, if there is an # equal number amount of 0s and 1s in a bundle, 0 is returned. diff --git a/torchhd/tensors/mcr.py b/torchhd/tensors/mcr.py index 3037999..7559915 100644 --- a/torchhd/tensors/mcr.py +++ b/torchhd/tensors/mcr.py @@ -62,8 +62,6 @@ def bundle(self, other: "MCRTensor") -> "MCRTensor": MCRTensor([32, 26, 39, 54, 27, 60, 2, 4, 40, 5]) """ - assert self.block_size == other.block_size - self_phasor = self.to_complex_unit() other_phasor = other.to_complex_unit() From 7e556321961178f548f015ce41328cd65454addc Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Thu, 29 May 2025 12:59:32 -0300 Subject: [PATCH 09/13] Improve exception in BaseMCRTensor if multiple block_sizes - Raise ValueError instead of RuntimeError --- torchhd/tensors/basemcr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchhd/tensors/basemcr.py b/torchhd/tensors/basemcr.py index ea996ea..838bf62 100644 --- a/torchhd/tensors/basemcr.py +++ b/torchhd/tensors/basemcr.py @@ -373,8 +373,8 @@ def _parse_container_for_attr(container, attr): block_sizes = _parse_container_for_attr(args, "block_size") if len(block_sizes) != 1: - raise RuntimeError( - f"Call to {func} must contain exactly one block size, got {list(block_sizes)}" + raise ValueError( + f"Call to {func} must contain exactly one block size, got {block_sizes}" ) # Call with super to avoid infinite recursion From dacb780817e1ff509b3ae2778d39eca43cf1ddaf Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Thu, 29 May 2025 13:57:29 -0300 Subject: [PATCH 10/13] Add CGRTensor to docs Include CGRTensor in the README and in the built docs. --- README.md | 1 + docs/torchhd.rst | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 996b00a..3f73eea 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,7 @@ Currently, the library supports the following HDC/VSA models: - [Holographic Reduced Representations (HRR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.HRRTensor.html) - [Fourier Holographic Reduced Representations (FHRR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.FHRRTensor.html) - [Binary Sparse Block Codes (B-SBC)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.BSBCTensor.html) +- [Modular Composite Representation (CGR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.CGRTensor.html) - [Modular Composite Representation (MCR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.MCRTensor.html) - [Vector-Derived Transformation Binding (VTB)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.VTBTensor.html) diff --git a/docs/torchhd.rst b/docs/torchhd.rst index 73342e4..088c1e3 100644 --- a/docs/torchhd.rst +++ b/docs/torchhd.rst @@ -10,7 +10,7 @@ This module consists of the basic hypervector generation functions and operation Basis-hypervector sets ---------------------------------- -.. autosummary:: +.. autosummary:: :toctree: generated/ :template: function.rst @@ -20,7 +20,7 @@ Basis-hypervector sets level thermometer circular - + Operations -------------------- @@ -51,7 +51,7 @@ Similarities .. autosummary:: :toctree: generated/ :template: function.rst - + cosine_similarity dot_similarity hamming_similarity @@ -87,6 +87,7 @@ VSA Models HRRTensor FHRRTensor BSBCTensor + CGRTensor MCRTensor VTBTensor @@ -97,7 +98,7 @@ Utilities .. autosummary:: :toctree: generated/ :template: function.rst - + ensure_vsa_tensor map_range value_to_index From f65c7273813f93862e3443a67ee53e0980a0baab Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Thu, 29 May 2025 14:02:44 -0300 Subject: [PATCH 11/13] Add BaseMCRTensor to docs --- docs/torchhd.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/torchhd.rst b/docs/torchhd.rst index 088c1e3..32add35 100644 --- a/docs/torchhd.rst +++ b/docs/torchhd.rst @@ -87,8 +87,9 @@ VSA Models HRRTensor FHRRTensor BSBCTensor - CGRTensor + BaseMCRTensor MCRTensor + CGRTensor VTBTensor From b7ac97cfb846b33ee6afadc5d38786cd2024a81d Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Thu, 29 May 2025 15:27:48 -0300 Subject: [PATCH 12/13] Fix typo in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3f73eea..29cdecd 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ Currently, the library supports the following HDC/VSA models: - [Holographic Reduced Representations (HRR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.HRRTensor.html) - [Fourier Holographic Reduced Representations (FHRR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.FHRRTensor.html) - [Binary Sparse Block Codes (B-SBC)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.BSBCTensor.html) -- [Modular Composite Representation (CGR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.CGRTensor.html) +- [Cyclic Group Representation (CGR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.CGRTensor.html) - [Modular Composite Representation (MCR)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.MCRTensor.html) - [Vector-Derived Transformation Binding (VTB)](https://torchhd.readthedocs.io/en/stable/generated/torchhd.VTBTensor.html) From 2f796b7bd4ced5da8aa54777177ecf2b7714d784 Mon Sep 17 00:00:00 2001 From: Caio Vieira Date: Thu, 12 Jun 2025 16:42:14 -0300 Subject: [PATCH 13/13] Improve documentation for MCRTensor and CGRTensor Replicate function signatures in inherited classes (MCRTensor and CGRTensor) from their parent class (BaseMCRTensor) to let sphinx print the documentation of each parent function. --- torchhd/tensors/cgr.py | 89 +++++++++++++++++++++++++++++++++++++++++- torchhd/tensors/mcr.py | 86 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) diff --git a/torchhd/tensors/cgr.py b/torchhd/tensors/cgr.py index 60a9178..35d292b 100644 --- a/torchhd/tensors/cgr.py +++ b/torchhd/tensors/cgr.py @@ -35,6 +35,71 @@ class CGRTensor(BaseMCRTensor): First introduced in `Modular Composite Representation `_ and then better elaborated in `Understanding hyperdimensional computing for parallel single-pass learning `_, this model works with modular integer vectors. It works similar to the MCR class, but uses a bundling based on element-wise mode instead of addition of complex numbers. """ + @classmethod + def empty( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + generator=None, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "CGRTensor": + return super().empty( + num_vectors, + dimensions, + block_size=block_size, + generator=generator, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + @classmethod + def identity( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "CGRTensor": + return super().identity( + num_vectors, + dimensions, + block_size=block_size, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + @classmethod + def random( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + generator=None, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "CGRTensor": + return super().random( + num_vectors, + dimensions, + block_size=block_size, + generator=generator, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + def bundle(self, other: "CGRTensor") -> "CGRTensor": r"""Bundle the hypervector with majority voting. Ties might be broken at random. However, the expected result is that the tie representing the lowest value wins. @@ -83,7 +148,29 @@ def multibundle(self) -> "CGRTensor": """Bundle multiple hypervectors""" # The use of torch.mode() makes untying deterministic as it always # returns the lowest index among the ties. For example, if there is an - # equal number amount of 0s and 1s in a bundle, 0 is returned. + # equal amount of 0s and 1s in a bundle, 0 is returned. val, _ = torch.mode(self, dim=-2) return val + def bind(self, other: "CGRTensor") -> "CGRTensor": + return super().bind(other) + + def multibind(self) -> "CGRTensor": + """Bind multiple hypervectors""" + return super().multibind() + + def inverse(self) -> "CGRTensor": + return super().inverse() + + def permute(self, shifts: int = 1) -> "CGRTensor": + return super().permute(shifts=shifts) + + def normalize(self) -> "CGRTensor": + return super().normalize() + + def dot_similarity(self, others: "CGRTensor", *, dtype=None) -> Tensor: + return super().dot_similarity(others, dtype=dtype) + + def cosine_similarity(self, others: "CGRTensor", *, dtype=None) -> Tensor: + return super().cosine_similarity(others, dtype=dtype) + diff --git a/torchhd/tensors/mcr.py b/torchhd/tensors/mcr.py index 7559915..ebca6d3 100644 --- a/torchhd/tensors/mcr.py +++ b/torchhd/tensors/mcr.py @@ -36,6 +36,70 @@ class MCRTensor(BaseMCRTensor): Proposed in `Modular Composite Representation `_, this model works with modular integer vectors. """ + @classmethod + def empty( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + generator=None, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "MCRTensor": + return super().empty( + num_vectors, + dimensions, + block_size=block_size, + generator=generator, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + @classmethod + def identity( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "MCRTensor": + return super().identity( + num_vectors, + dimensions, + block_size=block_size, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + @classmethod + def random( + cls, + num_vectors: int, + dimensions: int, + *, + block_size: int, + generator=None, + dtype=torch.int64, + device=None, + requires_grad=False, + ) -> "MCRTensor": + return super().random( + num_vectors, + dimensions, + block_size=block_size, + generator=generator, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + def bundle(self, other: "MCRTensor") -> "MCRTensor": r"""Bundle the hypervector with normalized complex vector addition. @@ -100,3 +164,25 @@ def multibundle(self) -> "MCRTensor": return torch.remainder(result, self.block_size).type(self.dtype) + def bind(self, other: "MCRTensor") -> "MCRTensor": + return super().bind(other) + + def multibind(self) -> "MCRTensor": + """Bind multiple hypervectors""" + return super().multibind() + + def inverse(self) -> "MCRTensor": + return super().inverse() + + def permute(self, shifts: int = 1) -> "MCRTensor": + return super().permute(shifts=shifts) + + def normalize(self) -> "MCRTensor": + return super().normalize() + + def dot_similarity(self, others: "MCRTensor", *, dtype=None) -> Tensor: + return super().dot_similarity(others, dtype=dtype) + + def cosine_similarity(self, others: "MCRTensor", *, dtype=None) -> Tensor: + return super().cosine_similarity(others, dtype=dtype) +