@@ -37,7 +37,7 @@ class Testlevel:
3737 @pytest .mark .parametrize ("d" , [84 , 16 ])
3838 @pytest .mark .parametrize ("vsa" , vsa_tensors )
3939 def test_shape (self , n , d , vsa ):
40- if vsa == "BSBC" or vsa == "MCR" :
40+ if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
4141 hv = functional .level (n , d , vsa , block_size = 1024 )
4242
4343 elif vsa == "VTB" and d == 84 :
@@ -49,7 +49,7 @@ def test_shape(self, n, d, vsa):
4949 else :
5050 hv = functional .level (n , d , vsa )
5151
52- if vsa == "BSBC" or vsa == "MCR" :
52+ if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
5353 assert hv .block_size == 1024
5454
5555 assert hv .dim () == 2
@@ -60,15 +60,15 @@ def test_shape(self, n, d, vsa):
6060 def test_generator (self , vsa ):
6161 generator = torch .Generator ()
6262 generator .manual_seed (seed )
63- if vsa == "BSBC" or vsa == "MCR" :
63+ if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
6464 hv1 = functional .level (20 , 10000 , vsa , generator = generator , block_size = 1024 )
6565 else :
6666 hv1 = functional .level (20 , 10000 , vsa , generator = generator )
6767
6868 generator = torch .Generator ()
6969 generator .manual_seed (seed )
7070
71- if vsa == "BSBC" or vsa == "MCR" :
71+ if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
7272 hv2 = functional .level (20 , 10000 , vsa , generator = generator , block_size = 1024 )
7373 else :
7474 hv2 = functional .level (20 , 10000 , vsa , generator = generator )
@@ -79,7 +79,7 @@ def test_generator(self, vsa):
7979 def test_value (self , dtype , vsa ):
8080 if not supported_dtype (dtype , vsa ):
8181 with pytest .raises (ValueError ):
82- if vsa == "BSBC" or vsa == "MCR" :
82+ if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
8383 functional .level (3 , 25 , vsa , dtype = dtype , block_size = 1024 )
8484 else :
8585 functional .level (3 , 25 , vsa , dtype = dtype )
@@ -89,7 +89,7 @@ def test_value(self, dtype, vsa):
8989 generator = torch .Generator ()
9090 generator .manual_seed (seed )
9191
92- if vsa == "BSBC" or vsa == "MCR" :
92+ if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
9393 hv = functional .level (
9494 50 , 25921 , vsa , dtype = dtype , generator = generator , block_size = 1024
9595 )
@@ -103,7 +103,7 @@ def test_value(self, dtype, vsa):
103103 if vsa == "BSC" :
104104 assert torch .all ((hv == False ) | (hv == True )).item ()
105105
106- elif vsa == "BSBC" or vsa == "MCR" :
106+ elif vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
107107 assert torch .all ((hv >= 0 ) | (hv < 1024 )).item ()
108108
109109 elif vsa == "MAP" :
@@ -123,7 +123,7 @@ def test_value(self, dtype, vsa):
123123 sims_diff = sims [:- 1 ] - sims [1 :]
124124 assert torch .all (sims_diff > 0 ).item (), "similarity must be decreasing"
125125
126- if vsa == "BSBC" or vsa == "MCR" :
126+ if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
127127 hv = functional .level (
128128 5 , 1000000 , vsa , generator = generator , dtype = dtype , block_size = 1024
129129 )
@@ -163,7 +163,7 @@ def test_device(self, dtype, vsa):
163163 return
164164
165165 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
166- if vsa == "BSBC" or vsa == "MCR" :
166+ if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR" :
167167 hv = functional .level (
168168 3 , 49 , vsa , device = device , dtype = dtype , block_size = 1024
169169 )
0 commit comments