|
19 | 19 |
|
20 | 20 | #include <algorithm>
|
21 | 21 | #include <cerrno>
|
| 22 | +#include <condition_variable> |
22 | 23 | #include <cstring>
|
23 | 24 | #include <memory>
|
| 25 | +#include <mutex> |
| 26 | +#include <queue> |
24 | 27 | #include <string>
|
25 | 28 | #include <string_view>
|
26 | 29 | #include <utility>
|
|
37 | 40 | #include "arrow/result.h"
|
38 | 41 | #include "arrow/stl_allocator.h"
|
39 | 42 | #include "arrow/type_traits.h"
|
| 43 | +#include "arrow/util/async_generator.h" |
40 | 44 | #include "arrow/util/bit_util.h"
|
41 | 45 | #include "arrow/util/checked_cast.h"
|
| 46 | +#include "arrow/util/future.h" |
42 | 47 | #include "arrow/util/key_value_metadata.h"
|
43 | 48 | #include "arrow/util/logging.h"
|
44 | 49 | #include "arrow/util/macros.h"
|
@@ -2511,4 +2516,347 @@ Result<std::shared_ptr<ChunkedArray>> ImportDeviceChunkedArray(
|
2511 | 2516 | return ImportChunked</*IsDevice=*/true>(stream, mapper);
|
2512 | 2517 | }
|
2513 | 2518 |
|
| 2519 | +namespace { |
| 2520 | + |
| 2521 | +class AsyncRecordBatchIterator { |
| 2522 | + public: |
| 2523 | + struct TaskWithMetadata { |
| 2524 | + ArrowAsyncTask task_; |
| 2525 | + std::shared_ptr<KeyValueMetadata> metadata_; |
| 2526 | + }; |
| 2527 | + |
| 2528 | + struct State { |
| 2529 | + State(uint64_t queue_size, DeviceMemoryMapper mapper) |
| 2530 | + : queue_size_{queue_size}, mapper_{std::move(mapper)} {} |
| 2531 | + |
| 2532 | + Result<RecordBatchWithMetadata> next() { |
| 2533 | + TaskWithMetadata task; |
| 2534 | + { |
| 2535 | + std::unique_lock<std::mutex> lock(mutex_); |
| 2536 | + cv_.wait(lock, |
| 2537 | + [&] { return !error_.ok() || !batches_.empty() || end_of_stream_; }); |
| 2538 | + if (!error_.ok()) { |
| 2539 | + return error_; |
| 2540 | + } |
| 2541 | + |
| 2542 | + if (batches_.empty() && end_of_stream_) { |
| 2543 | + return IterationEnd<RecordBatchWithMetadata>(); |
| 2544 | + } |
| 2545 | + |
| 2546 | + task = std::move(batches_.front()); |
| 2547 | + batches_.pop(); |
| 2548 | + } |
| 2549 | + |
| 2550 | + producer_->request(producer_, 1); |
| 2551 | + ArrowDeviceArray out; |
| 2552 | + if (task.task_.extract_data(&task.task_, &out) != 0) { |
| 2553 | + std::unique_lock<std::mutex> lock(mutex_); |
| 2554 | + cv_.wait(lock, [&] { return !error_.ok(); }); |
| 2555 | + return error_; |
| 2556 | + } |
| 2557 | + |
| 2558 | + ARROW_ASSIGN_OR_RAISE(auto batch, ImportDeviceRecordBatch(&out, schema_, mapper_)); |
| 2559 | + return RecordBatchWithMetadata{std::move(batch), std::move(task.metadata_)}; |
| 2560 | + } |
| 2561 | + |
| 2562 | + const uint64_t queue_size_; |
| 2563 | + const DeviceMemoryMapper mapper_; |
| 2564 | + ArrowAsyncProducer* producer_; |
| 2565 | + DeviceAllocationType device_type_; |
| 2566 | + |
| 2567 | + std::mutex mutex_; |
| 2568 | + std::shared_ptr<Schema> schema_; |
| 2569 | + std::condition_variable cv_; |
| 2570 | + std::queue<TaskWithMetadata> batches_; |
| 2571 | + bool end_of_stream_ = false; |
| 2572 | + Status error_{Status::OK()}; |
| 2573 | + }; |
| 2574 | + |
| 2575 | + AsyncRecordBatchIterator(uint64_t queue_size, DeviceMemoryMapper mapper) |
| 2576 | + : state_{std::make_shared<State>(queue_size, std::move(mapper))} {} |
| 2577 | + |
| 2578 | + explicit AsyncRecordBatchIterator(std::shared_ptr<State> state) |
| 2579 | + : state_{std::move(state)} {} |
| 2580 | + |
| 2581 | + const std::shared_ptr<Schema>& schema() const { return state_->schema_; } |
| 2582 | + |
| 2583 | + DeviceAllocationType device_type() const { return state_->device_type_; } |
| 2584 | + |
| 2585 | + Result<RecordBatchWithMetadata> Next() { return state_->next(); } |
| 2586 | + |
| 2587 | + static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make( |
| 2588 | + AsyncRecordBatchIterator& iterator, struct ArrowAsyncDeviceStreamHandler* handler) { |
| 2589 | + auto iterator_fut = Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make(); |
| 2590 | + |
| 2591 | + auto private_data = new PrivateData{iterator.state_}; |
| 2592 | + private_data->fut_iterator_ = iterator_fut; |
| 2593 | + |
| 2594 | + handler->private_data = private_data; |
| 2595 | + handler->on_schema = on_schema; |
| 2596 | + handler->on_next_task = on_next_task; |
| 2597 | + handler->on_error = on_error; |
| 2598 | + handler->release = release; |
| 2599 | + return iterator_fut; |
| 2600 | + } |
| 2601 | + |
| 2602 | + private: |
| 2603 | + struct PrivateData { |
| 2604 | + explicit PrivateData(std::shared_ptr<State> state) : state_(std::move(state)) {} |
| 2605 | + |
| 2606 | + std::shared_ptr<State> state_; |
| 2607 | + Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_; |
| 2608 | + ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData); |
| 2609 | + }; |
| 2610 | + |
| 2611 | + static int on_schema(struct ArrowAsyncDeviceStreamHandler* self, |
| 2612 | + struct ArrowSchema* stream_schema) { |
| 2613 | + auto* private_data = reinterpret_cast<PrivateData*>(self->private_data); |
| 2614 | + if (self->producer != nullptr) { |
| 2615 | + private_data->state_->producer_ = self->producer; |
| 2616 | + private_data->state_->device_type_ = |
| 2617 | + static_cast<DeviceAllocationType>(self->producer->device_type); |
| 2618 | + } |
| 2619 | + |
| 2620 | + auto maybe_schema = ImportSchema(stream_schema); |
| 2621 | + if (!maybe_schema.ok()) { |
| 2622 | + private_data->fut_iterator_.MarkFinished(maybe_schema.status()); |
| 2623 | + return EINVAL; |
| 2624 | + } |
| 2625 | + |
| 2626 | + private_data->state_->schema_ = maybe_schema.MoveValueUnsafe(); |
| 2627 | + private_data->fut_iterator_.MarkFinished(private_data->state_); |
| 2628 | + self->producer->request(self->producer, |
| 2629 | + static_cast<int64_t>(private_data->state_->queue_size_)); |
| 2630 | + return 0; |
| 2631 | + } |
| 2632 | + |
| 2633 | + static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask* task, |
| 2634 | + const char* metadata) { |
| 2635 | + auto* private_data = reinterpret_cast<PrivateData*>(self->private_data); |
| 2636 | + |
| 2637 | + if (task == nullptr) { |
| 2638 | + std::unique_lock<std::mutex> lock(private_data->state_->mutex_); |
| 2639 | + private_data->state_->end_of_stream_ = true; |
| 2640 | + lock.unlock(); |
| 2641 | + private_data->state_->cv_.notify_one(); |
| 2642 | + return 0; |
| 2643 | + } |
| 2644 | + |
| 2645 | + std::shared_ptr<KeyValueMetadata> kvmetadata; |
| 2646 | + if (metadata != nullptr) { |
| 2647 | + auto maybe_decoded = DecodeMetadata(metadata); |
| 2648 | + if (!maybe_decoded.ok()) { |
| 2649 | + private_data->state_->error_ = std::move(maybe_decoded).status(); |
| 2650 | + private_data->state_->cv_.notify_one(); |
| 2651 | + return EINVAL; |
| 2652 | + } |
| 2653 | + |
| 2654 | + kvmetadata = std::move(maybe_decoded->metadata); |
| 2655 | + } |
| 2656 | + |
| 2657 | + std::unique_lock<std::mutex> lock(private_data->state_->mutex_); |
| 2658 | + private_data->state_->batches_.push({*task, std::move(kvmetadata)}); |
| 2659 | + lock.unlock(); |
| 2660 | + private_data->state_->cv_.notify_one(); |
| 2661 | + return 0; |
| 2662 | + } |
| 2663 | + |
| 2664 | + static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const char* message, |
| 2665 | + const char* metadata) { |
| 2666 | + auto* private_data = reinterpret_cast<PrivateData*>(self->private_data); |
| 2667 | + std::string message_str, metadata_str; |
| 2668 | + if (message != nullptr) { |
| 2669 | + message_str = message; |
| 2670 | + } |
| 2671 | + if (metadata != nullptr) { |
| 2672 | + metadata_str = metadata; |
| 2673 | + } |
| 2674 | + |
| 2675 | + Status error = Status::FromDetailAndArgs( |
| 2676 | + StatusCode::UnknownError, |
| 2677 | + std::make_shared<AsyncErrorDetail>(code, message_str, std::move(metadata_str)), |
| 2678 | + std::move(message_str)); |
| 2679 | + |
| 2680 | + if (!private_data->fut_iterator_.is_finished()) { |
| 2681 | + private_data->fut_iterator_.MarkFinished(error); |
| 2682 | + return; |
| 2683 | + } |
| 2684 | + |
| 2685 | + std::unique_lock<std::mutex> lock(private_data->state_->mutex_); |
| 2686 | + private_data->state_->error_ = std::move(error); |
| 2687 | + lock.unlock(); |
| 2688 | + private_data->state_->cv_.notify_one(); |
| 2689 | + } |
| 2690 | + |
| 2691 | + static void release(ArrowAsyncDeviceStreamHandler* self) { |
| 2692 | + delete reinterpret_cast<PrivateData*>(self->private_data); |
| 2693 | + } |
| 2694 | + |
| 2695 | + std::shared_ptr<State> state_; |
| 2696 | +}; |
| 2697 | + |
| 2698 | +struct AsyncProducer { |
| 2699 | + struct State { |
| 2700 | + struct ArrowAsyncProducer producer_; |
| 2701 | + |
| 2702 | + std::mutex mutex_; |
| 2703 | + std::condition_variable cv_; |
| 2704 | + uint64_t pending_requests_{0}; |
| 2705 | + Status error_{Status::OK()}; |
| 2706 | + }; |
| 2707 | + |
| 2708 | + AsyncProducer(DeviceAllocationType device_type, struct ArrowSchema* schema, |
| 2709 | + struct ArrowAsyncDeviceStreamHandler* handler) |
| 2710 | + : handler_{handler}, state_{std::make_shared<State>()} { |
| 2711 | + state_->producer_.device_type = static_cast<ArrowDeviceType>(device_type); |
| 2712 | + state_->producer_.private_data = reinterpret_cast<void*>(state_.get()); |
| 2713 | + state_->producer_.request = AsyncProducer::request; |
| 2714 | + state_->producer_.cancel = AsyncProducer::cancel; |
| 2715 | + handler_->producer = &state_->producer_; |
| 2716 | + |
| 2717 | + if (int status = handler_->on_schema(handler_, schema) != 0) { |
| 2718 | + state_->error_ = |
| 2719 | + Status::UnknownError("Received error from handler::on_schema ", status); |
| 2720 | + } |
| 2721 | + } |
| 2722 | + |
| 2723 | + struct PrivateTaskData { |
| 2724 | + PrivateTaskData(std::shared_ptr<State> producer, std::shared_ptr<RecordBatch> record) |
| 2725 | + : producer_{std::move(producer)}, record_(std::move(record)) {} |
| 2726 | + |
| 2727 | + std::shared_ptr<State> producer_; |
| 2728 | + std::shared_ptr<RecordBatch> record_; |
| 2729 | + ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateTaskData); |
| 2730 | + }; |
| 2731 | + |
| 2732 | + Status operator()(const std::shared_ptr<RecordBatch>& record) { |
| 2733 | + std::unique_lock<std::mutex> lock(state_->mutex_); |
| 2734 | + if (state_->pending_requests_ == 0) { |
| 2735 | + state_->cv_.wait(lock, [this]() -> bool { |
| 2736 | + return !state_->error_.ok() || state_->pending_requests_ > 0; |
| 2737 | + }); |
| 2738 | + } |
| 2739 | + |
| 2740 | + if (!state_->error_.ok()) { |
| 2741 | + return state_->error_; |
| 2742 | + } |
| 2743 | + |
| 2744 | + if (state_->pending_requests_ > 0) { |
| 2745 | + state_->pending_requests_--; |
| 2746 | + lock.unlock(); |
| 2747 | + |
| 2748 | + ArrowAsyncTask task; |
| 2749 | + task.private_data = new PrivateTaskData{state_, record}; |
| 2750 | + task.extract_data = AsyncProducer::extract_data; |
| 2751 | + |
| 2752 | + if (int status = handler_->on_next_task(handler_, &task, nullptr) != 0) { |
| 2753 | + delete reinterpret_cast<PrivateTaskData*>(task.private_data); |
| 2754 | + return Status::UnknownError("Received error from handler::on_next_task ", status); |
| 2755 | + } |
| 2756 | + } |
| 2757 | + |
| 2758 | + return Status::OK(); |
| 2759 | + } |
| 2760 | + |
| 2761 | + static void request(struct ArrowAsyncProducer* producer, int64_t n) { |
| 2762 | + auto* self = reinterpret_cast<State*>(producer->private_data); |
| 2763 | + { |
| 2764 | + std::lock_guard<std::mutex> lock(self->mutex_); |
| 2765 | + if (!self->error_.ok()) { |
| 2766 | + return; |
| 2767 | + } |
| 2768 | + self->pending_requests_ += n; |
| 2769 | + } |
| 2770 | + self->cv_.notify_all(); |
| 2771 | + } |
| 2772 | + |
| 2773 | + static void cancel(struct ArrowAsyncProducer* producer) { |
| 2774 | + auto* self = reinterpret_cast<State*>(producer->private_data); |
| 2775 | + { |
| 2776 | + std::lock_guard<std::mutex> lock(self->mutex_); |
| 2777 | + if (!self->error_.ok()) { |
| 2778 | + return; |
| 2779 | + } |
| 2780 | + self->error_ = Status::Cancelled("Consumer requested cancellation"); |
| 2781 | + } |
| 2782 | + self->cv_.notify_all(); |
| 2783 | + } |
| 2784 | + |
| 2785 | + static int extract_data(struct ArrowAsyncTask* task, struct ArrowDeviceArray* out) { |
| 2786 | + std::unique_ptr<PrivateTaskData> private_data{ |
| 2787 | + reinterpret_cast<PrivateTaskData*>(task->private_data)}; |
| 2788 | + int ret = 0; |
| 2789 | + if (out != nullptr) { |
| 2790 | + auto status = ExportDeviceRecordBatch(*private_data->record_, |
| 2791 | + private_data->record_->GetSyncEvent(), out); |
| 2792 | + if (!status.ok()) { |
| 2793 | + std::lock_guard<std::mutex> lock(private_data->producer_->mutex_); |
| 2794 | + private_data->producer_->error_ = status; |
| 2795 | + } |
| 2796 | + } |
| 2797 | + |
| 2798 | + return ret; |
| 2799 | + } |
| 2800 | + |
| 2801 | + struct ArrowAsyncDeviceStreamHandler* handler_; |
| 2802 | + std::shared_ptr<State> state_; |
| 2803 | +}; |
| 2804 | + |
| 2805 | +} // namespace |
| 2806 | + |
| 2807 | +Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler( |
| 2808 | + struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* executor, |
| 2809 | + uint64_t queue_size, DeviceMemoryMapper mapper) { |
| 2810 | + auto iterator = |
| 2811 | + std::make_shared<AsyncRecordBatchIterator>(queue_size, std::move(mapper)); |
| 2812 | + return AsyncRecordBatchIterator::Make(*iterator, handler) |
| 2813 | + .Then([executor](std::shared_ptr<AsyncRecordBatchIterator::State> state) |
| 2814 | + -> Result<AsyncRecordBatchGenerator> { |
| 2815 | + AsyncRecordBatchGenerator gen{state->schema_, state->device_type_, nullptr}; |
| 2816 | + auto it = |
| 2817 | + Iterator<RecordBatchWithMetadata>(AsyncRecordBatchIterator{std::move(state)}); |
| 2818 | + ARROW_ASSIGN_OR_RAISE(gen.generator, |
| 2819 | + MakeBackgroundGenerator(std::move(it), executor)); |
| 2820 | + return gen; |
| 2821 | + }); |
| 2822 | +} |
| 2823 | + |
| 2824 | +Future<> ExportAsyncRecordBatchReader( |
| 2825 | + std::shared_ptr<Schema> schema, |
| 2826 | + AsyncGenerator<std::shared_ptr<RecordBatch>> generator, |
| 2827 | + DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler* handler) { |
| 2828 | + if (!schema) { |
| 2829 | + handler->on_error(handler, EINVAL, "Schema is null", nullptr); |
| 2830 | + handler->release(handler); |
| 2831 | + return Future<>::MakeFinished(Status::Invalid("Schema is null")); |
| 2832 | + } |
| 2833 | + |
| 2834 | + struct ArrowSchema c_schema; |
| 2835 | + SchemaExportGuard guard(&c_schema); |
| 2836 | + |
| 2837 | + auto status = ExportSchema(*schema, &c_schema); |
| 2838 | + if (!status.ok()) { |
| 2839 | + handler->on_error(handler, EINVAL, status.message().c_str(), nullptr); |
| 2840 | + handler->release(handler); |
| 2841 | + return Future<>::MakeFinished(status); |
| 2842 | + } |
| 2843 | + |
| 2844 | + return VisitAsyncGenerator(generator, AsyncProducer{device_type, &c_schema, handler}) |
| 2845 | + .Then( |
| 2846 | + [handler]() -> Status { |
| 2847 | + int status = handler->on_next_task(handler, nullptr, nullptr); |
| 2848 | + handler->release(handler); |
| 2849 | + if (status != 0) { |
| 2850 | + return Status::UnknownError("Received error from handler::on_next_task ", |
| 2851 | + status); |
| 2852 | + } |
| 2853 | + return Status::OK(); |
| 2854 | + }, |
| 2855 | + [handler](const Status status) -> Status { |
| 2856 | + handler->on_error(handler, EINVAL, status.message().c_str(), nullptr); |
| 2857 | + handler->release(handler); |
| 2858 | + return status; |
| 2859 | + }); |
| 2860 | +} |
| 2861 | + |
2514 | 2862 | } // namespace arrow
|
0 commit comments