Skip to content

Commit a958108

Browse files
authored
feat[gpu]: arrow device array decimal export (#8155)
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 1552135 commit a958108

5 files changed

Lines changed: 633 additions & 55 deletions

File tree

vortex-array/src/dtype/arrow.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,18 @@ mod test {
450450
);
451451
}
452452

453+
#[rstest]
454+
#[case(1, DataType::Decimal128(1, 0))]
455+
#[case(38, DataType::Decimal128(38, 0))]
456+
#[case(39, DataType::Decimal256(39, 0))]
457+
#[case(76, DataType::Decimal256(76, 0))]
458+
fn test_decimal_dtype_to_arrow(#[case] precision: u8, #[case] expected: DataType) {
459+
use crate::dtype::DecimalDType;
460+
461+
let dtype = DType::Decimal(DecimalDType::new(precision, 0), Nullability::NonNullable);
462+
assert_eq!(dtype.to_arrow_dtype().unwrap(), expected);
463+
}
464+
453465
#[test]
454466
fn test_variant_dtype_to_arrow_dtype_errors() {
455467
let err = DType::Variant(Nullability::NonNullable)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#include "config.cuh"
5+
#include "types.cuh"
6+
#include <stdint.h>
7+
#include <type_traits>
8+
9+
// Arrow decimal schemas fix the physical values buffer width:
10+
// - Decimal32: 4 bytes per value.
11+
// - Decimal64: 8 bytes per value.
12+
// - Decimal128: 16 bytes per value.
13+
// - Decimal256: 32 bytes per value.
14+
//
15+
// Vortex storage width can differ, so export casts to the schema-implied width.
16+
// Rust-side export rejects narrowing casts because detecting overflow on-device
17+
// would require synchronizing an overflow flag back to the host.
18+
19+
// Low 64-bit conversion for Decimal32/64 outputs.
20+
template <typename Input>
21+
__device__ __forceinline__ int64_t decimal_to_i64(Input value) {
22+
if constexpr (std::is_same_v<Input, int128_t>) {
23+
return value.lo;
24+
} else if constexpr (std::is_same_v<Input, int256_t>) {
25+
return value.parts[0];
26+
} else {
27+
return static_cast<int64_t>(value);
28+
}
29+
}
30+
31+
// 128-bit conversion for Decimal128 outputs.
32+
template <typename Input>
33+
__device__ __forceinline__ int128_t decimal_to_i128(Input value) {
34+
if constexpr (std::is_same_v<Input, int128_t>) {
35+
return value;
36+
} else if constexpr (std::is_same_v<Input, int256_t>) {
37+
return int128_t {value.parts[0], value.parts[1]};
38+
} else {
39+
const int64_t lo = static_cast<int64_t>(value);
40+
const int64_t hi = value < 0 ? -1 : 0;
41+
return int128_t {lo, hi};
42+
}
43+
}
44+
45+
// Convert one value to the Arrow schema's physical width.
46+
template <typename Output, typename Input>
47+
__device__ __forceinline__ Output decimal_cast_value(Input value) {
48+
if constexpr (std::is_same_v<Output, int32_t>) {
49+
return static_cast<int32_t>(decimal_to_i64(value));
50+
} else if constexpr (std::is_same_v<Output, int64_t>) {
51+
return decimal_to_i64(value);
52+
} else if constexpr (std::is_same_v<Output, int128_t>) {
53+
return decimal_to_i128(value);
54+
} else {
55+
static_assert(std::is_same_v<Output, int256_t>);
56+
if constexpr (std::is_same_v<Input, int256_t>) {
57+
return value;
58+
} else {
59+
const int128_t value128 = decimal_to_i128(value);
60+
const int64_t sign = value128.hi < 0 ? -1 : 0;
61+
return int256_t {{value128.lo, value128.hi, sign, sign}};
62+
}
63+
}
64+
}
65+
66+
// Cast a contiguous values buffer to the Arrow schema's physical width.
67+
template <typename Input, typename Output>
68+
__device__ void
69+
decimal_cast_device(const Input *__restrict input, Output *__restrict output, uint64_t array_len) {
70+
const uint64_t worker = blockIdx.x * blockDim.x + threadIdx.x;
71+
const uint64_t startElem = start_elem(worker, array_len);
72+
const uint64_t stopElem = stop_elem(worker, array_len);
73+
74+
if (startElem >= array_len) {
75+
return;
76+
}
77+
78+
for (uint64_t idx = startElem; idx < stopElem; idx++) {
79+
output[idx] = decimal_cast_value<Output>(input[idx]);
80+
}
81+
}
82+
83+
// Generate Decimal32/64/128/256 cast kernels for one input storage type.
84+
#define GENERATE_DECIMAL_CAST_KERNELS(input_suffix, InputType) \
85+
extern "C" __global__ void decimal_cast_##input_suffix##_i32(const InputType *__restrict input, \
86+
int32_t *__restrict output, \
87+
uint64_t array_len) { \
88+
decimal_cast_device(input, output, array_len); \
89+
} \
90+
extern "C" __global__ void decimal_cast_##input_suffix##_i64(const InputType *__restrict input, \
91+
int64_t *__restrict output, \
92+
uint64_t array_len) { \
93+
decimal_cast_device(input, output, array_len); \
94+
} \
95+
extern "C" __global__ void decimal_cast_##input_suffix##_i128(const InputType *__restrict input, \
96+
int128_t *__restrict output, \
97+
uint64_t array_len) { \
98+
decimal_cast_device(input, output, array_len); \
99+
} \
100+
extern "C" __global__ void decimal_cast_##input_suffix##_i256(const InputType *__restrict input, \
101+
int256_t *__restrict output, \
102+
uint64_t array_len) { \
103+
decimal_cast_device(input, output, array_len); \
104+
}
105+
106+
FOR_EACH_SIGNED_INT(GENERATE_DECIMAL_CAST_KERNELS)
107+
FOR_EACH_LARGE_DECIMAL(GENERATE_DECIMAL_CAST_KERNELS)

0 commit comments

Comments
 (0)