2626import torch
2727from torch import LongTensor , FloatTensor , Tensor
2828from collections import deque
29+ import warnings
2930
3031from torchhd .tensors .base import VSATensor
3132from torchhd .tensors .bsc import BSCTensor
5051 "permute" ,
5152 "inverse" ,
5253 "negative" ,
54+ "normalize" ,
5355 "cleanup" ,
5456 "create_random_permute" ,
5557 "hard_quantize" ,
@@ -673,6 +675,11 @@ def bundle(input: VSATensor, other: VSATensor) -> VSATensor:
673675
674676 \oplus: \mathcal{H} \times \mathcal{H} \to \mathcal{H}
675677
678+ .. note::
679+
680+ This operation does not normalize the resulting hypervectors.
681+ Normalized hypervectors can be obtained with :func:`~torchhd.normalize`.
682+
676683 Args:
677684 input (VSATensor): input hypervector
678685 other (VSATensor): other input hypervector
@@ -885,6 +892,12 @@ def hard_quantize(input: Tensor):
885892 tensor([ 1., -1., -1., -1., 1., -1.])
886893
887894 """
895+ warnings .warn (
896+ "torchhd.hard_quantize is deprecated, consider using torchhd.normalize instead." ,
897+ DeprecationWarning ,
898+ stacklevel = 2 ,
899+ )
900+
888901 # Make sure that the output tensor has the same dtype and device
889902 # as the input tensor.
890903 positive = torch .tensor (1.0 , dtype = input .dtype , device = input .device )
@@ -893,6 +906,35 @@ def hard_quantize(input: Tensor):
893906 return torch .where (input > 0 , positive , negative )
894907
895908
909+ def normalize (input : VSATensor ) -> VSATensor :
910+ """Normalize the input hypervectors.
911+
912+ Args:
913+ input (Tensor): input tensor
914+
915+ Shapes:
916+ - Input: :math:`(*)`
917+ - Output: :math:`(*)`
918+
919+ Examples::
920+
921+ >>> x = torchhd.random(4, 10, "MAP").multibundle()
922+ >>> x
923+ MAPTensor([ 0., 0., -2., -2., 2., -2., 2., 2., 2., 0.])
924+ >>> torchhd.normalize(x)
925+ MAPTensor([-1., -1., -1., -1., 1., -1., 1., 1., 1., -1.])
926+
927+ >>> x = torchhd.random(4, 10, "HRR").multibundle()
928+ >>> x
929+ HRRTensor([-0.2999, 0.4686, 0.1797, -0.4830, 0.2718, -0.3663, 0.3079, 0.2558, -1.5157, -0.5196])
930+ >>> torchhd.normalize(x)
931+ HRRTensor([-0.1601, 0.2501, 0.0959, -0.2578, 0.1451, -0.1955, 0.1643, 0.1365, -0.8089, -0.2773])
932+
933+ """
934+ input = ensure_vsa_tensor (input )
935+ return input .normalize ()
936+
937+
896938def dot_similarity (input : VSATensor , others : VSATensor , ** kwargs ) -> VSATensor :
897939 """Dot product between the input vector and each vector in others.
898940
@@ -1037,6 +1079,11 @@ def multiset(input: VSATensor) -> VSATensor:
10371079
10381080 \bigoplus_{i=0}^{n-1} V_i
10391081
1082+ .. note::
1083+
1084+ This operation does not normalize the resulting or intermediate hypervectors.
1085+ Normalized hypervectors can be obtained with :func:`~torchhd.normalize`.
1086+
10401087 Args:
10411088 input (VSATensor): input hypervector tensor
10421089
0 commit comments