Skip to content

Commit 58b75d4

Browse files
committed
Add CGR into tests
The CGR should behave almost the same as MCR, but diverges in bundling. Implement a custom bundling test for it.
1 parent 9c471ca commit 58b75d4

File tree

9 files changed

+127
-112
lines changed

9 files changed

+127
-112
lines changed

torchhd/tests/basis_hv/test_circular_hv.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_shape(self, n, d, vsa):
4141
if vsa == "HRR" or vsa == "VTB":
4242
return
4343

44-
if vsa == "BSBC" or vsa == "MCR":
44+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
4545
hv = functional.circular(n, d, vsa, block_size=1024)
4646
else:
4747
hv = functional.circular(n, d, vsa)
@@ -57,7 +57,7 @@ def test_generator(self, vsa):
5757

5858
generator = torch.Generator()
5959
generator.manual_seed(seed)
60-
if vsa == "BSBC" or vsa == "MCR":
60+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
6161
hv1 = functional.circular(
6262
20, 10000, vsa, generator=generator, block_size=1024
6363
)
@@ -66,7 +66,7 @@ def test_generator(self, vsa):
6666

6767
generator = torch.Generator()
6868
generator.manual_seed(seed)
69-
if vsa == "BSBC" or vsa == "MCR":
69+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
7070
hv2 = functional.circular(
7171
20, 10000, vsa, generator=generator, block_size=1024
7272
)
@@ -79,7 +79,7 @@ def test_generator(self, vsa):
7979
def test_value(self, dtype, vsa):
8080
if not supported_dtype(dtype, vsa):
8181
with pytest.raises(ValueError):
82-
if vsa == "BSBC" or vsa == "MCR":
82+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
8383
functional.circular(3, 26, vsa, dtype=dtype, block_size=1024)
8484
else:
8585
functional.circular(3, 26, vsa, dtype=dtype)
@@ -95,7 +95,7 @@ def test_value(self, dtype, vsa):
9595
generator = torch.Generator()
9696
generator.manual_seed(seed)
9797

98-
if vsa == "BSBC" or vsa == "MCR":
98+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
9999
hv = functional.circular(
100100
50, 26569, vsa, dtype=dtype, generator=generator, block_size=1024
101101
)
@@ -118,16 +118,16 @@ def test_value(self, dtype, vsa):
118118
mag, torch.tensor(1.0, dtype=mag.dtype), rtol=0.0001, atol=0.0001
119119
)
120120

121-
elif vsa == "BSBC" or vsa == "MCR":
121+
elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
122122
assert torch.all((hv >= 0) | (hv < 1024)).item()
123123

124-
if vsa == "BSBC" or vsa == "MCR":
124+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
125125
hv = functional.circular(
126126
8, 1000000, vsa, generator=generator, dtype=dtype, block_size=1024
127127
)
128128
else:
129129
hv = functional.circular(8, 1000000, vsa, generator=generator, dtype=dtype)
130-
130+
131131
for i in range(8-1):
132132
sims = functional.cosine_similarity(hv[0], hv)
133133
sims_diff = sims[:-1] - sims[1:]
@@ -180,7 +180,7 @@ def test_device(self, dtype, vsa):
180180
return
181181

182182
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
183-
if vsa == "BSBC" or vsa == "MCR":
183+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
184184
hv = functional.circular(
185185
3, 52, vsa, device=device, dtype=dtype, block_size=1024
186186
)

torchhd/tests/basis_hv/test_empty_hv.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Testempty:
4141
@pytest.mark.parametrize("d", [84, 16])
4242
@pytest.mark.parametrize("vsa", vsa_tensors)
4343
def test_shape(self, n, d, vsa):
44-
if vsa == "BSBC" or vsa == "MCR":
44+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
4545
hv = functional.empty(n, d, vsa, block_size=1024)
4646
elif vsa == "VTB" and d == 84:
4747
with pytest.raises(ValueError):
@@ -60,14 +60,14 @@ def test_shape(self, n, d, vsa):
6060
def test_value(self, dtype, vsa):
6161
if not supported_dtype(dtype, vsa):
6262
with pytest.raises(ValueError):
63-
if vsa == "BSBC" or vsa == "MCR":
63+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
6464
functional.empty(3, 25, vsa, dtype=dtype, block_size=1024)
6565
else:
6666
functional.empty(3, 25, vsa, dtype=dtype)
6767

6868
return
6969

70-
if vsa == "BSBC" or vsa == "MCR":
70+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
7171
hv = functional.empty(8, 25, vsa, dtype=dtype, block_size=1024)
7272
else:
7373
hv = functional.empty(8, 25, vsa, dtype=dtype)
@@ -80,7 +80,7 @@ def test_value(self, dtype, vsa):
8080
if vsa == "BSC":
8181
assert torch.all((hv == False) | (hv == True)).item()
8282

