@@ -148,6 +148,34 @@ async def test_codecs_use_of_gpu_prototype() -> None:
148148 assert cp .array_equal (expect , got )
149149
150150
151+ @gpu_test
152+ @pytest .mark .asyncio
153+ async def test_sharding_use_of_gpu_prototype () -> None :
154+ with zarr .config .enable_gpu ():
155+ expect = cp .zeros ((10 , 10 ), dtype = "uint16" , order = "F" )
156+
157+ a = await zarr .api .asynchronous .create_array (
158+ StorePath (MemoryStore ()) / "test_codecs_use_of_gpu_prototype" ,
159+ shape = expect .shape ,
160+ chunks = (5 , 5 ),
161+ shards = (10 , 10 ),
162+ dtype = expect .dtype ,
163+ fill_value = 0 ,
164+ )
165+ expect [:] = cp .arange (100 ).reshape (10 , 10 )
166+
167+ await a .setitem (
168+ selection = (slice (0 , 10 ), slice (0 , 10 )),
169+ value = expect [:],
170+ prototype = gpu .buffer_prototype ,
171+ )
172+ got = await a .getitem (
173+ selection = (slice (0 , 10 ), slice (0 , 10 )), prototype = gpu .buffer_prototype
174+ )
175+ assert isinstance (got , cp .ndarray )
176+ assert cp .array_equal (expect , got )
177+
178+
151179def test_numpy_buffer_prototype () -> None :
152180 buffer = cpu .buffer_prototype .buffer .create_zero_length ()
153181 ndbuffer = cpu .buffer_prototype .nd_buffer .create (shape = (1 , 2 ), dtype = np .dtype ("int64" ))
@@ -157,6 +185,16 @@ def test_numpy_buffer_prototype() -> None:
157185 ndbuffer .as_scalar ()
158186
159187
188+ @gpu_test
189+ def test_gpu_buffer_prototype () -> None :
190+ buffer = gpu .buffer_prototype .buffer .create_zero_length ()
191+ ndbuffer = gpu .buffer_prototype .nd_buffer .create (shape = (1 , 2 ), dtype = cp .dtype ("int64" ))
192+ assert isinstance (buffer .as_array_like (), cp .ndarray )
193+ assert isinstance (ndbuffer .as_ndarray_like (), cp .ndarray )
194+ with pytest .raises (ValueError , match = "Buffer does not contain a single scalar value" ):
195+ ndbuffer .as_scalar ()
196+
197+
160198# TODO: the same test for other buffer classes
161199def test_cpu_buffer_as_scalar () -> None :
162200 buf = cpu .buffer_prototype .nd_buffer .create (shape = (), dtype = "int64" )
0 commit comments