@@ -149,36 +149,7 @@ where
149
149
data : * const u8 ,
150
150
data_len : usize ,
151
151
) -> Result < Self , Error > {
152
- if dims == 0 {
153
- return Err ( fmt_error ! (
154
- ArrayError ,
155
- "Zero-dimensional arrays are not supported" ,
156
- ) ) ;
157
- }
158
- if data_len > MAX_ARRAY_BUFFER_SIZE {
159
- return Err ( fmt_error ! (
160
- ArrayError ,
161
- "Array buffer size too big: {}, maximum: {}" ,
162
- data_len,
163
- MAX_ARRAY_BUFFER_SIZE
164
- ) ) ;
165
- }
166
- let shape = slice:: from_raw_parts ( shape, dims) ;
167
- let size = shape
168
- . iter ( )
169
- . try_fold ( std:: mem:: size_of :: < T > ( ) , |acc, & dim| {
170
- acc. checked_mul ( dim)
171
- . ok_or_else ( || fmt_error ! ( ArrayError , "Array buffer size too big" ) )
172
- } ) ?;
173
-
174
- if size != data_len {
175
- return Err ( fmt_error ! (
176
- ArrayError ,
177
- "Array buffer length mismatch (actual: {}, expected: {})" ,
178
- data_len,
179
- size
180
- ) ) ;
181
- }
152
+ let shape = check_array_shape :: < T > ( dims, shape, data_len) ?;
182
153
let strides = slice:: from_raw_parts ( strides, dims) ;
183
154
let mut slice = None ;
184
155
if data_len != 0 {
@@ -359,36 +330,7 @@ where
359
330
data : * const u8 ,
360
331
data_len : usize ,
361
332
) -> Result < Self , Error > {
362
- if dims == 0 {
363
- return Err ( fmt_error ! (
364
- ArrayError ,
365
- "Zero-dimensional arrays are not supported" ,
366
- ) ) ;
367
- }
368
- if data_len > MAX_ARRAY_BUFFER_SIZE {
369
- return Err ( fmt_error ! (
370
- ArrayError ,
371
- "Array buffer size too big: {}, maximum: {}" ,
372
- data_len,
373
- MAX_ARRAY_BUFFER_SIZE
374
- ) ) ;
375
- }
376
- let shape = slice:: from_raw_parts ( shape, dims) ;
377
- let size = shape
378
- . iter ( )
379
- . try_fold ( std:: mem:: size_of :: < T > ( ) , |acc, & dim| {
380
- acc. checked_mul ( dim)
381
- . ok_or_else ( || fmt_error ! ( ArrayError , "Array buffer size too big" ) )
382
- } ) ?;
383
-
384
- if size != data_len {
385
- return Err ( fmt_error ! (
386
- ArrayError ,
387
- "Array buffer length mismatch (actual: {}, expected: {})" ,
388
- data_len,
389
- size
390
- ) ) ;
391
- }
333
+ let shape = check_array_shape :: < T > ( dims, shape, data_len) ?;
392
334
let mut slice = None ;
393
335
if data_len != 0 {
394
336
slice = Some ( slice:: from_raw_parts ( data, data_len) ) ;
@@ -402,6 +344,45 @@ where
402
344
}
403
345
}
404
346
347
+ fn check_array_shape < T > (
348
+ dims : usize ,
349
+ shape : * const usize ,
350
+ data_len : usize ,
351
+ ) -> Result < & ' static [ usize ] , Error > {
352
+ if dims == 0 {
353
+ return Err ( fmt_error ! (
354
+ ArrayError ,
355
+ "Zero-dimensional arrays are not supported" ,
356
+ ) ) ;
357
+ }
358
+ if data_len > MAX_ARRAY_BUFFER_SIZE {
359
+ return Err ( fmt_error ! (
360
+ ArrayError ,
361
+ "Array buffer size too big: {}, maximum: {}" ,
362
+ data_len,
363
+ MAX_ARRAY_BUFFER_SIZE
364
+ ) ) ;
365
+ }
366
+ let shape = unsafe { slice:: from_raw_parts ( shape, dims) } ;
367
+
368
+ let size = shape
369
+ . iter ( )
370
+ . try_fold ( std:: mem:: size_of :: < T > ( ) , |acc, & dim| {
371
+ acc. checked_mul ( dim)
372
+ . ok_or_else ( || fmt_error ! ( ArrayError , "Array buffer size too big" ) )
373
+ } ) ?;
374
+
375
+ if size != data_len {
376
+ return Err ( fmt_error ! (
377
+ ArrayError ,
378
+ "Array buffer length mismatch (actual: {}, expected: {})" ,
379
+ data_len,
380
+ size
381
+ ) ) ;
382
+ }
383
+ Ok ( shape)
384
+ }
385
+
405
386
#[ cfg( test) ]
406
387
mod tests {
407
388
use super :: * ;
@@ -909,4 +890,78 @@ mod tests {
909
890
assert_eq ! ( buf, expected) ;
910
891
Ok ( ( ) )
911
892
}
893
+
894
+ #[ test]
895
+ fn test_c_major_array_basic ( ) -> TestResult {
896
+ let test_data = [ 1.1 , 2.2 , 3.3 , 4.4 ] ;
897
+ let array_view: CMajorArrayView < ' _ , f64 > = unsafe {
898
+ CMajorArrayView :: new (
899
+ 2 ,
900
+ [ 2 , 2 ] . as_ptr ( ) ,
901
+ test_data. as_ptr ( ) as * const u8 ,
902
+ test_data. len ( ) * 8usize ,
903
+ )
904
+ } ?;
905
+ let mut buffer = Buffer :: new ( ProtocolVersion :: V2 ) ;
906
+ buffer. table ( "my_test" ) ?;
907
+ buffer. column_arr ( "temperature" , & array_view) ?;
908
+ let data = buffer. as_bytes ( ) ;
909
+ assert_eq ! ( & data[ 0 ..7 ] , b"my_test" ) ;
910
+ assert_eq ! ( & data[ 8 ..19 ] , b"temperature" ) ;
911
+ assert_eq ! (
912
+ & data[ 19 ..24 ] ,
913
+ & [
914
+ b'=' , b'=' , 14u8 , // ARRAY_BINARY_FORMAT_TYPE
915
+ 10u8 , // ArrayColumnTypeTag::Double.into()
916
+ 2u8
917
+ ]
918
+ ) ;
919
+ assert_eq ! (
920
+ & data[ 24 ..32 ] ,
921
+ [ 2i32 . to_le_bytes( ) , 2i32 . to_le_bytes( ) ] . concat( )
922
+ ) ;
923
+ assert_eq ! (
924
+ & data[ 32 ..64 ] ,
925
+ & [
926
+ 1.1f64 . to_ne_bytes( ) ,
927
+ 2.2f64 . to_le_bytes( ) ,
928
+ 3.3f64 . to_le_bytes( ) ,
929
+ 4.4f64 . to_le_bytes( ) ,
930
+ ]
931
+ . concat( )
932
+ ) ;
933
+ Ok ( ( ) )
934
+ }
935
+
936
+ #[ test]
937
+ fn test_c_major_empty_array ( ) -> TestResult {
938
+ let test_data = [ ] ;
939
+ let array_view: CMajorArrayView < ' _ , f64 > = unsafe {
940
+ CMajorArrayView :: new (
941
+ 2 ,
942
+ [ 2 , 0 ] . as_ptr ( ) ,
943
+ test_data. as_ptr ( ) ,
944
+ test_data. len ( ) * 8usize ,
945
+ )
946
+ } ?;
947
+ let mut buffer = Buffer :: new ( ProtocolVersion :: V2 ) ;
948
+ buffer. table ( "my_test" ) ?;
949
+ buffer. column_arr ( "temperature" , & array_view) ?;
950
+ let data = buffer. as_bytes ( ) ;
951
+ assert_eq ! ( & data[ 0 ..7 ] , b"my_test" ) ;
952
+ assert_eq ! ( & data[ 8 ..19 ] , b"temperature" ) ;
953
+ assert_eq ! (
954
+ & data[ 19 ..24 ] ,
955
+ & [
956
+ b'=' , b'=' , 14u8 , // ARRAY_BINARY_FORMAT_TYPE
957
+ 10u8 , // ArrayColumnTypeTag::Double.into()
958
+ 2u8
959
+ ]
960
+ ) ;
961
+ assert_eq ! (
962
+ & data[ 24 ..32 ] ,
963
+ [ 2i32 . to_le_bytes( ) , 0i32 . to_le_bytes( ) ] . concat( )
964
+ ) ;
965
+ Ok ( ( ) )
966
+ }
912
967
}
0 commit comments