83-
elif vsa == "BSBC" or vsa == "MCR":
83+
elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
8484
assert torch.all((hv >= 0) | (hv < 1024)).item()
8585

8686
else:
@@ -94,7 +94,7 @@ def test_device(self, dtype, vsa):
9494
return
9595

9696
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97-
if vsa == "BSBC" or vsa == "MCR":
97+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
9898
hv = functional.empty(
9999
3, 52, vsa, device=device, dtype=dtype, block_size=1024
100100
)

torchhd/tests/basis_hv/test_identity_hv.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Testidentity:
4141
@pytest.mark.parametrize("d", [84, 16])
4242
@pytest.mark.parametrize("vsa", vsa_tensors)
4343
def test_shape(self, n, d, vsa):
44-
if vsa == "BSBC" or vsa == "MCR":
44+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
4545
hv = functional.identity(n, d, vsa, block_size=1042)
4646
elif vsa == "VTB" and d == 84:
4747
with pytest.raises(ValueError):
@@ -61,14 +61,14 @@ def test_shape(self, n, d, vsa):
6161
def test_value(self, dtype, vsa):
6262
if not supported_dtype(dtype, vsa):
6363
with pytest.raises(ValueError):
64-
if vsa == "BSBC" or vsa == "MCR":
64+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
6565
functional.identity(3, 26, vsa, dtype=dtype, block_size=1042)
6666
else:
6767
functional.identity(3, 25, vsa, dtype=dtype)
6868

6969
return
7070

71-
if vsa == "BSBC" or vsa == "MCR":
71+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
7272
hv = functional.identity(8, 25, vsa, dtype=dtype, block_size=1042)
7373
else:
7474
hv = functional.identity(8, 25, vsa, dtype=dtype)
@@ -86,7 +86,7 @@ def test_value(self, dtype, vsa):
8686
x = torch.fft.fft(hv)
8787
assert torch.allclose(x, torch.full_like(x, 1.0))
8888

89-
elif vsa == "BSBC" or vsa == "MCR":
89+
elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
9090
assert torch.all(hv == 0)
9191

9292
elif vsa == "VTB":
@@ -103,7 +103,7 @@ def test_device(self, dtype, vsa):
103103
return
104104

105105
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106-
if vsa == "BSBC" or vsa == "MCR":
106+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
107107
hv = functional.identity(
108108
3, 52, vsa, device=device, dtype=dtype, block_size=1042
109109
)

torchhd/tests/basis_hv/test_level_hv.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Testlevel:
3737
@pytest.mark.parametrize("d", [84, 16])
3838
@pytest.mark.parametrize("vsa", vsa_tensors)
3939
def test_shape(self, n, d, vsa):
40-
if vsa == "BSBC" or vsa == "MCR":
40+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
4141
hv = functional.level(n, d, vsa, block_size=1024)
4242

4343
elif vsa == "VTB" and d == 84:
@@ -49,7 +49,7 @@ def test_shape(self, n, d, vsa):
4949
else:
5050
hv = functional.level(n, d, vsa)
5151

52-
if vsa == "BSBC" or vsa == "MCR":
52+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
5353
assert hv.block_size == 1024
5454

5555
assert hv.dim() == 2
@@ -60,15 +60,15 @@ def test_shape(self, n, d, vsa):
6060
def test_generator(self, vsa):
6161
generator = torch.Generator()
6262
generator.manual_seed(seed)
63-
if vsa == "BSBC" or vsa == "MCR":
63+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
6464
hv1 = functional.level(20, 10000, vsa, generator=generator, block_size=1024)
6565
else:
6666
hv1 = functional.level(20, 10000, vsa, generator=generator)
6767

6868
generator = torch.Generator()
6969
generator.manual_seed(seed)
7070

71-
if vsa == "BSBC" or vsa == "MCR":
71+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
7272
hv2 = functional.level(20, 10000, vsa, generator=generator, block_size=1024)
7373
else:
7474
hv2 = functional.level(20, 10000, vsa, generator=generator)
@@ -79,7 +79,7 @@ def test_generator(self, vsa):
7979
def test_value(self, dtype, vsa):
8080
if not supported_dtype(dtype, vsa):
8181
with pytest.raises(ValueError):
82-
if vsa == "BSBC" or vsa == "MCR":
82+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
8383
functional.level(3, 25, vsa, dtype=dtype, block_size=1024)
8484
else:
8585
functional.level(3, 25, vsa, dtype=dtype)
@@ -89,7 +89,7 @@ def test_value(self, dtype, vsa):
8989
generator = torch.Generator()
9090
generator.manual_seed(seed)
9191

