Skip to content

Commit d634328

Browse files
committed
Add BaseMCRTensor
Add a new base class to support future implementation of variations of MCR proposals.
1 parent 6b7478d commit d634328

File tree

3 files changed

+389
-341
lines changed

3 files changed

+389
-341
lines changed

torchhd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from torchhd.tensors.fhrr import FHRRTensor
3838
from torchhd.tensors.bsbc import BSBCTensor
3939
from torchhd.tensors.vtb import VTBTensor
40+
from torchhd.tensors.basemcr import BaseMCRTensor
4041
from torchhd.tensors.mcr import MCRTensor
4142

4243
from torchhd.functional import (
@@ -91,6 +92,7 @@
9192
"FHRRTensor",
9293
"BSBCTensor",
9394
"VTBTensor",
95+
"BaseMCRTensor",
9496
"MCRTensor",
9597
"functional",
9698
"embeddings",

torchhd/tensors/basemcr.py

Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
#
2+
# MIT License
3+
#
4+
# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham
5+
#
6+
# Permission is hereby granted, free of charge, to any person obtaining a copy
7+
# of this software and associated documentation files (the "Software"), to deal
8+
# in the Software without restriction, including without limitation the rights
9+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
# copies of the Software, and to permit persons to whom the Software is
11+
# furnished to do so, subject to the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be included in all
14+
# copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
# SOFTWARE.
23+
#
24+
import torch
25+
from torch import Tensor
26+
import torch.nn.functional as F
27+
from typing import Set
28+
29+
from torchhd.tensors.base import VSATensor
30+
31+
32+
class BaseMCRTensor(VSATensor):
33+
r"""Base class for VSA Modular Composite Representations (MCR)
34+
35+
Proposed in `Modular Composite Representation <https://link.springer.com/article/10.1007/s12559-013-9243-y>`_, this model works with modular integer vectors. The base class is used as template for the MCR and the Cyclic Group Representation (CGR), which is very similar to MCR but uses a different bundling operation.
36+
"""
37+
38+
block_size: int
39+
supported_dtypes: Set[torch.dtype] = {
40+
torch.float32,
41+
torch.float64,
42+
torch.int16,
43+
torch.int32,
44+
torch.int64,
45+
}
46+
47+
@classmethod
48+
def empty(
49+
cls,
50+
num_vectors: int,
51+
dimensions: int,
52+
*,
53+
block_size: int,
54+
generator=None,
55+
dtype=torch.int64,
56+
device=None,
57+
requires_grad=False,
58+
) -> "BaseMCRTensor":
59+
r"""Creates a set of hypervectors representing empty sets.
60+
61+
When bundled with a hypervector :math:`x`, the result is :math:`x`.
62+
Because of the low precession of the MCR model an empty set cannot be explicitly represented, therefore the returned hypervectors are identical to random-hypervectors.
63+
64+
Args:
65+
num_vectors (int): the number of hypervectors to generate.
66+
dimensions (int): the dimensionality of the hypervectors.
67+
block_size (int): the number of elements per block which controls the angular granularity.
68+
generator (``torch.Generator``, optional): a pseudorandom number generator for sampling.
69+
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: ``int64``.
70+
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
71+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
72+
73+
Examples::
74+
75+
>>> torchhd.BaseMCRTensor.empty(3, 6, block_size=64)
76+
BaseMCRTensor([[54, 3, 22, 27, 41, 21],
77+
[17, 31, 55, 3, 44, 52],
78+
[42, 37, 60, 54, 13, 41]])
79+
80+
"""
81+
82+
if dtype == None:
83+
dtype = torch.int64
84+
85+
if dtype not in cls.supported_dtypes:
86+
name = cls.__name__
87+
options = ", ".join([str(x) for x in cls.supported_dtypes])
88+
raise ValueError(
89+
f"{name} vectors must be one of dtype {options}, got {dtype}."
90+
)
91+
92+
result = torch.randint(
93+
0,
94+
block_size,
95+
(num_vectors, dimensions),
96+
generator=generator,
97+
dtype=dtype,
98+
device=device,
99+
requires_grad=requires_grad,
100+
)
101+
102+
result = result.as_subclass(cls)
103+
result.block_size = block_size
104+
return result
105+
106+
@classmethod
107+
def identity(
108+
cls,
109+
num_vectors: int,
110+
dimensions: int,
111+
*,
112+
block_size: int,
113+
dtype=torch.int64,
114+
device=None,
115+
requires_grad=False,
116+
) -> "BaseMCRTensor":
117+
r"""Creates a set of identity hypervectors.
118+
119+
When bound with a random-hypervector :math:`x`, the result is :math:`x`.
120+
121+
Args:
122+
num_vectors (int): the number of hypervectors to generate.
123+
dimensions (int): the dimensionality of the hypervectors.
124+
block_size (int): the number of elements per block which controls the angular granularity.
125+
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``int64`` depends on VSATensor.
126+
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
127+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
128+
129+
Examples::
130+
131+
>>> torchhd.BaseMCRTensor.identity(3, 6, block_size=64)
132+
BaseMCRTensor([[0, 0, 0, 0, 0, 0],
133+
[0, 0, 0, 0, 0, 0],
134+
[0, 0, 0, 0, 0, 0]])
135+
136+
"""
137+
if dtype == None:
138+
dtype = torch.int64
139+
140+
if dtype not in cls.supported_dtypes:
141+
name = cls.__name__
142+
options = ", ".join([str(x) for x in cls.supported_dtypes])
143+
raise ValueError(
144+
f"{name} vectors must be one of dtype {options}, got {dtype}."
145+
)
146+
147+
result = torch.zeros(
148+
num_vectors,
149+
dimensions,
150+
dtype=dtype,
151+
device=device,
152+
requires_grad=requires_grad,
153+
)
154+
155+
result = result.as_subclass(cls)
156+
result.block_size = block_size
157+
return result
158+
159+
@classmethod
160+
def random(
161+
cls,
162+
num_vectors: int,
163+
dimensions: int,
164+
*,
165+
block_size: int,
166+
generator=None,
167+
dtype=torch.int64,
168+
device=None,
169+
requires_grad=False,
170+
) -> "BaseMCRTensor":
171+
r"""Creates a set of random independent hypervectors.
172+
173+
The resulting hypervectors sample uniformly random integers between 0 and ``block_size``.
174+
175+
Args:
176+
num_vectors (int): the number of hypervectors to generate.
177+
dimensions (int): the dimensionality of the hypervectors.
178+
block_size (int): the number of elements per block which controls the angular granularity.
179+
generator (``torch.Generator``, optional): a pseudorandom number generator for sampling.
180+
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: ``int64``.
181+
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
182+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
183+
184+
Examples::
185+
186+
>>> torchhd.BaseMCRTensor.random(3, 6, block_size=64)
187+
BaseMCRTensor([[ 7, 1, 39, 8, 55, 22],
188+
[51, 38, 59, 45, 13, 29],
189+
[19, 26, 30, 5, 15, 51]])
190+
>>> torchhd.BaseMCRTensor.random(3, 6, block_size=128, dtype=torch.float32)
191+
BaseMCRTensor([[116., 25., 100., 10., 21., 86.],
192+
[ 69., 49., 2., 56., 78., 70.],
193+
[ 77., 47., 37., 106., 8., 30.]])
194+
195+
"""
196+
if dtype == None:
197+
dtype = torch.int64
198+
199+
if dtype not in cls.supported_dtypes:
200+
name = cls.__name__
201+
options = ", ".join([str(x) for x in cls.supported_dtypes])
202+
raise ValueError(
203+
f"{name} vectors must be one of dtype {options}, got {dtype}."
204+
)
205+
206+
result = torch.randint(
207+
0,
208+
block_size,
209+
(num_vectors, dimensions),
210+
generator=generator,
211+
dtype=dtype,
212+
device=device,
213+
requires_grad=requires_grad,
214+
)
215+
216+
result = result.as_subclass(cls)
217+
result.block_size = block_size
218+
return result
219+
220+
def to_complex_unit(self):
221+
angles = 2 * torch.pi * self / self.block_size
222+
return torch.polar(torch.ones_like(self, dtype=angles.dtype), angles)
223+
224+
def bundle(self, other: "BaseMCRTensor") -> "BaseMCRTensor":
225+
"""Bundle the hypervector with other"""
226+
raise NotImplementedError
227+
228+
def multibundle(self) -> "BaseMCRTensor":
229+
"""Bundle multiple hypervectors"""
230+
return super().multibundle()
231+
232+
def bind(self, other: "BaseMCRTensor") -> "BaseMCRTensor":
233+
r"""Bind the hypervector with other using circular convolution.
234+
235+
This produces a hypervector dissimilar to both.
236+
237+
Binding is used to associate information, for instance, to assign values to variables.
238+
239+
Args:
240+
other (BaseMCRTensor): other input hypervector
241+
242+
Shapes:
243+
- Self: :math:`(*)`
244+
- Other: :math:`(*)`
245+
- Output: :math:`(*)`
246+
247+
Examples::
248+
249+
>>> a, b = torchhd.BaseMCRTensor.random(2, 10, block_size=64)
250+
>>> a
251+
BaseMCRTensor([18, 55, 40, 62, 39, 26, 35, 24, 49, 41])
252+
>>> b
253+
BaseMCRTensor([46, 36, 21, 23, 25, 12, 29, 53, 54, 41])
254+
>>> a.bind(b)
255+
BaseMCRTensor([ 0, 27, 61, 21, 0, 38, 0, 13, 39, 18])
256+
257+
"""
258+
assert self.block_size == other.block_size
259+
return torch.remainder(torch.add(self, other), self.block_size)
260+
261+
def multibind(self) -> "BaseMCRTensor":
262+
"""Bind multiple hypervectors"""
263+
return torch.remainder(
264+
torch.sum(self, dim=-2, dtype=self.dtype), self.block_size
265+
)
266+
267+
def inverse(self) -> "BaseMCRTensor":
268+
r"""Invert the hypervector for binding.
269+
270+
Shapes:
271+
- Self: :math:`(*)`
272+
- Output: :math:`(*)`
273+
274+
Examples::
275+
276+
>>> a = torchhd.BaseMCRTensor.random(1, 10, block_size=64)
277+
>>> a
278+
BaseMCRTensor([[ 5, 30, 15, 43, 19, 36, 4, 14, 57, 34]])
279+
>>> a.inverse()
280+
BaseMCRTensor([[59, 34, 49, 21, 45, 28, 60, 50, 7, 30]])
281+
282+
"""
283+
284+
return torch.remainder(torch.negative(self), self.block_size)
285+
286+
def permute(self, shifts: int = 1) -> "BaseMCRTensor":
287+
r"""Permute the hypervector.
288+
289+
The permutation operator is commonly used to assign an order to hypervectors.
290+
291+
Args:
292+
shifts (int, optional): The number of places by which the elements of the tensor are shifted.
293+
294+
Shapes:
295+
- Self: :math:`(*)`
296+
- Output: :math:`(*)`
297+
298+
Examples::
299+
300+
>>> a = torchhd.BaseMCRTensor.random(1, 10, block_size=64)
301+
>>> a
302+
BaseMCRTensor([[33, 24, 1, 36, 2, 57, 11, 59, 33, 3]])
303+
>>> a.permute(4)
304+
BaseMCRTensor([[11, 59, 33, 3, 33, 24, 1, 36, 2, 57]])
305+
306+
"""
307+
return torch.roll(self, shifts=shifts, dims=-1)
308+
309+
def normalize(self) -> "BaseMCRTensor":
310+
r"""Normalize the hypervector.
311+
312+
Each operation on MCR hypervectors ensures it remains normalized, so this returns a copy of self.
313+
314+
Shapes:
315+
- Self: :math:`(*)`
316+
- Output: :math:`(*)`
317+
318+
Examples::
319+
320+
>>> x = torchhd.BaseMCRTensor.random(4, 6, block_size=64).multibundle()
321+
>>> x
322+
BaseMCRTensor([28, 27, 20, 44, 57, 18])
323+
>>> x.normalize()
324+
BaseMCRTensor([28, 27, 20, 44, 57, 18])
325+
326+
"""
327+
return self.clone()
328+
329+
def dot_similarity(self, others: "BaseMCRTensor", *, dtype=None) -> Tensor:
330+
"""Based on 'Manhattan Distance in a Modular Space'.
331+
Distance of two elements devided by the avearage distance of two random numbers.
332+
"""
333+
if dtype is None:
334+
dtype = torch.get_default_dtype()
335+
336+
random_distance = self.block_size/4
337+
338+
339+
if self.dim() > 1 and others.dim() > 1:
340+
aminusb = torch.remainder(self.unsqueeze(-2) - others.unsqueeze(-3), self.block_size)
341+
bminusa = torch.remainder(others.unsqueeze(-3) - self.unsqueeze(-2), self.block_size)
342+
else:
343+
aminusb = torch.remainder(self - others, self.block_size)
344+
bminusa = torch.remainder(others - self, self.block_size)
345+
distance = torch.min(aminusb,bminusa)
346+
normalized_distance = distance/random_distance
347+
348+
return torch.sum(1-normalized_distance, dim=-1, dtype=dtype)
349+
350+
def cosine_similarity(self, others: "BaseMCRTensor", *, dtype=None) -> Tensor:
351+
"""Cosine similarity with other hypervectors"""
352+
magnitude = self.size(-1)
353+
return self.dot_similarity(others, dtype=dtype) / magnitude
354+
355+
@classmethod
356+
def __torch_function__(cls, func, types, args=(), kwargs=None):
357+
# Ensure that all the build-in torch operations on this Tensor subclass maintain the block_size property
358+
359+
if kwargs is None:
360+
kwargs = {}
361+
362+
block_sizes = set(a.block_size for a in args if hasattr(a, "block_size"))
363+
if len(block_sizes) != 1:
364+
raise RuntimeError(
365+
f"Call to {func} must contain exactly one block size, got {list(block_sizes)}"
366+
)
367+
368+
# Call with super to avoid infinite recursion
369+
ret = super().__torch_function__(func, types, args, kwargs)
370+
371+
if isinstance(ret, BaseMCRTensor):
372+
ret.block_size = list(block_sizes)[0]
373+
elif isinstance(ret, (tuple, list)):
374+
for x in ret:
375+
if isinstance(x, BaseMCRTensor):
376+
x.block_size = list(block_sizes)[0]
377+
378+
# TODO: handle more return types
379+
return ret

0 commit comments

Comments
 (0)