Skip to content

Commit 5867f9e

Browse files
committed
add non-contiguous minus strides array.
1 parent 4057bbe commit 5867f9e

File tree

2 files changed

+123
-69
lines changed

2 files changed

+123
-69
lines changed

questdb-rs-ffi/src/lib.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -879,15 +879,14 @@ pub unsafe extern "C" fn line_sender_buffer_column_f64_arr_c_major(
879879
) -> bool {
880880
let buffer = unwrap_buffer_mut(buffer);
881881
let name = name.as_name();
882-
let view =
883-
match CMajorArrayView::<f64>::new(rank, shape, data_buffer, data_buffer_len) {
884-
Ok(value) => value,
885-
Err(err) => {
886-
let err_ptr = Box::into_raw(Box::new(line_sender_error(err)));
887-
*err_out = err_ptr;
888-
return false;
889-
}
890-
};
882+
let view = match CMajorArrayView::<f64>::new(rank, shape, data_buffer, data_buffer_len) {
883+
Ok(value) => value,
884+
Err(err) => {
885+
let err_ptr = Box::into_raw(Box::new(line_sender_error(err)));
886+
*err_out = err_ptr;
887+
return false;
888+
}
889+
};
891890
bubble_err_to_c!(
892891
err_out,
893892
buffer.column_arr::<ColumnName<'_>, CMajorArrayView<'_, f64>, f64>(name, &view)

questdb-rs-ffi/src/ndarr.rs

Lines changed: 115 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -149,36 +149,7 @@ where
149149
data: *const u8,
150150
data_len: usize,
151151
) -> Result<Self, Error> {
152-
if dims == 0 {
153-
return Err(fmt_error!(
154-
ArrayError,
155-
"Zero-dimensional arrays are not supported",
156-
));
157-
}
158-
if data_len > MAX_ARRAY_BUFFER_SIZE {
159-
return Err(fmt_error!(
160-
ArrayError,
161-
"Array buffer size too big: {}, maximum: {}",
162-
data_len,
163-
MAX_ARRAY_BUFFER_SIZE
164-
));
165-
}
166-
let shape = slice::from_raw_parts(shape, dims);
167-
let size = shape
168-
.iter()
169-
.try_fold(std::mem::size_of::<T>(), |acc, &dim| {
170-
acc.checked_mul(dim)
171-
.ok_or_else(|| fmt_error!(ArrayError, "Array buffer size too big"))
172-
})?;
173-
174-
if size != data_len {
175-
return Err(fmt_error!(
176-
ArrayError,
177-
"Array buffer length mismatch (actual: {}, expected: {})",
178-
data_len,
179-
size
180-
));
181-
}
152+
let shape = check_array_shape::<T>(dims, shape, data_len)?;
182153
let strides = slice::from_raw_parts(strides, dims);
183154
let mut slice = None;
184155
if data_len != 0 {
@@ -359,36 +330,7 @@ where
359330
data: *const u8,
360331
data_len: usize,
361332
) -> Result<Self, Error> {
362-
if dims == 0 {
363-
return Err(fmt_error!(
364-
ArrayError,
365-
"Zero-dimensional arrays are not supported",
366-
));
367-
}
368-
if data_len > MAX_ARRAY_BUFFER_SIZE {
369-
return Err(fmt_error!(
370-
ArrayError,
371-
"Array buffer size too big: {}, maximum: {}",
372-
data_len,
373-
MAX_ARRAY_BUFFER_SIZE
374-
));
375-
}
376-
let shape = slice::from_raw_parts(shape, dims);
377-
let size = shape
378-
.iter()
379-
.try_fold(std::mem::size_of::<T>(), |acc, &dim| {
380-
acc.checked_mul(dim)
381-
.ok_or_else(|| fmt_error!(ArrayError, "Array buffer size too big"))
382-
})?;
383-
384-
if size != data_len {
385-
return Err(fmt_error!(
386-
ArrayError,
387-
"Array buffer length mismatch (actual: {}, expected: {})",
388-
data_len,
389-
size
390-
));
391-
}
333+
let shape = check_array_shape::<T>(dims, shape, data_len)?;
392334
let mut slice = None;
393335
if data_len != 0 {
394336
slice = Some(slice::from_raw_parts(data, data_len));
@@ -402,6 +344,45 @@ where
402344
}
403345
}
404346

347+
fn check_array_shape<T>(
348+
dims: usize,
349+
shape: *const usize,
350+
data_len: usize,
351+
) -> Result<&'static [usize], Error> {
352+
if dims == 0 {
353+
return Err(fmt_error!(
354+
ArrayError,
355+
"Zero-dimensional arrays are not supported",
356+
));
357+
}
358+
if data_len > MAX_ARRAY_BUFFER_SIZE {
359+
return Err(fmt_error!(
360+
ArrayError,
361+
"Array buffer size too big: {}, maximum: {}",
362+
data_len,
363+
MAX_ARRAY_BUFFER_SIZE
364+
));
365+
}
366+
let shape = unsafe { slice::from_raw_parts(shape, dims) };
367+
368+
let size = shape
369+
.iter()
370+
.try_fold(std::mem::size_of::<T>(), |acc, &dim| {
371+
acc.checked_mul(dim)
372+
.ok_or_else(|| fmt_error!(ArrayError, "Array buffer size too big"))
373+
})?;
374+
375+
if size != data_len {
376+
return Err(fmt_error!(
377+
ArrayError,
378+
"Array buffer length mismatch (actual: {}, expected: {})",
379+
data_len,
380+
size
381+
));
382+
}
383+
Ok(shape)
384+
}
385+
405386
#[cfg(test)]
406387
mod tests {
407388
use super::*;
@@ -909,4 +890,78 @@ mod tests {
909890
assert_eq!(buf, expected);
910891
Ok(())
911892
}
893+
894+
#[test]
895+
fn test_c_major_array_basic() -> TestResult {
896+
let test_data = [1.1, 2.2, 3.3, 4.4];
897+
let array_view: CMajorArrayView<'_, f64> = unsafe {
898+
CMajorArrayView::new(
899+
2,
900+
[2, 2].as_ptr(),
901+
test_data.as_ptr() as *const u8,
902+
test_data.len() * 8usize,
903+
)
904+
}?;
905+
let mut buffer = Buffer::new(ProtocolVersion::V2);
906+
buffer.table("my_test")?;
907+
buffer.column_arr("temperature", &array_view)?;
908+
let data = buffer.as_bytes();
909+
assert_eq!(&data[0..7], b"my_test");
910+
assert_eq!(&data[8..19], b"temperature");
911+
assert_eq!(
912+
&data[19..24],
913+
&[
914+
b'=', b'=', 14u8, // ARRAY_BINARY_FORMAT_TYPE
915+
10u8, // ArrayColumnTypeTag::Double.into()
916+
2u8
917+
]
918+
);
919+
assert_eq!(
920+
&data[24..32],
921+
[2i32.to_le_bytes(), 2i32.to_le_bytes()].concat()
922+
);
923+
assert_eq!(
924+
&data[32..64],
925+
&[
926+
1.1f64.to_ne_bytes(),
927+
2.2f64.to_le_bytes(),
928+
3.3f64.to_le_bytes(),
929+
4.4f64.to_le_bytes(),
930+
]
931+
.concat()
932+
);
933+
Ok(())
934+
}
935+
936+
#[test]
937+
fn test_c_major_empty_array() -> TestResult {
938+
let test_data = [];
939+
let array_view: CMajorArrayView<'_, f64> = unsafe {
940+
CMajorArrayView::new(
941+
2,
942+
[2, 0].as_ptr(),
943+
test_data.as_ptr(),
944+
test_data.len() * 8usize,
945+
)
946+
}?;
947+
let mut buffer = Buffer::new(ProtocolVersion::V2);
948+
buffer.table("my_test")?;
949+
buffer.column_arr("temperature", &array_view)?;
950+
let data = buffer.as_bytes();
951+
assert_eq!(&data[0..7], b"my_test");
952+
assert_eq!(&data[8..19], b"temperature");
953+
assert_eq!(
954+
&data[19..24],
955+
&[
956+
b'=', b'=', 14u8, // ARRAY_BINARY_FORMAT_TYPE
957+
10u8, // ArrayColumnTypeTag::Double.into()
958+
2u8
959+
]
960+
);
961+
assert_eq!(
962+
&data[24..32],
963+
[2i32.to_le_bytes(), 0i32.to_le_bytes()].concat()
964+
);
965+
Ok(())
966+
}
912967
}

0 commit comments

Comments
 (0)