92-
if vsa == "BSBC" or vsa == "MCR":
92+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
9393
hv = functional.level(
9494
50, 25921, vsa, dtype=dtype, generator=generator, block_size=1024
9595
)
@@ -103,7 +103,7 @@ def test_value(self, dtype, vsa):
103103
if vsa == "BSC":
104104
assert torch.all((hv == False) | (hv == True)).item()
105105

106-
elif vsa == "BSBC" or vsa == "MCR":
106+
elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
107107
assert torch.all((hv >= 0) | (hv < 1024)).item()
108108

109109
elif vsa == "MAP":
@@ -123,7 +123,7 @@ def test_value(self, dtype, vsa):
123123
sims_diff = sims[:-1] - sims[1:]
124124
assert torch.all(sims_diff > 0).item(), "similarity must be decreasing"
125125

126-
if vsa == "BSBC" or vsa == "MCR":
126+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
127127
hv = functional.level(
128128
5, 1000000, vsa, generator=generator, dtype=dtype, block_size=1024
129129
)
@@ -163,7 +163,7 @@ def test_device(self, dtype, vsa):
163163
return
164164

165165
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166-
if vsa == "BSBC" or vsa == "MCR":
166+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
167167
hv = functional.level(
168168
3, 49, vsa, device=device, dtype=dtype, block_size=1024
169169
)

torchhd/tests/basis_hv/test_random_hv.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Testrandom:
4141
@pytest.mark.parametrize("d", [84, 16])
4242
@pytest.mark.parametrize("vsa", vsa_tensors)
4343
def test_shape(self, n, d, vsa):
44-
if vsa == "BSBC" or vsa == "MCR":
44+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
4545
hv = functional.random(n, d, vsa, block_size=64)
4646

4747
elif vsa == "VTB" and d == 84:
@@ -62,15 +62,15 @@ def test_generator(self, vsa):
6262
generator = torch.Generator()
6363
generator.manual_seed(seed)
6464

65-
if vsa == "BSBC" or vsa == "MCR":
65+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
6666
hv1 = functional.random(20, 10000, vsa, generator=generator, block_size=64)
6767
else:
6868
hv1 = functional.random(20, 10000, vsa, generator=generator)
6969

7070
generator = torch.Generator()
7171
generator.manual_seed(seed)
7272

73-
if vsa == "BSBC" or vsa == "MCR":
73+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
7474
hv2 = functional.random(20, 10000, vsa, generator=generator, block_size=64)
7575
else:
7676
hv2 = functional.random(20, 10000, vsa, generator=generator)
@@ -81,7 +81,7 @@ def test_generator(self, vsa):
8181
def test_value(self, dtype, vsa):
8282
if not supported_dtype(dtype, vsa):
8383
with pytest.raises(ValueError):
84-
if vsa == "BSBC" or vsa == "MCR":
84+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
8585
functional.random(3, 25, vsa, dtype=dtype, block_size=64)
8686
else:
8787
functional.random(3, 25, vsa, dtype=dtype)
@@ -91,7 +91,7 @@ def test_value(self, dtype, vsa):
9191
generator = torch.Generator()
9292
generator.manual_seed(seed)
9393

94-
if vsa == "BSBC" or vsa == "MCR":
94+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
9595
hv = functional.random(
9696
8, 25921, vsa, dtype=dtype, generator=generator, block_size=64
9797
)
@@ -122,7 +122,7 @@ def test_value(self, dtype, vsa):
122122
mag = hv.abs()
123123
assert torch.allclose(mag, torch.tensor(1.0, dtype=mag.dtype))
124124

125-
elif vsa == "BSBC" or vsa == "MCR":
125+
elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
126126
assert torch.all((hv < 64) & (hv >= 0))
127127

128128
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0])
@@ -155,7 +155,7 @@ def test_orthogonality(self, dtype, vsa):
155155
generator = torch.Generator()
156156
generator.manual_seed(seed)
157157

158-
if vsa == "BSBC" or vsa == "MCR":
158+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
159159
hv = functional.random(
160160
100, 10000, vsa, dtype=dtype, generator=generator, block_size=1042
161161
)
@@ -174,7 +174,7 @@ def test_device(self, dtype, vsa):
174174
return
175175

176176
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177-
if vsa == "BSBC" or vsa == "MCR":
177+
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
178178
hv = functional.random(
179179
3, 49, vsa, device=device, dtype=dtype, block_size=64
180180
)

0 commit comments

Comments
 (0)