Skip to content

Commit aab7d81

Browse files
zeroshadelidavidmbkietz
authored
apacheGH-43631: [C++] Add C++ implementation of Async C Data Interface (apache#44495)
### Rationale for this change Building on apache#43632 which created the Async C Data Structures, this adds functions to `bridge.h`/`bridge.cc` to implement helpers for managing the Async C Data interfaces ### What changes are included in this PR? Two functions added to bridge.h: 1. `CreateAsyncDeviceStreamHandler` populates a `ArrowAsyncDeviceStreamHandler` and an `Executor` to provide a future that resolves to an `AsyncRecordBatchGenerator` to produce record batches as they are pushed asynchronously. The `ArrowAsyncDeviceStreamHandler` can then be passed to any asynchronous producer. 2. `ExportAsyncRecordBatchReader` takes a record batch generator and a schema, along with an `ArrowAsyncDeviceStreamHandler` to use for calling the callbacks to push data as it is available from the generator. ### Are these changes tested? Unit tests are added (currently only one test, more tests to be added) ### Are there any user-facing changes? No * GitHub Issue: apache#43631 Lead-authored-by: Matt Topol <[email protected]> Co-authored-by: David Li <[email protected]> Co-authored-by: Benjamin Kietzman <[email protected]> Signed-off-by: Matt Topol <[email protected]>
1 parent d7e982c commit aab7d81

File tree

7 files changed

+558
-1
lines changed

7 files changed

+558
-1
lines changed

cpp/src/arrow/c/abi.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ struct ArrowAsyncTask {
287287
// calling this, and so it must be released separately.
288288
//
289289
// It is only valid to call this method exactly once.
290-
int (*extract_data)(struct ArrowArrayTask* self, struct ArrowDeviceArray* out);
290+
int (*extract_data)(struct ArrowAsyncTask* self, struct ArrowDeviceArray* out);
291291

292292
// opaque task-specific data
293293
void* private_data;
@@ -298,6 +298,9 @@ struct ArrowAsyncTask {
298298
// control on the asynchronous stream processing. This object must be owned by the
299299
// producer who creates it, and thus is responsible for cleaning it up.
300300
struct ArrowAsyncProducer {
301+
// The device type that this stream produces data on.
302+
ArrowDeviceType device_type;
303+
301304
// A consumer must call this function to start receiving on_next_task calls.
302305
//
303306
// It *must* be valid to call this synchronously from within `on_next_task` or

cpp/src/arrow/c/bridge.cc

+348
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020
#include <algorithm>
2121
#include <cerrno>
22+
#include <condition_variable>
2223
#include <cstring>
2324
#include <memory>
25+
#include <mutex>
26+
#include <queue>
2427
#include <string>
2528
#include <string_view>
2629
#include <utility>
@@ -37,8 +40,10 @@
3740
#include "arrow/result.h"
3841
#include "arrow/stl_allocator.h"
3942
#include "arrow/type_traits.h"
43+
#include "arrow/util/async_generator.h"
4044
#include "arrow/util/bit_util.h"
4145
#include "arrow/util/checked_cast.h"
46+
#include "arrow/util/future.h"
4247
#include "arrow/util/key_value_metadata.h"
4348
#include "arrow/util/logging.h"
4449
#include "arrow/util/macros.h"
@@ -2511,4 +2516,347 @@ Result<std::shared_ptr<ChunkedArray>> ImportDeviceChunkedArray(
25112516
return ImportChunked</*IsDevice=*/true>(stream, mapper);
25122517
}
25132518

2519+
namespace {
2520+
2521+
class AsyncRecordBatchIterator {
2522+
public:
2523+
struct TaskWithMetadata {
2524+
ArrowAsyncTask task_;
2525+
std::shared_ptr<KeyValueMetadata> metadata_;
2526+
};
2527+
2528+
struct State {
2529+
State(uint64_t queue_size, DeviceMemoryMapper mapper)
2530+
: queue_size_{queue_size}, mapper_{std::move(mapper)} {}
2531+
2532+
Result<RecordBatchWithMetadata> next() {
2533+
TaskWithMetadata task;
2534+
{
2535+
std::unique_lock<std::mutex> lock(mutex_);
2536+
cv_.wait(lock,
2537+
[&] { return !error_.ok() || !batches_.empty() || end_of_stream_; });
2538+
if (!error_.ok()) {
2539+
return error_;
2540+
}
2541+
2542+
if (batches_.empty() && end_of_stream_) {
2543+
return IterationEnd<RecordBatchWithMetadata>();
2544+
}
2545+
2546+
task = std::move(batches_.front());
2547+
batches_.pop();
2548+
}
2549+
2550+
producer_->request(producer_, 1);
2551+
ArrowDeviceArray out;
2552+
if (task.task_.extract_data(&task.task_, &out) != 0) {
2553+
std::unique_lock<std::mutex> lock(mutex_);
2554+
cv_.wait(lock, [&] { return !error_.ok(); });
2555+
return error_;
2556+
}
2557+
2558+
ARROW_ASSIGN_OR_RAISE(auto batch, ImportDeviceRecordBatch(&out, schema_, mapper_));
2559+
return RecordBatchWithMetadata{std::move(batch), std::move(task.metadata_)};
2560+
}
2561+
2562+
const uint64_t queue_size_;
2563+
const DeviceMemoryMapper mapper_;
2564+
ArrowAsyncProducer* producer_;
2565+
DeviceAllocationType device_type_;
2566+
2567+
std::mutex mutex_;
2568+
std::shared_ptr<Schema> schema_;
2569+
std::condition_variable cv_;
2570+
std::queue<TaskWithMetadata> batches_;
2571+
bool end_of_stream_ = false;
2572+
Status error_{Status::OK()};
2573+
};
2574+
2575+
AsyncRecordBatchIterator(uint64_t queue_size, DeviceMemoryMapper mapper)
2576+
: state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
2577+
2578+
explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
2579+
: state_{std::move(state)} {}
2580+
2581+
const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
2582+
2583+
DeviceAllocationType device_type() const { return state_->device_type_; }
2584+
2585+
Result<RecordBatchWithMetadata> Next() { return state_->next(); }
2586+
2587+
static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
2588+
AsyncRecordBatchIterator& iterator, struct ArrowAsyncDeviceStreamHandler* handler) {
2589+
auto iterator_fut = Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
2590+
2591+
auto private_data = new PrivateData{iterator.state_};
2592+
private_data->fut_iterator_ = iterator_fut;
2593+
2594+
handler->private_data = private_data;
2595+
handler->on_schema = on_schema;
2596+
handler->on_next_task = on_next_task;
2597+
handler->on_error = on_error;
2598+
handler->release = release;
2599+
return iterator_fut;
2600+
}
2601+
2602+
private:
2603+
struct PrivateData {
2604+
explicit PrivateData(std::shared_ptr<State> state) : state_(std::move(state)) {}
2605+
2606+
std::shared_ptr<State> state_;
2607+
Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
2608+
ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
2609+
};
2610+
2611+
static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
2612+
struct ArrowSchema* stream_schema) {
2613+
auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
2614+
if (self->producer != nullptr) {
2615+
private_data->state_->producer_ = self->producer;
2616+
private_data->state_->device_type_ =
2617+
static_cast<DeviceAllocationType>(self->producer->device_type);
2618+
}
2619+
2620+
auto maybe_schema = ImportSchema(stream_schema);
2621+
if (!maybe_schema.ok()) {
2622+
private_data->fut_iterator_.MarkFinished(maybe_schema.status());
2623+
return EINVAL;
2624+
}
2625+
2626+
private_data->state_->schema_ = maybe_schema.MoveValueUnsafe();
2627+
private_data->fut_iterator_.MarkFinished(private_data->state_);
2628+
self->producer->request(self->producer,
2629+
static_cast<int64_t>(private_data->state_->queue_size_));
2630+
return 0;
2631+
}
2632+
2633+
static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask* task,
2634+
const char* metadata) {
2635+
auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
2636+
2637+
if (task == nullptr) {
2638+
std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
2639+
private_data->state_->end_of_stream_ = true;
2640+
lock.unlock();
2641+
private_data->state_->cv_.notify_one();
2642+
return 0;
2643+
}
2644+
2645+
std::shared_ptr<KeyValueMetadata> kvmetadata;
2646+
if (metadata != nullptr) {
2647+
auto maybe_decoded = DecodeMetadata(metadata);
2648+
if (!maybe_decoded.ok()) {
2649+
private_data->state_->error_ = std::move(maybe_decoded).status();
2650+
private_data->state_->cv_.notify_one();
2651+
return EINVAL;
2652+
}
2653+
2654+
kvmetadata = std::move(maybe_decoded->metadata);
2655+
}
2656+
2657+
std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
2658+
private_data->state_->batches_.push({*task, std::move(kvmetadata)});
2659+
lock.unlock();
2660+
private_data->state_->cv_.notify_one();
2661+
return 0;
2662+
}
2663+
2664+
static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const char* message,
2665+
const char* metadata) {
2666+
auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
2667+
std::string message_str, metadata_str;
2668+
if (message != nullptr) {
2669+
message_str = message;
2670+
}
2671+
if (metadata != nullptr) {
2672+
metadata_str = metadata;
2673+
}
2674+
2675+
Status error = Status::FromDetailAndArgs(
2676+
StatusCode::UnknownError,
2677+
std::make_shared<AsyncErrorDetail>(code, message_str, std::move(metadata_str)),
2678+
std::move(message_str));
2679+
2680+
if (!private_data->fut_iterator_.is_finished()) {
2681+
private_data->fut_iterator_.MarkFinished(error);
2682+
return;
2683+
}
2684+
2685+
std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
2686+
private_data->state_->error_ = std::move(error);
2687+
lock.unlock();
2688+
private_data->state_->cv_.notify_one();
2689+
}
2690+
2691+
static void release(ArrowAsyncDeviceStreamHandler* self) {
2692+
delete reinterpret_cast<PrivateData*>(self->private_data);
2693+
}
2694+
2695+
std::shared_ptr<State> state_;
2696+
};
2697+
2698+
struct AsyncProducer {
2699+
struct State {
2700+
struct ArrowAsyncProducer producer_;
2701+
2702+
std::mutex mutex_;
2703+
std::condition_variable cv_;
2704+
uint64_t pending_requests_{0};
2705+
Status error_{Status::OK()};
2706+
};
2707+
2708+
AsyncProducer(DeviceAllocationType device_type, struct ArrowSchema* schema,
2709+
struct ArrowAsyncDeviceStreamHandler* handler)
2710+
: handler_{handler}, state_{std::make_shared<State>()} {
2711+
state_->producer_.device_type = static_cast<ArrowDeviceType>(device_type);
2712+
state_->producer_.private_data = reinterpret_cast<void*>(state_.get());
2713+
state_->producer_.request = AsyncProducer::request;
2714+
state_->producer_.cancel = AsyncProducer::cancel;
2715+
handler_->producer = &state_->producer_;
2716+
2717+
if (int status = handler_->on_schema(handler_, schema) != 0) {
2718+
state_->error_ =
2719+
Status::UnknownError("Received error from handler::on_schema ", status);
2720+
}
2721+
}
2722+
2723+
struct PrivateTaskData {
2724+
PrivateTaskData(std::shared_ptr<State> producer, std::shared_ptr<RecordBatch> record)
2725+
: producer_{std::move(producer)}, record_(std::move(record)) {}
2726+
2727+
std::shared_ptr<State> producer_;
2728+
std::shared_ptr<RecordBatch> record_;
2729+
ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateTaskData);
2730+
};
2731+
2732+
Status operator()(const std::shared_ptr<RecordBatch>& record) {
2733+
std::unique_lock<std::mutex> lock(state_->mutex_);
2734+
if (state_->pending_requests_ == 0) {
2735+
state_->cv_.wait(lock, [this]() -> bool {
2736+
return !state_->error_.ok() || state_->pending_requests_ > 0;
2737+
});
2738+
}
2739+
2740+
if (!state_->error_.ok()) {
2741+
return state_->error_;
2742+
}
2743+
2744+
if (state_->pending_requests_ > 0) {
2745+
state_->pending_requests_--;
2746+
lock.unlock();
2747+
2748+
ArrowAsyncTask task;
2749+
task.private_data = new PrivateTaskData{state_, record};
2750+
task.extract_data = AsyncProducer::extract_data;
2751+
2752+
if (int status = handler_->on_next_task(handler_, &task, nullptr) != 0) {
2753+
delete reinterpret_cast<PrivateTaskData*>(task.private_data);
2754+
return Status::UnknownError("Received error from handler::on_next_task ", status);
2755+
}
2756+
}
2757+
2758+
return Status::OK();
2759+
}
2760+
2761+
static void request(struct ArrowAsyncProducer* producer, int64_t n) {
2762+
auto* self = reinterpret_cast<State*>(producer->private_data);
2763+
{
2764+
std::lock_guard<std::mutex> lock(self->mutex_);
2765+
if (!self->error_.ok()) {
2766+
return;
2767+
}
2768+
self->pending_requests_ += n;
2769+
}
2770+
self->cv_.notify_all();
2771+
}
2772+
2773+
static void cancel(struct ArrowAsyncProducer* producer) {
2774+
auto* self = reinterpret_cast<State*>(producer->private_data);
2775+
{
2776+
std::lock_guard<std::mutex> lock(self->mutex_);
2777+
if (!self->error_.ok()) {
2778+
return;
2779+
}
2780+
self->error_ = Status::Cancelled("Consumer requested cancellation");
2781+
}
2782+
self->cv_.notify_all();
2783+
}
2784+
2785+
static int extract_data(struct ArrowAsyncTask* task, struct ArrowDeviceArray* out) {
2786+
std::unique_ptr<PrivateTaskData> private_data{
2787+
reinterpret_cast<PrivateTaskData*>(task->private_data)};
2788+
int ret = 0;
2789+
if (out != nullptr) {
2790+
auto status = ExportDeviceRecordBatch(*private_data->record_,
2791+
private_data->record_->GetSyncEvent(), out);
2792+
if (!status.ok()) {
2793+
std::lock_guard<std::mutex> lock(private_data->producer_->mutex_);
2794+
private_data->producer_->error_ = status;
2795+
}
2796+
}
2797+
2798+
return ret;
2799+
}
2800+
2801+
struct ArrowAsyncDeviceStreamHandler* handler_;
2802+
std::shared_ptr<State> state_;
2803+
};
2804+
2805+
} // namespace
2806+
2807+
Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler(
2808+
struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* executor,
2809+
uint64_t queue_size, DeviceMemoryMapper mapper) {
2810+
auto iterator =
2811+
std::make_shared<AsyncRecordBatchIterator>(queue_size, std::move(mapper));
2812+
return AsyncRecordBatchIterator::Make(*iterator, handler)
2813+
.Then([executor](std::shared_ptr<AsyncRecordBatchIterator::State> state)
2814+
-> Result<AsyncRecordBatchGenerator> {
2815+
AsyncRecordBatchGenerator gen{state->schema_, state->device_type_, nullptr};
2816+
auto it =
2817+
Iterator<RecordBatchWithMetadata>(AsyncRecordBatchIterator{std::move(state)});
2818+
ARROW_ASSIGN_OR_RAISE(gen.generator,
2819+
MakeBackgroundGenerator(std::move(it), executor));
2820+
return gen;
2821+
});
2822+
}
2823+
2824+
Future<> ExportAsyncRecordBatchReader(
2825+
std::shared_ptr<Schema> schema,
2826+
AsyncGenerator<std::shared_ptr<RecordBatch>> generator,
2827+
DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler* handler) {
2828+
if (!schema) {
2829+
handler->on_error(handler, EINVAL, "Schema is null", nullptr);
2830+
handler->release(handler);
2831+
return Future<>::MakeFinished(Status::Invalid("Schema is null"));
2832+
}
2833+
2834+
struct ArrowSchema c_schema;
2835+
SchemaExportGuard guard(&c_schema);
2836+
2837+
auto status = ExportSchema(*schema, &c_schema);
2838+
if (!status.ok()) {
2839+
handler->on_error(handler, EINVAL, status.message().c_str(), nullptr);
2840+
handler->release(handler);
2841+
return Future<>::MakeFinished(status);
2842+
}
2843+
2844+
return VisitAsyncGenerator(generator, AsyncProducer{device_type, &c_schema, handler})
2845+
.Then(
2846+
[handler]() -> Status {
2847+
int status = handler->on_next_task(handler, nullptr, nullptr);
2848+
handler->release(handler);
2849+
if (status != 0) {
2850+
return Status::UnknownError("Received error from handler::on_next_task ",
2851+
status);
2852+
}
2853+
return Status::OK();
2854+
},
2855+
[handler](const Status status) -> Status {
2856+
handler->on_error(handler, EINVAL, status.message().c_str(), nullptr);
2857+
handler->release(handler);
2858+
return status;
2859+
});
2860+
}
2861+
25142862
} // namespace arrow

0 commit comments

Comments
 (0)