From 2e32ebb225e0269e28f5e0b06987e84961d543d2 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Fri, 3 Oct 2025 16:00:00 +0200 Subject: [PATCH 01/11] Add output streams and serializers --- CMakeLists.txt | 14 +- .../chunk_memory_output_stream.hpp | 81 +++ .../sparrow_ipc/chunk_memory_serializer.hpp | 73 +++ include/sparrow_ipc/encapsulated_message.hpp | 1 + include/sparrow_ipc/file_output_stream.hpp | 36 ++ include/sparrow_ipc/flatbuffer_utils.hpp | 217 +++++++ include/sparrow_ipc/memory_output_stream.hpp | 64 ++ include/sparrow_ipc/output_stream.hpp | 100 +++ include/sparrow_ipc/serialize.hpp | 63 +- include/sparrow_ipc/serialize_utils.hpp | 336 ++-------- include/sparrow_ipc/serializer.hpp | 136 ++++ include/sparrow_ipc/utils.hpp | 14 +- src/chunk_memory_serializer.cpp | 54 ++ src/file_output_stream.cpp | 57 ++ src/flatbuffer_utils.cpp | 586 +++++++++++++++++ src/serialize.cpp | 31 + src/serialize_utils.cpp | 299 ++------- src/serializer.cpp | 61 ++ src/utils.cpp | 452 +------------ tests/CMakeLists.txt | 41 +- tests/include/sparrow_ipc_tests_helpers.hpp | 13 +- tests/test_chunk_memory_output_stream.cpp | 570 +++++++++++++++++ tests/test_chunk_memory_serializer.cpp | 381 +++++++++++ tests/test_de_serialization_with_files.cpp | 8 +- tests/test_file_output_stream.cpp | 604 ++++++++++++++++++ tests/test_flatbuffer_utils.cpp | 535 ++++++++++++++++ tests/test_memory_output_streams.cpp | 372 +++++++++++ tests/test_serialize_utils.cpp | 547 ++++++++-------- tests/test_utils.cpp | 333 +--------- 29 files changed, 4461 insertions(+), 1618 deletions(-) create mode 100644 include/sparrow_ipc/chunk_memory_output_stream.hpp create mode 100644 include/sparrow_ipc/chunk_memory_serializer.hpp create mode 100644 include/sparrow_ipc/file_output_stream.hpp create mode 100644 include/sparrow_ipc/flatbuffer_utils.hpp create mode 100644 include/sparrow_ipc/memory_output_stream.hpp create mode 100644 include/sparrow_ipc/output_stream.hpp create mode 100644 include/sparrow_ipc/serializer.hpp create mode 100644 src/chunk_memory_serializer.cpp create mode 100644 src/file_output_stream.cpp create mode 100644 src/flatbuffer_utils.cpp create mode 100644 src/serialize.cpp create mode 100644 src/serializer.cpp create mode 100644 tests/test_chunk_memory_output_stream.cpp create mode 100644 tests/test_chunk_memory_serializer.cpp create mode 100644 tests/test_file_output_stream.cpp create mode 100644 tests/test_flatbuffer_utils.cpp create mode 100644 tests/test_memory_output_streams.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 86b80e3..4008df7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -100,6 +100,8 @@ set(SPARROW_IPC_HEADERS ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_schema.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_schema/private_data.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_output_stream.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_serializer.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/config.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/sparrow_ipc_version.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp @@ -109,24 +111,34 @@ set(SPARROW_IPC_HEADERS ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/encapsulated_message.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/file_output_stream.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/flatbuffer_utils.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/magic_values.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/memory_output_stream.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/metadata.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/output_stream.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/serialize_utils.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/serialize.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/serializer.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/utils.hpp ) set(SPARROW_IPC_SRC - ${SPARROW_IPC_SOURCE_DIR}/serialize_utils.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_array.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_array/private_data.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_schema.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_schema/private_data.cpp + ${SPARROW_IPC_SOURCE_DIR}/chunk_memory_serializer.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize_fixedsizebinary_array.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize_utils.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize.cpp ${SPARROW_IPC_SOURCE_DIR}/encapsulated_message.cpp + ${SPARROW_IPC_SOURCE_DIR}/file_output_stream.cpp + ${SPARROW_IPC_SOURCE_DIR}/flatbuffer_utils.cpp ${SPARROW_IPC_SOURCE_DIR}/metadata.cpp + ${SPARROW_IPC_SOURCE_DIR}/serialize_utils.cpp + ${SPARROW_IPC_SOURCE_DIR}/serialize.cpp + ${SPARROW_IPC_SOURCE_DIR}/serializer.cpp ${SPARROW_IPC_SOURCE_DIR}/utils.cpp ) diff --git a/include/sparrow_ipc/chunk_memory_output_stream.hpp b/include/sparrow_ipc/chunk_memory_output_stream.hpp new file mode 100644 index 0000000..bad94ef --- /dev/null +++ b/include/sparrow_ipc/chunk_memory_output_stream.hpp @@ -0,0 +1,81 @@ +#pragma once + +#include +#include + +#include "sparrow_ipc/output_stream.hpp" + +namespace sparrow_ipc +{ + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + class chuncked_memory_output_stream final : public output_stream + { + public: + + explicit chuncked_memory_output_stream(R& chunks) + : m_chunks(&chunks) {}; + + std::size_t write(std::span span) override + { + m_chunks->emplace_back(span.begin(), span.end()); + return span.size(); + } + + std::size_t write(std::vector&& buffer) + { + m_chunks->emplace_back(std::move(buffer)); + return m_chunks->back().size(); + } + + std::size_t write(uint8_t value, std::size_t count) override + { + m_chunks->emplace_back(count, value); + return count; + } + + void reserve(std::size_t size) override + { + m_chunks->reserve(size); + } + + void reserve(const std::function& calculate_reserve_size) override + { + m_chunks->reserve(calculate_reserve_size()); + } + + size_t size() const override + { + return std::accumulate( + m_chunks->begin(), + m_chunks->end(), + 0, + [](size_t acc, const auto& chunk) + { + return acc + chunk.size(); + } + ); + } + + void flush() override + { + // Implementation for flushing memory + } + + void close() override + { + // Implementation for closing the stream + } + + bool is_open() const override + { + return true; + } + + private: + + R* m_chunks; + }; +} \ No newline at end of file diff --git a/include/sparrow_ipc/chunk_memory_serializer.hpp b/include/sparrow_ipc/chunk_memory_serializer.hpp new file mode 100644 index 0000000..a897354 --- /dev/null +++ b/include/sparrow_ipc/chunk_memory_serializer.hpp @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include "sparrow_ipc/chunk_memory_output_stream.hpp" +#include "sparrow_ipc/memory_output_stream.hpp" +#include "sparrow_ipc/serialize.hpp" +#include "sparrow_ipc/serialize_utils.hpp" + +namespace sparrow_ipc +{ + class chunk_serializer + { + public: + + chunk_serializer( + const sparrow::record_batch& rb, + chuncked_memory_output_stream>>& stream + ); + + template + requires std::same_as, sparrow::record_batch> + chunk_serializer( + const R& record_batches, + chuncked_memory_output_stream>>& stream + ) + : m_pstream(&stream) + { + if (record_batches.empty()) + { + throw std::invalid_argument("Record batches collection is empty"); + } + m_dtypes = get_column_dtypes(record_batches[0]); + + m_pstream->reserve(record_batches.size() + 1); + std::vector buffer; + memory_output_stream schema_stream(buffer); + serialize_schema_message(record_batches[0], schema_stream); + m_pstream->write(std::move(buffer)); + append(record_batches); + } + + void append(const sparrow::record_batch& rb); + + template + requires std::same_as, sparrow::record_batch> + void append(const R& record_batches) + { + if (m_ended) + { + throw std::runtime_error("Cannot append to a serializer that has been ended"); + } + + m_pstream->reserve(m_pstream->size() + record_batches.size()); + + for (const auto& rb : record_batches) + { + std::vector buffer; + memory_output_stream stream(buffer); + serialize_record_batch(rb, stream); + m_pstream->write(std::move(buffer)); + } + } + + void end(); + + private: + + std::vector m_dtypes; + chuncked_memory_output_stream>>* m_pstream; + bool m_ended{false}; + }; +} \ No newline at end of file diff --git a/include/sparrow_ipc/encapsulated_message.hpp b/include/sparrow_ipc/encapsulated_message.hpp index 7e95339..cea09a6 100644 --- a/include/sparrow_ipc/encapsulated_message.hpp +++ b/include/sparrow_ipc/encapsulated_message.hpp @@ -2,6 +2,7 @@ #include #include +#include #include "Message_generated.h" diff --git a/include/sparrow_ipc/file_output_stream.hpp b/include/sparrow_ipc/file_output_stream.hpp new file mode 100644 index 0000000..9057caa --- /dev/null +++ b/include/sparrow_ipc/file_output_stream.hpp @@ -0,0 +1,36 @@ +#include +#include + +#include "sparrow_ipc/output_stream.hpp" + + +namespace sparrow_ipc +{ + class SPARROW_IPC_API file_output_stream final : public output_stream + { + public: + + explicit file_output_stream(std::ofstream& file); + + std::size_t write(std::span span) override; + + std::size_t write(uint8_t value, std::size_t count = 1) override; + + size_t size() const override; + + void reserve(std::size_t size) override; + + void reserve(const std::function& calculate_reserve_size) override; + + void flush() override; + + void close() override; + + bool is_open() const override; + + private: + + std::ofstream& m_file; + size_t m_written_bytes = 0; + }; +} \ No newline at end of file diff --git a/include/sparrow_ipc/flatbuffer_utils.hpp b/include/sparrow_ipc/flatbuffer_utils.hpp new file mode 100644 index 0000000..4ec4ef7 --- /dev/null +++ b/include/sparrow_ipc/flatbuffer_utils.hpp @@ -0,0 +1,217 @@ +#pragma once +#include +#include + +#include +#include + +namespace sparrow_ipc +{ + // Creates a Flatbuffers Decimal type from a format string + // The format string is expected to be in the format "d:precision,scale" + [[nodiscard]] std::pair> + get_flatbuffer_decimal_type( + flatbuffers::FlatBufferBuilder& builder, + std::string_view format_str, + const int32_t bitWidth + ); + + // Creates a Flatbuffers type from a format string + // This function maps a sparrow data type to the corresponding Flatbuffers type + [[nodiscard]] std::pair> + get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str); + + /** + * @brief Creates a FlatBuffers vector of KeyValue pairs from ArrowSchema metadata. + * + * This function converts metadata from an ArrowSchema into a FlatBuffers representation + * suitable for serialization. It processes key-value pairs from the schema's metadata + * and creates corresponding FlatBuffers KeyValue objects. + * + * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffers objects + * @param arrow_schema The ArrowSchema containing metadata to be serialized + * + * @return A FlatBuffers offset to a vector of KeyValue pairs. Returns 0 if the schema + * has no metadata (metadata is nullptr). + * + * @note The function reserves memory for the vector based on the metadata size for + * optimal performance. + */ + [[nodiscard]] flatbuffers::Offset>> + create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + + /** + * @brief Creates a FlatBuffer Field object from an ArrowSchema. + * + * This function converts an ArrowSchema structure into a FlatBuffer Field representation + * suitable for Apache Arrow IPC serialization. It handles the creation of all necessary + * components including field name, type information, metadata, children, and nullable flag. + * + * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffer objects + * @param arrow_schema The ArrowSchema structure containing the field definition to convert + * + * @return A FlatBuffer offset to the created Field object that can be used in further + * FlatBuffer construction operations + * + * @note Dictionary encoding is not currently supported (TODO item) + * @note The function checks the NULLABLE flag from the ArrowSchema flags to determine nullability + */ + [[nodiscard]] ::flatbuffers::Offset + create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + + /** + * @brief Creates a FlatBuffers vector of Field objects from an ArrowSchema's children. + * + * This function iterates through all children of the given ArrowSchema and converts + * each child to a FlatBuffers Field object. The resulting fields are collected into + * a FlatBuffers vector. + * + * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffers objects + * @param arrow_schema The ArrowSchema containing the children to convert + * + * @return A FlatBuffers offset to a vector of Field objects, or 0 if no children exist + * + * @throws std::invalid_argument If any child pointer in the ArrowSchema is null + * + * @note The function reserves space for all children upfront for performance optimization + * @note Returns 0 (null offset) when the schema has no children, otherwise returns a valid vector offset + */ + [[nodiscard]] ::flatbuffers::Offset< + ::flatbuffers::Vector<::flatbuffers::Offset>> + create_children(flatbuffers::FlatBufferBuilder& builder, sparrow::record_batch::column_range columns); + + /** + * @brief Creates a FlatBuffers vector of Field objects from a range of columns. + * + * This function iterates through the provided column range, extracts the Arrow schema + * from each column's proxy, and creates corresponding FlatBuffers Field objects. + * The resulting fields are collected into a vector and converted to a FlatBuffers + * vector offset. + * + * @param builder Reference to the FlatBuffers builder used for creating the vector + * @param columns Range of columns to process, each containing an Arrow schema proxy + * + * @return FlatBuffers offset to a vector of Field objects, or 0 if the input range is empty + * + * @note The function reserves space in the children vector based on the column count + * for performance optimization + */ + [[nodiscard]] ::flatbuffers::Offset< + ::flatbuffers::Vector<::flatbuffers::Offset>> + create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + + /** + * @brief Creates a FlatBuffer builder containing a serialized Arrow schema message. + * + * This function constructs an Arrow IPC schema message from a record batch by: + * 1. Creating field definitions from the record batch columns + * 2. Building a Schema flatbuffer with little-endian byte order + * 3. Wrapping the schema in a Message with metadata version V5 + * 4. Finalizing the buffer for serialization + * + * @param record_batch The source record batch containing column definitions + * @return flatbuffers::FlatBufferBuilder A completed FlatBuffer containing the schema message, + * ready for Arrow IPC serialization + * + * @note The schema message has zero body length as it contains only metadata + * @note Currently uses little-endian byte order (marked as TODO for configurability) + */ + [[nodiscard]] flatbuffers::FlatBufferBuilder + get_schema_message_builder(const sparrow::record_batch& record_batch); + + /** + * @brief Recursively fills a vector of FieldNode objects from an arrow_proxy and its children. + * + * This function creates FieldNode objects containing length and null count information + * from the given arrow_proxy and recursively processes all its children, appending + * them to the provided nodes vector in depth-first order. + * + * @param arrow_proxy The arrow proxy object containing array metadata (length, null_count) + * and potential child arrays + * @param nodes Reference to a vector that will be populated with FieldNode objects. + * Each FieldNode contains the length and null count of the corresponding array. + * + * @note The function reserves space in the nodes vector to optimize memory allocation + * when processing children arrays. + * @note The traversal order is depth-first, with parent nodes added before their children. + */ + void fill_fieldnodes( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& nodes + ); + + /** + * @brief Creates a vector of Apache Arrow FieldNode objects from a record batch. + * + * This function iterates through all columns in the provided record batch and + * generates corresponding FieldNode flatbuffer objects. Each column's arrow proxy + * is used to populate the field nodes vector through the fill_fieldnodes function. + * + * @param record_batch The sparrow record batch containing columns to process + * @return std::vector Vector of FieldNode + * objects representing the structure and metadata of each column + */ + [[nodiscard]] std::vector + create_fieldnodes(const sparrow::record_batch& record_batch); + + + /** + * @brief Recursively fills a vector of FlatBuffer Buffer objects with buffer information from an Arrow + * proxy. + * + * This function traverses an Arrow proxy structure and creates FlatBuffer Buffer entries for each buffer + * found in the proxy and its children. The buffers are processed in a depth-first manner, first handling + * the buffers of the current proxy, then recursively processing all child proxies. + * + * @param arrow_proxy The Arrow proxy object containing buffers and potential child proxies to process + * @param flatbuf_buffers Vector of FlatBuffer Buffer objects to be populated with buffer information + * @param offset Reference to the current byte offset, updated as buffers are processed and aligned to + * 8-byte boundaries + * + * @note The offset is automatically aligned to 8-byte boundaries using utils::align_to_8() for each + * buffer + * @note This function modifies both the flatbuf_buffers vector and the offset parameter + */ + void fill_buffers( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& flatbuf_buffers, + int64_t& offset + ); + + /** + * @brief Extracts buffer information from a record batch for serialization. + * + * This function iterates through all columns in the provided record batch and + * collects their buffer information into a vector of Arrow FlatBuffer Buffer objects. + * The buffers are processed sequentially with cumulative offset tracking. + * + * @param record_batch The sparrow record batch containing columns to extract buffers from + * @return std::vector A vector containing all buffer + * descriptors from the record batch columns, with properly calculated offsets + * + * @note This function relies on the fill_buffers helper function to process individual + * column buffers and maintain offset consistency across all buffers. + */ + [[nodiscard]] std::vector + get_buffers(const sparrow::record_batch& record_batch); + + /** + * @brief Creates a FlatBuffer message containing a serialized Apache Arrow RecordBatch. + * + * This function builds a complete Arrow IPC message by serializing a record batch + * along with its metadata (field nodes and buffer information) into a FlatBuffer + * format that conforms to the Arrow IPC specification. + * + * @param record_batch The source record batch containing the data to be serialized + * + * @return A FlatBufferBuilder containing the complete serialized message ready for + * transmission or storage. The builder is finished and ready to be accessed + * via GetBufferPointer() and GetSize(). + * + * @note The returned message uses Arrow IPC format version V5 + * @note Compression and variadic buffer counts are not currently implemented (set to 0) + * @note The body size is automatically calculated based on the record batch contents + */ + [[nodiscard]] flatbuffers::FlatBufferBuilder + get_record_batch_message_builder(const sparrow::record_batch& record_batch); +} \ No newline at end of file diff --git a/include/sparrow_ipc/memory_output_stream.hpp b/include/sparrow_ipc/memory_output_stream.hpp new file mode 100644 index 0000000..768b672 --- /dev/null +++ b/include/sparrow_ipc/memory_output_stream.hpp @@ -0,0 +1,64 @@ +#include +#include +#include + +#include "sparrow_ipc/output_stream.hpp" + +namespace sparrow_ipc +{ + template + requires std::ranges::random_access_range && std::same_as + class memory_output_stream final : public output_stream + { + public: + + memory_output_stream(R& buffer) + : m_buffer(&buffer) {}; + + std::size_t write(std::span span) override + { + m_buffer->insert(m_buffer->end(), span.begin(), span.end()); + return span.size(); + } + + std::size_t write(uint8_t value, std::size_t count) override + { + m_buffer->insert(m_buffer->end(), count, value); + return count; + } + + void reserve(std::size_t size) override + { + m_buffer->reserve(size); + } + + void reserve(const std::function& calculate_reserve_size) override + { + m_buffer->reserve(calculate_reserve_size()); + } + + size_t size() const override + { + return m_buffer->size(); + } + + void flush() override + { + // Implementation for flushing memory + } + + void close() override + { + // Implementation for closing the stream + } + + bool is_open() const override + { + return true; + } + + private: + + R* m_buffer; + }; +} diff --git a/include/sparrow_ipc/output_stream.hpp b/include/sparrow_ipc/output_stream.hpp new file mode 100644 index 0000000..6a36bf8 --- /dev/null +++ b/include/sparrow_ipc/output_stream.hpp @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include + +#include "sparrow_ipc/config/config.hpp" + +namespace sparrow_ipc +{ + /** + * @brief Abstract interface for output streams used in sparrow-ipc serialization. + * + * This interface provides a generic way to write binary data during serialization + * operations. Implementations can target different destinations such as files, + * memory buffers, network streams, etc. + */ + class SPARROW_IPC_API output_stream + { + public: + + virtual ~output_stream() = default; + + /** + * @brief Writes a span of bytes to the output stream. + * + * This method attempts to write all bytes from the provided span to the + * underlying destination. It returns the number of bytes actually written, + * which may be less than the requested size in case of errors or partial writes. + * + * @param span A span of bytes to write + * @return Number of bytes successfully written + * @throws std::runtime_error if a write error occurs + */ + virtual std::size_t write(std::span span) = 0; + + virtual std::size_t write(uint8_t value, std::size_t count = 1) = 0; + + void add_padding() + { + const size_t current_size = size(); + const size_t padding_needed = (8 - (current_size % 8)) % 8; + if (padding_needed > 0) + { + write(uint8_t{0}, padding_needed); + } + } + + /** + * @brief Reserves capacity in the output stream if supported. + * + * This is a hint to the implementation that at least `size` bytes + * will be written. Implementations may use this to optimize memory + * allocation or buffer management. + * + * @param size Number of bytes to reserve + * @return true if reservation was successful or not needed, false otherwise + */ + virtual void reserve(std::size_t size) = 0; + + virtual void reserve(const std::function& calculate_reserve_size) = 0; + + virtual size_t size() const = 0; + + /** + * @brief Flushes any buffered data to the underlying destination. + * + * Ensures that all previously written data is committed to the + * underlying storage or transmitted. + * + * @throws std::runtime_error if flush operation fails + */ + virtual void flush() = 0; + + /** + * @brief Closes the output stream and releases any resources. + * + * After calling close(), no further operations should be performed + * on the stream. After calling close(), is_open() should return false. + * Multiple calls to close() should be safe. + * + * @throws std::runtime_error if close operation fails + */ + virtual void close() = 0; + + /** + * @brief Checks if the stream is still open and writable. + * + * @return true if the stream is open, false if closed or in error state + */ + virtual bool is_open() const = 0; + + // Convenience method for writing single bytes + std::size_t write(std::uint8_t byte) + { + return write(std::span{&byte, 1}); + } + }; +} // namespace sparrow_ipc diff --git a/include/sparrow_ipc/serialize.hpp b/include/sparrow_ipc/serialize.hpp index 1ab8003..b35770c 100644 --- a/include/sparrow_ipc/serialize.hpp +++ b/include/sparrow_ipc/serialize.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -9,6 +8,7 @@ #include "Message_generated.h" #include "sparrow_ipc/config/config.hpp" #include "sparrow_ipc/magic_values.hpp" +#include "sparrow_ipc/output_stream.hpp" #include "sparrow_ipc/serialize_utils.hpp" #include "sparrow_ipc/utils.hpp" @@ -26,9 +26,7 @@ namespace sparrow_ipc * @tparam R Container type that holds record batches (must support empty(), operator[], begin(), end()) * @param record_batches Collection of record batches to serialize. All batches must have identical * schemas. - * - * @return std::vector Binary serialized data containing schema, record batches, and - * end-of-stream marker. Returns empty vector if input collection is empty. + * @param stream The output stream where the serialized data will be written. * * @throws std::invalid_argument If record batches have inconsistent schemas or if the collection * contains batches that cannot be serialized together. @@ -38,27 +36,60 @@ namespace sparrow_ipc */ template requires std::same_as, sparrow::record_batch> - std::vector serialize(const R& record_batches) + void serialize_record_batches_to_ipc_stream(const R& record_batches, output_stream& stream) { if (record_batches.empty()) { - return {}; + return; } + if (!utils::check_record_batches_consistency(record_batches)) { throw std::invalid_argument( "All record batches must have the same schema to be serialized together." ); } - std::vector serialized_schema = serialize_schema_message(record_batches[0]); - std::vector serialized_record_batches = serialize_record_batches_without_schema_message(record_batches); - serialized_schema.insert( - serialized_schema.end(), - std::make_move_iterator(serialized_record_batches.begin()), - std::make_move_iterator(serialized_record_batches.end()) - ); - // End of stream message - serialized_schema.insert(serialized_schema.end(), end_of_stream.begin(), end_of_stream.end()); - return serialized_schema; + serialize_schema_message(record_batches[0], stream); + for (const auto& rb : record_batches) + { + serialize_record_batch(rb, stream); + } + stream.write(end_of_stream); } + + /** + * @brief Serializes a record batch into a binary format following the Arrow IPC specification. + * + * This function converts a sparrow record batch into a serialized byte vector that includes: + * - A continuation marker + * - The record batch message length (4 bytes) + * - The flatbuffer-encoded record batch metadata + * - Padding to align to 8-byte boundaries + * - The record batch body containing the actual data buffers + * + * @param record_batch The sparrow record batch to serialize + * @param stream The output stream where the serialized record batch will be written + * + * @note The output follows Arrow IPC message format with proper alignment and + * includes both metadata and data portions of the record batch + */ + SPARROW_IPC_API void + serialize_record_batch(const sparrow::record_batch& record_batch, output_stream& stream); + + /** + * @brief Serializes a schema message for a record batch into a byte buffer. + * + * This function creates a serialized schema message following the Arrow IPC format. + * The resulting buffer contains: + * 1. Continuation bytes at the beginning + * 2. A 4-byte length prefix indicating the size of the schema message + * 3. The actual FlatBuffer schema message bytes + * 4. Padding bytes to align the total size to 8-byte boundaries + * + * @param record_batch The record batch containing the schema to serialize + * @param stream The output stream where the serialized schema message will be written + */ + SPARROW_IPC_API void + serialize_schema_message(const sparrow::record_batch& record_batch, output_stream& stream); + } diff --git a/include/sparrow_ipc/serialize_utils.hpp b/include/sparrow_ipc/serialize_utils.hpp index 9ead8ea..0ae9832 100644 --- a/include/sparrow_ipc/serialize_utils.hpp +++ b/include/sparrow_ipc/serialize_utils.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -8,7 +7,7 @@ #include "Message_generated.h" #include "sparrow_ipc/config/config.hpp" -#include "sparrow_ipc/magic_values.hpp" +#include "sparrow_ipc/output_stream.hpp" #include "sparrow_ipc/utils.hpp" namespace sparrow_ipc @@ -21,11 +20,10 @@ namespace sparrow_ipc * The resulting format follows the Arrow IPC specification for schema messages. * * @param record_batch The record batch containing the schema to be serialized - * @return std::vector A byte vector containing the complete serialized schema message - * with continuation bytes, 4-byte length prefix, schema data, and 8-byte alignment padding + * @param stream The output stream where the serialized schema message will be written */ - [[nodiscard]] SPARROW_IPC_API std::vector - serialize_schema_message(const sparrow::record_batch& record_batch); + SPARROW_IPC_API void + serialize_schema_message(const sparrow::record_batch& record_batch, output_stream& stream); /** * @brief Serializes a record batch into a binary format following the Arrow IPC specification. @@ -41,236 +39,83 @@ namespace sparrow_ipc * consists of a metadata section followed by a body section containing the actual data. * * @param record_batch The sparrow record batch to be serialized - * @return std::vector A byte vector containing the complete serialized record batch - * in Arrow IPC format, ready for transmission or storage + * @param stream The output stream where the serialized record batch will be written */ - [[nodiscard]] SPARROW_IPC_API std::vector - serialize_record_batch(const sparrow::record_batch& record_batch); - - template - requires std::same_as, sparrow::record_batch> - /** - * @brief Serializes a collection of record batches into a single byte vector. - * - * This function takes a range or container of record batches and serializes each one - * individually, then concatenates all the serialized data into a single output vector. - * The serialization is performed by calling serialize_record_batch() for each record batch - * in the input collection. - * - * @tparam R The type of the record batch container/range (must be iterable) - * @param record_batches A collection of record batches to be serialized - * @return std::vector A byte vector containing the serialized data of all record batches - * - * @note The function uses move iterators to efficiently transfer the serialized data - * from individual record batches to the output vector. - */ - [[nodiscard]] std::vector serialize_record_batches_without_schema_message(const R& record_batches) - { - std::vector output; - for (const auto& record_batch : record_batches) - { - const auto rb_serialized = serialize_record_batch(record_batch); - output.insert( - output.end(), - std::make_move_iterator(rb_serialized.begin()), - std::make_move_iterator(rb_serialized.end()) - ); - } - return output; - } - - /** - * @brief Creates a FlatBuffers vector of KeyValue pairs from ArrowSchema metadata. - * - * This function converts metadata from an ArrowSchema into a FlatBuffers representation - * suitable for serialization. It processes key-value pairs from the schema's metadata - * and creates corresponding FlatBuffers KeyValue objects. - * - * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffers objects - * @param arrow_schema The ArrowSchema containing metadata to be serialized - * - * @return A FlatBuffers offset to a vector of KeyValue pairs. Returns 0 if the schema - * has no metadata (metadata is nullptr). - * - * @note The function reserves memory for the vector based on the metadata size for - * optimal performance. - */ - [[nodiscard]] SPARROW_IPC_API - flatbuffers::Offset>> - create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); - - /** - * @brief Creates a FlatBuffer Field object from an ArrowSchema. - * - * This function converts an ArrowSchema structure into a FlatBuffer Field representation - * suitable for Apache Arrow IPC serialization. It handles the creation of all necessary - * components including field name, type information, metadata, children, and nullable flag. - * - * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffer objects - * @param arrow_schema The ArrowSchema structure containing the field definition to convert - * - * @return A FlatBuffer offset to the created Field object that can be used in further - * FlatBuffer construction operations - * - * @note Dictionary encoding is not currently supported (TODO item) - * @note The function checks the NULLABLE flag from the ArrowSchema flags to determine nullability - */ - [[nodiscard]] SPARROW_IPC_API ::flatbuffers::Offset - create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + SPARROW_IPC_API void + serialize_record_batch(const sparrow::record_batch& record_batch, output_stream& stream); /** - * @brief Creates a FlatBuffers vector of Field objects from an ArrowSchema's children. + * @brief Calculates the total serialized size of a schema message. * - * This function iterates through all children of the given ArrowSchema and converts - * each child to a FlatBuffers Field object. The resulting fields are collected into - * a FlatBuffers vector. + * This function computes the complete size that would be produced by serialize_schema_message(), + * including: + * - Continuation bytes (4 bytes) + * - Message length prefix (4 bytes) + * - FlatBuffer schema message data + * - Padding to 8-byte alignment * - * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffers objects - * @param arrow_schema The ArrowSchema containing the children to convert - * - * @return A FlatBuffers offset to a vector of Field objects, or 0 if no children exist - * - * @throws std::invalid_argument If any child pointer in the ArrowSchema is null - * - * @note The function reserves space for all children upfront for performance optimization - * @note Returns 0 (null offset) when the schema has no children, otherwise returns a valid vector offset + * @param record_batch The record batch containing the schema to be measured + * @return The total size in bytes that the serialized schema message would occupy */ - [[nodiscard]] SPARROW_IPC_API ::flatbuffers::Offset< - ::flatbuffers::Vector<::flatbuffers::Offset>> - create_children(flatbuffers::FlatBufferBuilder& builder, sparrow::record_batch::column_range columns); + [[nodiscard]] SPARROW_IPC_API std::size_t + calculate_schema_message_size(const sparrow::record_batch& record_batch); /** - * @brief Creates a FlatBuffers vector of Field objects from a range of columns. - * - * This function iterates through the provided column range, extracts the Arrow schema - * from each column's proxy, and creates corresponding FlatBuffers Field objects. - * The resulting fields are collected into a vector and converted to a FlatBuffers - * vector offset. - * - * @param builder Reference to the FlatBuffers builder used for creating the vector - * @param columns Range of columns to process, each containing an Arrow schema proxy + * @brief Calculates the total serialized size of a record batch message. * - * @return FlatBuffers offset to a vector of Field objects, or 0 if the input range is empty + * This function computes the complete size that would be produced by serialize_record_batch(), + * including: + * - Continuation bytes (4 bytes) + * - Message length prefix (4 bytes) + * - FlatBuffer record batch metadata + * - Padding to 8-byte alignment after metadata + * - Body data with 8-byte alignment between buffers * - * @note The function reserves space in the children vector based on the column count - * for performance optimization + * @param record_batch The record batch to be measured + * @return The total size in bytes that the serialized record batch would occupy */ - [[nodiscard]] SPARROW_IPC_API ::flatbuffers::Offset< - ::flatbuffers::Vector<::flatbuffers::Offset>> - create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + [[nodiscard]] SPARROW_IPC_API std::size_t + calculate_record_batch_message_size(const sparrow::record_batch& record_batch); /** - * @brief Creates a FlatBuffer builder containing a serialized Arrow schema message. + * @brief Calculates the total serialized size for a collection of record batches. * - * This function constructs an Arrow IPC schema message from a record batch by: - * 1. Creating field definitions from the record batch columns - * 2. Building a Schema flatbuffer with little-endian byte order - * 3. Wrapping the schema in a Message with metadata version V5 - * 4. Finalizing the buffer for serialization + * This function computes the complete size that would be produced by serializing + * a schema message followed by all record batch messages in the collection. * - * @param record_batch The source record batch containing column definitions - * @return flatbuffers::FlatBufferBuilder A completed FlatBuffer containing the schema message, - * ready for Arrow IPC serialization - * - * @note The schema message has zero body length as it contains only metadata - * @note Currently uses little-endian byte order (marked as TODO for configurability) - */ - [[nodiscard]] SPARROW_IPC_API flatbuffers::FlatBufferBuilder - get_schema_message_builder(const sparrow::record_batch& record_batch); - - /** - * @brief Serializes a schema message for a record batch into a byte buffer. - * - * This function creates a serialized schema message following the Arrow IPC format. - * The resulting buffer contains: - * 1. Continuation bytes at the beginning - * 2. A 4-byte length prefix indicating the size of the schema message - * 3. The actual FlatBuffer schema message bytes - * 4. Padding bytes to align the total size to 8-byte boundaries - * - * @param record_batch The record batch containing the schema to serialize - * @return std::vector A byte buffer containing the complete serialized schema message + * @tparam R Range type containing sparrow::record_batch objects + * @param record_batches Collection of record batches to be measured + * @return The total size in bytes for the complete serialized output + * @throws std::invalid_argument if record batches have inconsistent schemas */ - [[nodiscard]] SPARROW_IPC_API std::vector - serialize_schema_message(const sparrow::record_batch& record_batch); + template + requires std::same_as, sparrow::record_batch> + [[nodiscard]] std::size_t calculate_total_serialized_size(const R& record_batches) + { + if (record_batches.empty()) + { + return 0; + } - /** - * @brief Recursively fills a vector of FieldNode objects from an arrow_proxy and its children. - * - * This function creates FieldNode objects containing length and null count information - * from the given arrow_proxy and recursively processes all its children, appending - * them to the provided nodes vector in depth-first order. - * - * @param arrow_proxy The arrow proxy object containing array metadata (length, null_count) - * and potential child arrays - * @param nodes Reference to a vector that will be populated with FieldNode objects. - * Each FieldNode contains the length and null count of the corresponding array. - * - * @note The function reserves space in the nodes vector to optimize memory allocation - * when processing children arrays. - * @note The traversal order is depth-first, with parent nodes added before their children. - */ - SPARROW_IPC_API void fill_fieldnodes( - const sparrow::arrow_proxy& arrow_proxy, - std::vector& nodes - ); + if (!utils::check_record_batches_consistency(record_batches)) + { + throw std::invalid_argument("Record batches have inconsistent schemas"); + } - /** - * @brief Creates a vector of Apache Arrow FieldNode objects from a record batch. - * - * This function iterates through all columns in the provided record batch and - * generates corresponding FieldNode flatbuffer objects. Each column's arrow proxy - * is used to populate the field nodes vector through the fill_fieldnodes function. - * - * @param record_batch The sparrow record batch containing columns to process - * @return std::vector Vector of FieldNode - * objects representing the structure and metadata of each column - */ - [[nodiscard]] SPARROW_IPC_API std::vector - create_fieldnodes(const sparrow::record_batch& record_batch); + // Calculate schema message size (only once) + std::size_t total_size = calculate_schema_message_size(record_batches[0]); - /** - * @brief Recursively fills a vector of FlatBuffer Buffer objects with buffer information from an Arrow - * proxy. - * - * This function traverses an Arrow proxy structure and creates FlatBuffer Buffer entries for each buffer - * found in the proxy and its children. The buffers are processed in a depth-first manner, first handling - * the buffers of the current proxy, then recursively processing all child proxies. - * - * @param arrow_proxy The Arrow proxy object containing buffers and potential child proxies to process - * @param flatbuf_buffers Vector of FlatBuffer Buffer objects to be populated with buffer information - * @param offset Reference to the current byte offset, updated as buffers are processed and aligned to - * 8-byte boundaries - * - * @note The offset is automatically aligned to 8-byte boundaries using utils::align_to_8() for each - * buffer - * @note This function modifies both the flatbuf_buffers vector and the offset parameter - */ - SPARROW_IPC_API void fill_buffers( - const sparrow::arrow_proxy& arrow_proxy, - std::vector& flatbuf_buffers, - int64_t& offset - ); + // Calculate record batch message sizes + for (const auto& record_batch : record_batches) + { + total_size += calculate_record_batch_message_size(record_batch); + } - /** - * @brief Extracts buffer information from a record batch for serialization. - * - * This function iterates through all columns in the provided record batch and - * collects their buffer information into a vector of Arrow FlatBuffer Buffer objects. - * The buffers are processed sequentially with cumulative offset tracking. - * - * @param record_batch The sparrow record batch containing columns to extract buffers from - * @return std::vector A vector containing all buffer - * descriptors from the record batch columns, with properly calculated offsets - * - * @note This function relies on the fill_buffers helper function to process individual - * column buffers and maintain offset consistency across all buffers. - */ - [[nodiscard]] SPARROW_IPC_API std::vector - get_buffers(const sparrow::record_batch& record_batch); + return total_size; + } /** - * @brief Fills the body vector with buffer data from an arrow proxy and its children. + * @brief Fills the body vector with serialized data from an arrow proxy and its children. * * This function recursively processes an arrow proxy by: * 1. Iterating through all buffers in the proxy and appending their data to the body vector @@ -282,9 +127,9 @@ namespace sparrow_ipc * format compliance. * * @param arrow_proxy The arrow proxy containing buffers and potential child proxies to serialize - * @param body Reference to the vector where the serialized buffer data will be appended + * @param stream The output stream where the serialized body data will be written */ - SPARROW_IPC_API void fill_body(const sparrow::arrow_proxy& arrow_proxy, std::vector& body); + SPARROW_IPC_API void fill_body(const sparrow::arrow_proxy& arrow_proxy, output_stream& stream); /** * @brief Generates a serialized body from a record batch. @@ -294,9 +139,9 @@ namespace sparrow_ipc * single byte vector that forms the body of the serialized data. * * @param record_batch The record batch containing columns to be serialized - * @return std::vector A byte vector containing the serialized body data + * @param stream The output stream where the serialized body will be written */ - [[nodiscard]] SPARROW_IPC_API std::vector generate_body(const sparrow::record_batch& record_batch); + SPARROW_IPC_API void generate_body(const sparrow::record_batch& record_batch, output_stream& stream); /** * @brief Calculates the total size of the body section for an Arrow array. @@ -322,60 +167,17 @@ namespace sparrow_ipc */ [[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::record_batch& record_batch); - /** - * @brief Creates a FlatBuffer message containing a serialized Apache Arrow RecordBatch. - * - * This function builds a complete Arrow IPC message by serializing a record batch - * along with its metadata (field nodes and buffer information) into a FlatBuffer - * format that conforms to the Arrow IPC specification. - * - * @param record_batch The source record batch containing the data to be serialized - * @param nodes Vector of field nodes describing the structure and null counts of columns - * @param buffers Vector of buffer descriptors containing offset and length information - * for the data buffers - * - * @return A FlatBufferBuilder containing the complete serialized message ready for - * transmission or storage. The builder is finished and ready to be accessed - * via GetBufferPointer() and GetSize(). - * - * @note The returned message uses Arrow IPC format version V5 - * @note Compression and variadic buffer counts are not currently implemented (set to 0) - * @note The body size is automatically calculated based on the record batch contents - */ - [[nodiscard]] SPARROW_IPC_API flatbuffers::FlatBufferBuilder get_record_batch_message_builder( - const sparrow::record_batch& record_batch, - const std::vector& nodes, - const std::vector& buffers - ); - - /** - * @brief Serializes a record batch into a binary format following the Arrow IPC specification. - * - * This function converts a sparrow record batch into a serialized byte vector that includes: - * - A continuation marker - * - The record batch message length (4 bytes) - * - The flatbuffer-encoded record batch metadata - * - Padding to align to 8-byte boundaries - * - The record batch body containing the actual data buffers - * - * @param record_batch The sparrow record batch to serialize - * @return std::vector A byte vector containing the serialized record batch - * in Arrow IPC format, ready for transmission or storage - * - * @note The output follows Arrow IPC message format with proper alignment and - * includes both metadata and data portions of the record batch - */ - [[nodiscard]] SPARROW_IPC_API std::vector - serialize_record_batch(const sparrow::record_batch& record_batch); /** - * @brief Adds padding bytes to a buffer to ensure 8-byte alignment. + * @brief Adds padding bytes to an output stream to ensure 8-byte alignment. * - * This function appends zero bytes to the end of the provided buffer until + * This function appends zero bytes to the end of the provided stream until * its size is a multiple of 8. This is often required for proper memory * alignment in binary formats such as Apache Arrow IPC. * - * @param buffer The byte vector to which padding will be added + * @param stream The output stream where padding bytes will be added */ - void add_padding(std::vector& buffer); + void add_padding(output_stream& stream); + + std::vector get_column_dtypes(const sparrow::record_batch& rb); } diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp new file mode 100644 index 0000000..f4e0fd5 --- /dev/null +++ b/include/sparrow_ipc/serializer.hpp @@ -0,0 +1,136 @@ +#include +#include + +#include + +#include "sparrow_ipc/output_stream.hpp" +#include "sparrow_ipc/serialize_utils.hpp" + +namespace sparrow_ipc +{ + /** + * @brief A class for serializing Apache Arrow record batches to an output stream. + * + * The serializer class provides functionality to serialize single or multiple record batches + * into a binary format suitable for storage or transmission. It ensures schema consistency + * across multiple record batches and optimizes memory allocation by pre-calculating required + * buffer sizes. + * + * @details The serializer supports two main usage patterns: + * 1. Construction with a collection of record batches for batch serialization + * 2. Construction with a single record batch followed by incremental appends + * + * The class validates that all record batches have consistent schemas and throws + * std::invalid_argument if inconsistencies are detected or if an empty collection + * is provided. + * + * Memory efficiency is achieved through: + * - Pre-calculation of total serialization size + * - Stream reservation to minimize memory reallocations + * - Lazy evaluation of size calculations using lambda functions + */ + class serializer + { + public: + + serializer(const sparrow::record_batch& rb, output_stream& stream); + + template + requires std::same_as, sparrow::record_batch> + serializer(const R& record_batches, output_stream& stream) + : m_pstream(&stream) + , m_dtypes(get_column_dtypes(record_batches[0])) + { + if (record_batches.empty()) + { + throw std::invalid_argument("Record batches collection is empty"); + } + + const auto reserve_function = [&record_batches]() + { + return calculate_schema_message_size(record_batches[0]) + + std::accumulate( + record_batches.cbegin(), + record_batches.cend(), + 0, + [](size_t acc, const sparrow::record_batch& rb) + { + return acc + calculate_record_batch_message_size(rb); + } + ); + }; + m_pstream->reserve(reserve_function); + serialize_schema_message(record_batches[0], *m_pstream); + append(record_batches); + } + + /** + * Appends a record batch to the serializer. + * + * @param rb The record batch to append to the serializer + */ + void append(const sparrow::record_batch& rb); + + /** + * @brief Appends a collection of record batches to the stream. + * + * This method efficiently adds multiple record batches to the serialization stream + * by first calculating the total required size and reserving memory space to minimize + * reallocations during the append operations. + * + * @tparam R The type of the record batch collection (must be iterable) + * @param record_batches A collection of record batches to append to the stream + * + * The method performs the following operations: + * 1. Calculates the total size needed for all record batches + * 2. Reserves the required memory space in the stream + * 3. Iterates through each record batch and adds it to the stream + */ + template + requires std::same_as, sparrow::record_batch> + void append(const R& record_batches) + { + if (m_ended) + { + throw std::runtime_error("Cannot append to a serializer that has been ended"); + } + const auto reserve_function = [&record_batches, this]() + { + return std::accumulate( + record_batches.cbegin(), + record_batches.cend(), + m_pstream->size(), + [this](size_t acc, const sparrow::record_batch& rb) + { + return acc + calculate_record_batch_message_size(rb); + } + ); + }; + m_pstream->reserve(reserve_function); + for (const auto& rb : record_batches) + { + serialize_record_batch(rb, *m_pstream); + } + } + + /** + * @brief Finalizes the serialization process by writing end-of-stream marker. + * + * This method writes an end-of-stream marker to the output stream and flushes + * any buffered data. It can be called multiple times safely as it tracks + * whether the stream has already been ended to prevent duplicate operations. + * + * @note This method is idempotent - calling it multiple times has no additional effect. + * @post After calling this method, m_ended will be set to true. + */ + void end(); + + private: + + static std::vector get_column_dtypes(const sparrow::record_batch& rb); + + std::vector m_dtypes; + output_stream* m_pstream; + bool m_ended{false}; + }; +} \ No newline at end of file diff --git a/include/sparrow_ipc/utils.hpp b/include/sparrow_ipc/utils.hpp index 0c80f9c..3da1e54 100644 --- a/include/sparrow_ipc/utils.hpp +++ b/include/sparrow_ipc/utils.hpp @@ -3,22 +3,15 @@ #include #include #include -#include #include -#include "Schema_generated.h" #include "sparrow_ipc/config/config.hpp" namespace sparrow_ipc::utils { // Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies - SPARROW_IPC_API int64_t align_to_8(const int64_t n); - - // Creates a Flatbuffers type from a format string - // This function maps a sparrow data type to the corresponding Flatbuffers type - SPARROW_IPC_API std::pair> - get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str); + size_t align_to_8(const size_t n); /** * @brief Checks if all record batches in a collection have consistent structure. @@ -39,7 +32,7 @@ namespace sparrow_ipc::utils requires std::same_as, sparrow::record_batch> bool check_record_batches_consistency(const R& record_batches) { - if (record_batches.empty()) + if (record_batches.empty() || record_batches.size() == 1) { return true; } @@ -67,5 +60,8 @@ namespace sparrow_ipc::utils return true; } + // Parse the format string + // The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc + std::optional parse_format(std::string_view format_str, std::string_view sep); // size_t calculate_output_serialized_size(const sparrow::record_batch& record_batch); } diff --git a/src/chunk_memory_serializer.cpp b/src/chunk_memory_serializer.cpp new file mode 100644 index 0000000..46e073d --- /dev/null +++ b/src/chunk_memory_serializer.cpp @@ -0,0 +1,54 @@ +#include "sparrow_ipc/chunk_memory_serializer.hpp" + +#include "sparrow_ipc/serialize.hpp" +#include "sparrow_ipc/serialize_utils.hpp" + +namespace sparrow_ipc +{ + chunk_serializer::chunk_serializer( + const sparrow::record_batch& rb, + chuncked_memory_output_stream>>& stream + ) + : m_pstream(&stream) + , m_dtypes(get_column_dtypes(rb)) + { + m_pstream->reserve(2); + std::vector schema_buffer; + memory_output_stream schema_stream(schema_buffer); + serialize_schema_message(rb, schema_stream); + m_pstream->write(std::move(schema_buffer)); + + std::vector batch_buffer; + memory_output_stream batch_stream(batch_buffer); + serialize_record_batch(rb, batch_stream); + m_pstream->write(std::move(batch_buffer)); + } + + void chunk_serializer::append(const sparrow::record_batch& rb) + { + if (m_ended) + { + throw std::runtime_error("Cannot append to a serializer that has been ended"); + } + if (get_column_dtypes(rb) != m_dtypes) + { + throw std::invalid_argument("Record batch has different schema than previous ones"); + } + m_pstream->reserve(m_pstream->size() + 1); + std::vector buffer; + memory_output_stream stream(buffer); + serialize_record_batch(rb, stream); + m_pstream->write(std::move(buffer)); + } + + void chunk_serializer::end() + { + if (m_ended) + { + return; + } + std::vector buffer(end_of_stream.begin(), end_of_stream.end()); + m_pstream->write(std::move(buffer)); + m_ended = true; + } +} diff --git a/src/file_output_stream.cpp b/src/file_output_stream.cpp new file mode 100644 index 0000000..14dbff5 --- /dev/null +++ b/src/file_output_stream.cpp @@ -0,0 +1,57 @@ +#include "sparrow_ipc/file_output_stream.hpp" + +namespace sparrow_ipc +{ + file_output_stream::file_output_stream(std::ofstream& file) + : m_file(file) + { + if (!m_file.is_open()) + { + throw std::runtime_error("Failed to open file stream"); + } + } + + std::size_t file_output_stream::write(std::span span) + { + m_file.write(reinterpret_cast(span.data()), span.size()); + m_written_bytes += span.size(); + return span.size(); + } + + std::size_t file_output_stream::write(uint8_t value, std::size_t count) + { + std::fill_n(std::ostreambuf_iterator(m_file), count, value); + m_written_bytes += count; + return count; + } + + size_t file_output_stream::size() const + { + return m_written_bytes; + } + + void file_output_stream::reserve(std::size_t size) + { + // File streams do not support reservation + } + + void file_output_stream::reserve(const std::function& calculate_reserve_size) + { + // File streams do not support reservation + } + + void file_output_stream::flush() + { + m_file.flush(); + } + + void file_output_stream::close() + { + m_file.close(); + } + + bool file_output_stream::is_open() const + { + return m_file.is_open(); + } +} \ No newline at end of file diff --git a/src/flatbuffer_utils.cpp b/src/flatbuffer_utils.cpp new file mode 100644 index 0000000..91d8306 --- /dev/null +++ b/src/flatbuffer_utils.cpp @@ -0,0 +1,586 @@ +#include "sparrow_ipc/flatbuffer_utils.hpp" + +#include "sparrow_ipc/serialize_utils.hpp" +#include "sparrow_ipc/utils.hpp" + +namespace sparrow_ipc +{ + std::pair> + get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str) + { + const auto type = sparrow::format_to_data_type(format_str); + switch (type) + { + case sparrow::data_type::NA: + { + const auto null_type = org::apache::arrow::flatbuf::CreateNull(builder); + return {org::apache::arrow::flatbuf::Type::Null, null_type.Union()}; + } + case sparrow::data_type::BOOL: + { + const auto bool_type = org::apache::arrow::flatbuf::CreateBool(builder); + return {org::apache::arrow::flatbuf::Type::Bool, bool_type.Union()}; + } + case sparrow::data_type::UINT8: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT8: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT16: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT16: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT32: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT32: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT64: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT64: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::HALF_FLOAT: + { + const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, + org::apache::arrow::flatbuf::Precision::HALF + ); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::FLOAT: + { + const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, + org::apache::arrow::flatbuf::Precision::SINGLE + ); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::DOUBLE: + { + const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, + org::apache::arrow::flatbuf::Precision::DOUBLE + ); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::STRING: + { + const auto string_type = org::apache::arrow::flatbuf::CreateUtf8(builder); + return {org::apache::arrow::flatbuf::Type::Utf8, string_type.Union()}; + } + case sparrow::data_type::LARGE_STRING: + { + const auto large_string_type = org::apache::arrow::flatbuf::CreateLargeUtf8(builder); + return {org::apache::arrow::flatbuf::Type::LargeUtf8, large_string_type.Union()}; + } + case sparrow::data_type::BINARY: + { + const auto binary_type = org::apache::arrow::flatbuf::CreateBinary(builder); + return {org::apache::arrow::flatbuf::Type::Binary, binary_type.Union()}; + } + case sparrow::data_type::LARGE_BINARY: + { + const auto large_binary_type = org::apache::arrow::flatbuf::CreateLargeBinary(builder); + return {org::apache::arrow::flatbuf::Type::LargeBinary, large_binary_type.Union()}; + } + case sparrow::data_type::STRING_VIEW: + { + const auto string_view_type = org::apache::arrow::flatbuf::CreateUtf8View(builder); + return {org::apache::arrow::flatbuf::Type::Utf8View, string_view_type.Union()}; + } + case sparrow::data_type::BINARY_VIEW: + { + const auto binary_view_type = org::apache::arrow::flatbuf::CreateBinaryView(builder); + return {org::apache::arrow::flatbuf::Type::BinaryView, binary_view_type.Union()}; + } + case sparrow::data_type::DATE_DAYS: + { + const auto date_type = org::apache::arrow::flatbuf::CreateDate( + builder, + org::apache::arrow::flatbuf::DateUnit::DAY + ); + return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; + } + case sparrow::data_type::DATE_MILLISECONDS: + { + const auto date_type = org::apache::arrow::flatbuf::CreateDate( + builder, + org::apache::arrow::flatbuf::DateUnit::MILLISECOND + ); + return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_SECONDS: + { + const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( + builder, + org::apache::arrow::flatbuf::TimeUnit::SECOND + ); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_MILLISECONDS: + { + const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( + builder, + org::apache::arrow::flatbuf::TimeUnit::MILLISECOND + ); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_MICROSECONDS: + { + const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( + builder, + org::apache::arrow::flatbuf::TimeUnit::MICROSECOND + ); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_NANOSECONDS: + { + const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( + builder, + org::apache::arrow::flatbuf::TimeUnit::NANOSECOND + ); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::DURATION_SECONDS: + { + const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( + builder, + org::apache::arrow::flatbuf::TimeUnit::SECOND + ); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_MILLISECONDS: + { + const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( + builder, + org::apache::arrow::flatbuf::TimeUnit::MILLISECOND + ); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_MICROSECONDS: + { + const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( + builder, + org::apache::arrow::flatbuf::TimeUnit::MICROSECOND + ); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_NANOSECONDS: + { + const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( + builder, + org::apache::arrow::flatbuf::TimeUnit::NANOSECOND + ); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::INTERVAL_MONTHS: + { + const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( + builder, + org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH + ); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::INTERVAL_DAYS_TIME: + { + const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( + builder, + org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME + ); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS: + { + const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( + builder, + org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO + ); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::TIME_SECONDS: + { + const auto time_type = org::apache::arrow::flatbuf::CreateTime( + builder, + org::apache::arrow::flatbuf::TimeUnit::SECOND, + 32 + ); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_MILLISECONDS: + { + const auto time_type = org::apache::arrow::flatbuf::CreateTime( + builder, + org::apache::arrow::flatbuf::TimeUnit::MILLISECOND, + 32 + ); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_MICROSECONDS: + { + const auto time_type = org::apache::arrow::flatbuf::CreateTime( + builder, + org::apache::arrow::flatbuf::TimeUnit::MICROSECOND, + 64 + ); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_NANOSECONDS: + { + const auto time_type = org::apache::arrow::flatbuf::CreateTime( + builder, + org::apache::arrow::flatbuf::TimeUnit::NANOSECOND, + 64 + ); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::LIST: + { + const auto list_type = org::apache::arrow::flatbuf::CreateList(builder); + return {org::apache::arrow::flatbuf::Type::List, list_type.Union()}; + } + case sparrow::data_type::LARGE_LIST: + { + const auto large_list_type = org::apache::arrow::flatbuf::CreateLargeList(builder); + return {org::apache::arrow::flatbuf::Type::LargeList, large_list_type.Union()}; + } + case sparrow::data_type::LIST_VIEW: + { + const auto list_view_type = org::apache::arrow::flatbuf::CreateListView(builder); + return {org::apache::arrow::flatbuf::Type::ListView, list_view_type.Union()}; + } + case sparrow::data_type::LARGE_LIST_VIEW: + { + const auto large_list_view_type = org::apache::arrow::flatbuf::CreateLargeListView(builder); + return {org::apache::arrow::flatbuf::Type::LargeListView, large_list_view_type.Union()}; + } + case sparrow::data_type::FIXED_SIZED_LIST: + { + // FixedSizeList requires listSize. We need to parse the format_str. + // Format: "+w:size" + const auto list_size = utils::parse_format(format_str, ":"); + if (!list_size.has_value()) + { + throw std::runtime_error( + "Failed to parse FixedSizeList size from format string: " + std::string(format_str) + ); + } + + const auto fixed_size_list_type = org::apache::arrow::flatbuf::CreateFixedSizeList( + builder, + list_size.value() + ); + return {org::apache::arrow::flatbuf::Type::FixedSizeList, fixed_size_list_type.Union()}; + } + case sparrow::data_type::STRUCT: + { + const auto struct_type = org::apache::arrow::flatbuf::CreateStruct_(builder); + return {org::apache::arrow::flatbuf::Type::Struct_, struct_type.Union()}; + } + case sparrow::data_type::MAP: + { + // not sorted keys + const auto map_type = org::apache::arrow::flatbuf::CreateMap(builder, false); + return {org::apache::arrow::flatbuf::Type::Map, map_type.Union()}; + } + case sparrow::data_type::DENSE_UNION: + { + const auto union_type = org::apache::arrow::flatbuf::CreateUnion( + builder, + org::apache::arrow::flatbuf::UnionMode::Dense, + 0 + ); + return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; + } + case sparrow::data_type::SPARSE_UNION: + { + const auto union_type = org::apache::arrow::flatbuf::CreateUnion( + builder, + org::apache::arrow::flatbuf::UnionMode::Sparse, + 0 + ); + return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; + } + case sparrow::data_type::RUN_ENCODED: + { + const auto run_end_encoded_type = org::apache::arrow::flatbuf::CreateRunEndEncoded(builder); + return {org::apache::arrow::flatbuf::Type::RunEndEncoded, run_end_encoded_type.Union()}; + } + case sparrow::data_type::DECIMAL32: + { + return get_flatbuffer_decimal_type(builder, format_str, 32); + } + case sparrow::data_type::DECIMAL64: + { + return get_flatbuffer_decimal_type(builder, format_str, 64); + } + case sparrow::data_type::DECIMAL128: + { + return get_flatbuffer_decimal_type(builder, format_str, 128); + } + case sparrow::data_type::DECIMAL256: + { + return get_flatbuffer_decimal_type(builder, format_str, 256); + } + case sparrow::data_type::FIXED_WIDTH_BINARY: + { + // FixedSizeBinary requires byteWidth. We need to parse the format_str. + // Format: "w:size" + const auto byte_width = utils::parse_format(format_str, ":"); + if (!byte_width.has_value()) + { + throw std::runtime_error( + "Failed to parse FixedWidthBinary size from format string: " + std::string(format_str) + ); + } + + const auto fixed_width_binary_type = org::apache::arrow::flatbuf::CreateFixedSizeBinary( + builder, + byte_width.value() + ); + return {org::apache::arrow::flatbuf::Type::FixedSizeBinary, fixed_width_binary_type.Union()}; + } + default: + { + throw std::runtime_error("Unsupported data type for serialization"); + } + } + } + + // Creates a Flatbuffers Decimal type from a format string + // The format string is expected to be in the format "d:precision,scale" + std::pair> get_flatbuffer_decimal_type( + flatbuffers::FlatBufferBuilder& builder, + std::string_view format_str, + const int32_t bitWidth + ) + { + // Decimal requires precision and scale. We need to parse the format_str. + // Format: "d:precision,scale" + const auto scale = utils::parse_format(format_str, ","); + if (!scale.has_value()) + { + throw std::runtime_error( + "Failed to parse Decimal " + std::to_string(bitWidth) + + " scale from format string: " + std::string(format_str) + ); + } + const size_t comma_pos = format_str.find(','); + const auto precision = utils::parse_format(format_str.substr(0, comma_pos), ":"); + if (!precision.has_value()) + { + throw std::runtime_error( + "Failed to parse Decimal " + std::to_string(bitWidth) + + " precision from format string: " + std::string(format_str) + ); + } + const auto decimal_type = org::apache::arrow::flatbuf::CreateDecimal( + builder, + precision.value(), + scale.value(), + bitWidth + ); + return {org::apache::arrow::flatbuf::Type::Decimal, decimal_type.Union()}; + } + + flatbuffers::Offset>> + create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) + { + if (arrow_schema.metadata == nullptr) + { + return 0; + } + + const auto metadata_view = sparrow::key_value_view(arrow_schema.metadata); + std::vector> kv_offsets; + kv_offsets.reserve(metadata_view.size()); + for (const auto& [key, value] : metadata_view) + { + const auto key_offset = builder.CreateString(std::string(key)); + const auto value_offset = builder.CreateString(std::string(value)); + kv_offsets.push_back(org::apache::arrow::flatbuf::CreateKeyValue(builder, key_offset, value_offset)); + } + return builder.CreateVector(kv_offsets); + } + + ::flatbuffers::Offset + create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) + { + flatbuffers::Offset fb_name_offset = (arrow_schema.name == nullptr) + ? 0 + : builder.CreateString(arrow_schema.name); + const auto [type_enum, type_offset] = get_flatbuffer_type(builder, arrow_schema.format); + auto fb_metadata_offset = create_metadata(builder, arrow_schema); + const auto children = create_children(builder, arrow_schema); + const auto fb_field = org::apache::arrow::flatbuf::CreateField( + builder, + fb_name_offset, + (arrow_schema.flags & static_cast(sparrow::ArrowFlag::NULLABLE)) != 0, + type_enum, + type_offset, + 0, // TODO: support dictionary + children, + fb_metadata_offset + ); + return fb_field; + } + + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> + create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) + { + std::vector> children_vec; + children_vec.reserve(arrow_schema.n_children); + for (size_t i = 0; i < arrow_schema.n_children; ++i) + { + if (arrow_schema.children[i] == nullptr) + { + throw std::invalid_argument("ArrowSchema has null child pointer"); + } + const auto& child = *arrow_schema.children[i]; + flatbuffers::Offset field = create_field(builder, child); + children_vec.emplace_back(field); + } + return children_vec.empty() ? 0 : builder.CreateVector(children_vec); + } + + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> + create_children(flatbuffers::FlatBufferBuilder& builder, sparrow::record_batch::column_range columns) + { + std::vector> children_vec; + children_vec.reserve(columns.size()); + for (const auto& column : columns) + { + const auto& arrow_schema = sparrow::detail::array_access::get_arrow_proxy(column).schema(); + flatbuffers::Offset field = create_field(builder, arrow_schema); + children_vec.emplace_back(field); + } + return children_vec.empty() ? 0 : builder.CreateVector(children_vec); + } + + flatbuffers::FlatBufferBuilder get_schema_message_builder(const sparrow::record_batch& record_batch) + { + flatbuffers::FlatBufferBuilder schema_builder; + const auto fields_vec = create_children(schema_builder, record_batch.columns()); + const auto schema_offset = org::apache::arrow::flatbuf::CreateSchema( + schema_builder, + org::apache::arrow::flatbuf::Endianness::Little, // TODO: make configurable + fields_vec + ); + const auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage( + schema_builder, + org::apache::arrow::flatbuf::MetadataVersion::V5, + org::apache::arrow::flatbuf::MessageHeader::Schema, + schema_offset.Union(), + 0, // body length is 0 for schema messages + 0 // custom metadata + ); + schema_builder.Finish(schema_message_offset); + return schema_builder; + } + + void fill_fieldnodes( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& nodes + ) + { + nodes.emplace_back(arrow_proxy.length(), arrow_proxy.null_count()); + nodes.reserve(nodes.size() + arrow_proxy.n_children()); + for (const auto& child : arrow_proxy.children()) + { + fill_fieldnodes(child, nodes); + } + } + + std::vector + create_fieldnodes(const sparrow::record_batch& record_batch) + { + std::vector nodes; + nodes.reserve(record_batch.columns().size()); + for (const auto& column : record_batch.columns()) + { + fill_fieldnodes(sparrow::detail::array_access::get_arrow_proxy(column), nodes); + } + return nodes; + } + + void fill_buffers( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& flatbuf_buffers, + int64_t& offset + ) + { + const auto& buffers = arrow_proxy.buffers(); + for (const auto& buffer : buffers) + { + int64_t size = static_cast(buffer.size()); + flatbuf_buffers.emplace_back(offset, size); + offset += utils::align_to_8(size); + } + for (const auto& child : arrow_proxy.children()) + { + fill_buffers(child, flatbuf_buffers, offset); + } + } + + std::vector get_buffers(const sparrow::record_batch& record_batch) + { + std::vector buffers; + std::int64_t offset = 0; + for (const auto& column : record_batch.columns()) + { + const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); + fill_buffers(arrow_proxy, buffers, offset); + } + return buffers; + } + + flatbuffers::FlatBufferBuilder get_record_batch_message_builder(const sparrow::record_batch& record_batch) + { + const std::vector nodes = create_fieldnodes(record_batch); + const std::vector buffers = get_buffers(record_batch); + flatbuffers::FlatBufferBuilder record_batch_builder; + auto nodes_offset = record_batch_builder.CreateVectorOfStructs(nodes); + auto buffers_offset = record_batch_builder.CreateVectorOfStructs(buffers); + const auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch( + record_batch_builder, + static_cast(record_batch.nb_rows()), + nodes_offset, + buffers_offset, + 0, // TODO: Compression + 0 // TODO :variadic buffer Counts + ); + + const int64_t body_size = calculate_body_size(record_batch); + const auto record_batch_message_offset = org::apache::arrow::flatbuf::CreateMessage( + record_batch_builder, + org::apache::arrow::flatbuf::MetadataVersion::V5, + org::apache::arrow::flatbuf::MessageHeader::RecordBatch, + record_batch_offset.Union(), + body_size, // body length + 0 // custom metadata + ); + record_batch_builder.Finish(record_batch_message_offset); + return record_batch_builder; + } +} diff --git a/src/serialize.cpp b/src/serialize.cpp new file mode 100644 index 0000000..ec1130b --- /dev/null +++ b/src/serialize.cpp @@ -0,0 +1,31 @@ +#include "sparrow_ipc/serialize.hpp" + +#include "sparrow_ipc/flatbuffer_utils.hpp" + +namespace sparrow_ipc +{ + void common_serialize( + const sparrow::record_batch& record_batch, + const flatbuffers::FlatBufferBuilder& builder, + output_stream& stream + ) + { + stream.write(continuation); + const flatbuffers::uoffset_t size = builder.GetSize(); + const std::span size_span(reinterpret_cast(&size), sizeof(uint32_t)); + stream.write(size_span); + stream.write(std::span(builder.GetBufferPointer(), size)); + add_padding(stream); + } + + void serialize_schema_message(const sparrow::record_batch& record_batch, output_stream& stream) + { + common_serialize(record_batch, get_schema_message_builder(record_batch), stream); + } + + void serialize_record_batch(const sparrow::record_batch& record_batch, output_stream& stream) + { + common_serialize(record_batch, get_record_batch_message_builder(record_batch), stream); + generate_body(record_batch, stream); + } +} \ No newline at end of file diff --git a/src/serialize_utils.cpp b/src/serialize_utils.cpp index ac1e026..5bcbca8 100644 --- a/src/serialize_utils.cpp +++ b/src/serialize_utils.cpp @@ -1,208 +1,30 @@ -#include - +#include "sparrow_ipc/flatbuffer_utils.hpp" #include "sparrow_ipc/magic_values.hpp" #include "sparrow_ipc/serialize.hpp" #include "sparrow_ipc/utils.hpp" namespace sparrow_ipc { - - flatbuffers::Offset>> - create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) - { - if (arrow_schema.metadata == nullptr) - { - return 0; - } - - const auto metadata_view = sparrow::key_value_view(arrow_schema.metadata); - std::vector> kv_offsets; - kv_offsets.reserve(metadata_view.size()); - for (const auto& [key, value] : metadata_view) - { - const auto key_offset = builder.CreateString(std::string(key)); - const auto value_offset = builder.CreateString(std::string(value)); - kv_offsets.push_back(org::apache::arrow::flatbuf::CreateKeyValue(builder, key_offset, value_offset)); - } - return builder.CreateVector(kv_offsets); - } - - ::flatbuffers::Offset - create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) - { - flatbuffers::Offset fb_name_offset = (arrow_schema.name == nullptr) - ? 0 - : builder.CreateString(arrow_schema.name); - const auto [type_enum, type_offset] = utils::get_flatbuffer_type(builder, arrow_schema.format); - auto fb_metadata_offset = create_metadata(builder, arrow_schema); - const auto children = create_children(builder, arrow_schema); - const auto fb_field = org::apache::arrow::flatbuf::CreateField( - builder, - fb_name_offset, - (arrow_schema.flags & static_cast(sparrow::ArrowFlag::NULLABLE)) != 0, - type_enum, - type_offset, - 0, // TODO: support dictionary - children, - fb_metadata_offset - ); - return fb_field; - } - - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> - create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) - { - std::vector> children_vec; - children_vec.reserve(arrow_schema.n_children); - for (size_t i = 0; i < arrow_schema.n_children; ++i) - { - if (arrow_schema.children[i] == nullptr) - { - throw std::invalid_argument("ArrowSchema has null child pointer"); - } - const auto& child = *arrow_schema.children[i]; - flatbuffers::Offset field = create_field(builder, child); - children_vec.emplace_back(field); - } - return children_vec.empty() ? 0 : builder.CreateVector(children_vec); - } - - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> - create_children(flatbuffers::FlatBufferBuilder& builder, sparrow::record_batch::column_range columns) - { - std::vector> children_vec; - children_vec.reserve(columns.size()); - for (const auto& column : columns) - { - const auto& arrow_schema = sparrow::detail::array_access::get_arrow_proxy(column).schema(); - flatbuffers::Offset field = create_field(builder, arrow_schema); - children_vec.emplace_back(field); - } - return children_vec.empty() ? 0 : builder.CreateVector(children_vec); - } - - flatbuffers::FlatBufferBuilder get_schema_message_builder(const sparrow::record_batch& record_batch) - { - flatbuffers::FlatBufferBuilder schema_builder; - const auto fields_vec = create_children(schema_builder, record_batch.columns()); - const auto schema_offset = org::apache::arrow::flatbuf::CreateSchema( - schema_builder, - org::apache::arrow::flatbuf::Endianness::Little, // TODO: make configurable - fields_vec - ); - const auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage( - schema_builder, - org::apache::arrow::flatbuf::MetadataVersion::V5, - org::apache::arrow::flatbuf::MessageHeader::Schema, - schema_offset.Union(), - 0, // body length is 0 for schema messages - 0 // custom metadata - ); - schema_builder.Finish(schema_message_offset); - return schema_builder; - } - - std::vector serialize_schema_message(const sparrow::record_batch& record_batch) - { - std::vector schema_buffer; - schema_buffer.insert(schema_buffer.end(), continuation.begin(), continuation.end()); - flatbuffers::FlatBufferBuilder schema_builder = get_schema_message_builder(record_batch); - const flatbuffers::uoffset_t schema_len = schema_builder.GetSize(); - schema_buffer.reserve(schema_buffer.size() + sizeof(uint32_t) + schema_len); - // Write the 4-byte length prefix after the continuation bytes - schema_buffer.insert( - schema_buffer.end(), - reinterpret_cast(&schema_len), - reinterpret_cast(&schema_len) + sizeof(uint32_t) - ); - // Append the actual message bytes - schema_buffer.insert( - schema_buffer.end(), - schema_builder.GetBufferPointer(), - schema_builder.GetBufferPointer() + schema_len - ); - add_padding(schema_buffer); - return schema_buffer; - } - - void fill_fieldnodes( - const sparrow::arrow_proxy& arrow_proxy, - std::vector& nodes - ) - { - nodes.emplace_back(arrow_proxy.length(), arrow_proxy.null_count()); - nodes.reserve(nodes.size() + arrow_proxy.n_children()); - for (const auto& child : arrow_proxy.children()) - { - fill_fieldnodes(child, nodes); - } - } - - std::vector - create_fieldnodes(const sparrow::record_batch& record_batch) - { - std::vector nodes; - nodes.reserve(record_batch.columns().size()); - for (const auto& column : record_batch.columns()) - { - fill_fieldnodes(sparrow::detail::array_access::get_arrow_proxy(column), nodes); - } - return nodes; - } - - void fill_buffers( - const sparrow::arrow_proxy& arrow_proxy, - std::vector& flatbuf_buffers, - int64_t& offset - ) - { - const auto& buffers = arrow_proxy.buffers(); - for (const auto& buffer : buffers) - { - int64_t size = static_cast(buffer.size()); - flatbuf_buffers.emplace_back(offset, size); - offset += utils::align_to_8(size); - } - for (const auto& child : arrow_proxy.children()) - { - fill_buffers(child, flatbuf_buffers, offset); - } - } - - std::vector get_buffers(const sparrow::record_batch& record_batch) - { - std::vector buffers; - std::int64_t offset = 0; - for (const auto& column : record_batch.columns()) - { - const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); - fill_buffers(arrow_proxy, buffers, offset); - } - return buffers; - } - - void fill_body(const sparrow::arrow_proxy& arrow_proxy, std::vector& body) + void fill_body(const sparrow::arrow_proxy& arrow_proxy, output_stream& stream) { for (const auto& buffer : arrow_proxy.buffers()) { - body.insert(body.end(), buffer.begin(), buffer.end()); - add_padding(body); + stream.write(buffer); + add_padding(stream); } for (const auto& child : arrow_proxy.children()) { - fill_body(child, body); + fill_body(child, stream); } } - std::vector generate_body(const sparrow::record_batch& record_batch) + void generate_body(const sparrow::record_batch& record_batch, output_stream& stream) { - std::vector body; for (const auto& column : record_batch.columns()) { const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); - fill_body(arrow_proxy, body); + fill_body(arrow_proxy, stream); } - return body; } int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy) @@ -210,7 +32,7 @@ namespace sparrow_ipc int64_t total_size = 0; for (const auto& buffer : arrow_proxy.buffers()) { - total_size += utils::align_to_8(static_cast(buffer.size())); + total_size += utils::align_to_8(buffer.size()); } for (const auto& child : arrow_proxy.children()) { @@ -221,7 +43,7 @@ namespace sparrow_ipc int64_t calculate_body_size(const sparrow::record_batch& record_batch) { - return std::accumulate( + return std::reduce( record_batch.columns().begin(), record_batch.columns().end(), 0, @@ -233,73 +55,60 @@ namespace sparrow_ipc ); } - flatbuffers::FlatBufferBuilder get_record_batch_message_builder( - const sparrow::record_batch& record_batch, - const std::vector& nodes, - const std::vector& buffers - ) + void add_padding(output_stream& stream) { - flatbuffers::FlatBufferBuilder record_batch_builder; + const size_t stream_size = stream.size(); + stream.write(0, utils::align_to_8(stream_size) - stream_size); + } - auto nodes_offset = record_batch_builder.CreateVectorOfStructs(nodes); - auto buffers_offset = record_batch_builder.CreateVectorOfStructs(buffers); - const auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch( - record_batch_builder, - static_cast(record_batch.nb_rows()), - nodes_offset, - buffers_offset, - 0, // TODO: Compression - 0 // TODO :variadic buffer Counts - ); + std::size_t calculate_schema_message_size(const sparrow::record_batch& record_batch) + { + // Build the schema message to get its exact size + flatbuffers::FlatBufferBuilder schema_builder = get_schema_message_builder(record_batch); + const flatbuffers::uoffset_t schema_len = schema_builder.GetSize(); - const int64_t body_size = calculate_body_size(record_batch); - const auto record_batch_message_offset = org::apache::arrow::flatbuf::CreateMessage( - record_batch_builder, - org::apache::arrow::flatbuf::MetadataVersion::V5, - org::apache::arrow::flatbuf::MessageHeader::RecordBatch, - record_batch_offset.Union(), - body_size, // body length - 0 // custom metadata - ); - record_batch_builder.Finish(record_batch_message_offset); - return record_batch_builder; + // Calculate total size: + // - Continuation bytes (4) + // - Message length prefix (4) + // - FlatBuffer schema message data + // - Padding to 8-byte alignment + std::size_t total_size = continuation.size() + sizeof(uint32_t) + schema_len; + return utils::align_to_8(total_size); } - std::vector serialize_record_batch(const sparrow::record_batch& record_batch) + std::size_t calculate_record_batch_message_size(const sparrow::record_batch& record_batch) { - std::vector nodes = create_fieldnodes(record_batch); - std::vector flatbuf_buffers = get_buffers(record_batch); - flatbuffers::FlatBufferBuilder record_batch_builder = get_record_batch_message_builder( - record_batch, - nodes, - flatbuf_buffers - ); - std::vector output; - output.insert(output.end(), continuation.begin(), continuation.end()); + // Build the record batch message to get its exact metadata size + flatbuffers::FlatBufferBuilder record_batch_builder = get_record_batch_message_builder(record_batch); const flatbuffers::uoffset_t record_batch_len = record_batch_builder.GetSize(); - output.insert( - output.end(), - reinterpret_cast(&record_batch_len), - reinterpret_cast(&record_batch_len) + sizeof(record_batch_len) - ); - output.insert( - output.end(), - record_batch_builder.GetBufferPointer(), - record_batch_builder.GetBufferPointer() + record_batch_len - ); - add_padding(output); - std::vector body = generate_body(record_batch); - output.insert(output.end(), std::make_move_iterator(body.begin()), std::make_move_iterator(body.end())); - return output; + + // Calculate body size (already includes 8-byte alignment for each buffer) + const int64_t body_size = calculate_body_size(record_batch); + + // Calculate total size: + // - Continuation bytes (4) + // - Message length prefix (4) + // - FlatBuffer record batch metadata + // - Padding after metadata to 8-byte alignment + // - Body data (already aligned) + std::size_t metadata_size = continuation.size() + sizeof(uint32_t) + record_batch_len; + metadata_size = utils::align_to_8(metadata_size); + + return metadata_size + static_cast(body_size); } - void add_padding(std::vector& buffer) + std::vector get_column_dtypes(const sparrow::record_batch& rb) { - buffer.insert( - buffer.end(), - utils::align_to_8(static_cast(buffer.size())) - static_cast(buffer.size()), - 0 + std::vector dtypes; + dtypes.reserve(rb.nb_columns()); + std::ranges::transform( + rb.columns(), + std::back_inserter(dtypes), + [](const auto& col) + { + return col.data_type(); + } ); + return dtypes; } - -} \ No newline at end of file +} diff --git a/src/serializer.cpp b/src/serializer.cpp new file mode 100644 index 0000000..c08e4ad --- /dev/null +++ b/src/serializer.cpp @@ -0,0 +1,61 @@ +#include "sparrow_ipc/serializer.hpp" + +#include + +#include "sparrow_ipc/serialize.hpp" +#include "sparrow_ipc/serialize_utils.hpp" + +namespace sparrow_ipc +{ + serializer::serializer(const sparrow::record_batch& rb, output_stream& stream) + : m_pstream(&stream) + , m_dtypes(get_column_dtypes(rb)) + { + const auto reserve_function = [&rb]() + { + return calculate_schema_message_size(rb) + calculate_record_batch_message_size(rb); + }; + m_pstream->reserve(reserve_function); + serialize_schema_message(rb, *m_pstream); + serialize_record_batch(rb, *m_pstream); + } + + void serializer::append(const sparrow::record_batch& rb) + { + if (m_ended) + { + throw std::runtime_error("Cannot append to a serializer that has been ended"); + } + if (get_column_dtypes(rb) != m_dtypes) + { + throw std::invalid_argument("Record batch has different schema than previous ones"); + } + const auto reserve_function = [&]() + { + return m_pstream->size() + calculate_record_batch_message_size(rb); + }; + serialize_record_batch(rb, *m_pstream); + } + + std::vector serializer::get_column_dtypes(const sparrow::record_batch& rb) + { + std::vector dtypes; + dtypes.reserve(rb.nb_columns()); + for (const auto& col : rb.columns()) + { + dtypes.push_back(col.data_type()); + } + return dtypes; + } + + void serializer::end() + { + if (m_ended) + { + return; + } + m_pstream->write(end_of_stream); + m_pstream->flush(); + m_ended = true; + } +} \ No newline at end of file diff --git a/src/utils.cpp b/src/utils.cpp index 3d7b5e7..2fc2490 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,448 +1,34 @@ #include "sparrow_ipc/utils.hpp" -#include -#include -#include - -#include "sparrow.hpp" - -namespace sparrow_ipc +namespace sparrow_ipc::utils { - namespace + std::optional parse_format(std::string_view format_str, std::string_view sep) { - // Parse the format string - // The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc - std::optional parse_format(std::string_view format_str, std::string_view sep) + // Find the position of the delimiter + const auto sep_pos = format_str.find(sep); + if (sep_pos == std::string_view::npos) { - // Find the position of the delimiter - const auto sep_pos = format_str.find(sep); - if (sep_pos == std::string_view::npos) - { - return std::nullopt; - } - - std::string_view substr_str(format_str.data() + sep_pos + 1, format_str.size() - sep_pos - 1); + return std::nullopt; + } - int32_t substr_size = 0; - const auto [ptr, ec] = std::from_chars( - substr_str.data(), - substr_str.data() + substr_str.size(), - substr_size - ); + std::string_view substr_str(format_str.data() + sep_pos + 1, format_str.size() - sep_pos - 1); - if (ec != std::errc() || ptr != substr_str.data() + substr_str.size()) - { - return std::nullopt; - } - return substr_size; - } + int32_t substr_size = 0; + const auto [ptr, ec] = std::from_chars( + substr_str.data(), + substr_str.data() + substr_str.size(), + substr_size + ); - // Creates a Flatbuffers Decimal type from a format string - // The format string is expected to be in the format "d:precision,scale" - std::pair> get_flatbuffer_decimal_type( - flatbuffers::FlatBufferBuilder& builder, - std::string_view format_str, - const int32_t bitWidth - ) + if (ec != std::errc() || ptr != substr_str.data() + substr_str.size()) { - // Decimal requires precision and scale. We need to parse the format_str. - // Format: "d:precision,scale" - const auto scale = parse_format(format_str, ","); - if (!scale.has_value()) - { - throw std::runtime_error( - "Failed to parse Decimal " + std::to_string(bitWidth) - + " scale from format string: " + std::string(format_str) - ); - } - const size_t comma_pos = format_str.find(','); - const auto precision = parse_format(format_str.substr(0, comma_pos), ":"); - if (!precision.has_value()) - { - throw std::runtime_error( - "Failed to parse Decimal " + std::to_string(bitWidth) - + " precision from format string: " + std::string(format_str) - ); - } - const auto decimal_type = org::apache::arrow::flatbuf::CreateDecimal( - builder, - precision.value(), - scale.value(), - bitWidth - ); - return {org::apache::arrow::flatbuf::Type::Decimal, decimal_type.Union()}; + return std::nullopt; } + return substr_size; } - namespace utils + size_t align_to_8(const size_t n) { - int64_t align_to_8(const int64_t n) - { - return (n + 7) & -8; - } - - std::pair> - get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str) - { - const auto type = sparrow::format_to_data_type(format_str); - switch (type) - { - case sparrow::data_type::NA: - { - const auto null_type = org::apache::arrow::flatbuf::CreateNull(builder); - return {org::apache::arrow::flatbuf::Type::Null, null_type.Union()}; - } - case sparrow::data_type::BOOL: - { - const auto bool_type = org::apache::arrow::flatbuf::CreateBool(builder); - return {org::apache::arrow::flatbuf::Type::Bool, bool_type.Union()}; - } - case sparrow::data_type::UINT8: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, false); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::INT8: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::UINT16: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, false); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::INT16: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::UINT32: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, false); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::INT32: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::UINT64: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, false); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::INT64: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::HALF_FLOAT: - { - const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( - builder, - org::apache::arrow::flatbuf::Precision::HALF - ); - return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; - } - case sparrow::data_type::FLOAT: - { - const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( - builder, - org::apache::arrow::flatbuf::Precision::SINGLE - ); - return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; - } - case sparrow::data_type::DOUBLE: - { - const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( - builder, - org::apache::arrow::flatbuf::Precision::DOUBLE - ); - return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; - } - case sparrow::data_type::STRING: - { - const auto string_type = org::apache::arrow::flatbuf::CreateUtf8(builder); - return {org::apache::arrow::flatbuf::Type::Utf8, string_type.Union()}; - } - case sparrow::data_type::LARGE_STRING: - { - const auto large_string_type = org::apache::arrow::flatbuf::CreateLargeUtf8(builder); - return {org::apache::arrow::flatbuf::Type::LargeUtf8, large_string_type.Union()}; - } - case sparrow::data_type::BINARY: - { - const auto binary_type = org::apache::arrow::flatbuf::CreateBinary(builder); - return {org::apache::arrow::flatbuf::Type::Binary, binary_type.Union()}; - } - case sparrow::data_type::LARGE_BINARY: - { - const auto large_binary_type = org::apache::arrow::flatbuf::CreateLargeBinary(builder); - return {org::apache::arrow::flatbuf::Type::LargeBinary, large_binary_type.Union()}; - } - case sparrow::data_type::STRING_VIEW: - { - const auto string_view_type = org::apache::arrow::flatbuf::CreateUtf8View(builder); - return {org::apache::arrow::flatbuf::Type::Utf8View, string_view_type.Union()}; - } - case sparrow::data_type::BINARY_VIEW: - { - const auto binary_view_type = org::apache::arrow::flatbuf::CreateBinaryView(builder); - return {org::apache::arrow::flatbuf::Type::BinaryView, binary_view_type.Union()}; - } - case sparrow::data_type::DATE_DAYS: - { - const auto date_type = org::apache::arrow::flatbuf::CreateDate( - builder, - org::apache::arrow::flatbuf::DateUnit::DAY - ); - return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; - } - case sparrow::data_type::DATE_MILLISECONDS: - { - const auto date_type = org::apache::arrow::flatbuf::CreateDate( - builder, - org::apache::arrow::flatbuf::DateUnit::MILLISECOND - ); - return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; - } - case sparrow::data_type::TIMESTAMP_SECONDS: - { - const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( - builder, - org::apache::arrow::flatbuf::TimeUnit::SECOND - ); - return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; - } - case sparrow::data_type::TIMESTAMP_MILLISECONDS: - { - const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( - builder, - org::apache::arrow::flatbuf::TimeUnit::MILLISECOND - ); - return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; - } - case sparrow::data_type::TIMESTAMP_MICROSECONDS: - { - const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( - builder, - org::apache::arrow::flatbuf::TimeUnit::MICROSECOND - ); - return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; - } - case sparrow::data_type::TIMESTAMP_NANOSECONDS: - { - const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( - builder, - org::apache::arrow::flatbuf::TimeUnit::NANOSECOND - ); - return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; - } - case sparrow::data_type::DURATION_SECONDS: - { - const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( - builder, - org::apache::arrow::flatbuf::TimeUnit::SECOND - ); - return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; - } - case sparrow::data_type::DURATION_MILLISECONDS: - { - const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( - builder, - org::apache::arrow::flatbuf::TimeUnit::MILLISECOND - ); - return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; - } - case sparrow::data_type::DURATION_MICROSECONDS: - { - const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( - builder, - org::apache::arrow::flatbuf::TimeUnit::MICROSECOND - ); - return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; - } - case sparrow::data_type::DURATION_NANOSECONDS: - { - const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( - builder, - org::apache::arrow::flatbuf::TimeUnit::NANOSECOND - ); - return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; - } - case sparrow::data_type::INTERVAL_MONTHS: - { - const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( - builder, - org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH - ); - return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; - } - case sparrow::data_type::INTERVAL_DAYS_TIME: - { - const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( - builder, - org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME - ); - return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; - } - case sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS: - { - const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( - builder, - org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO - ); - return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; - } - case sparrow::data_type::TIME_SECONDS: - { - const auto time_type = org::apache::arrow::flatbuf::CreateTime( - builder, - org::apache::arrow::flatbuf::TimeUnit::SECOND, - 32 - ); - return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; - } - case sparrow::data_type::TIME_MILLISECONDS: - { - const auto time_type = org::apache::arrow::flatbuf::CreateTime( - builder, - org::apache::arrow::flatbuf::TimeUnit::MILLISECOND, - 32 - ); - return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; - } - case sparrow::data_type::TIME_MICROSECONDS: - { - const auto time_type = org::apache::arrow::flatbuf::CreateTime( - builder, - org::apache::arrow::flatbuf::TimeUnit::MICROSECOND, - 64 - ); - return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; - } - case sparrow::data_type::TIME_NANOSECONDS: - { - const auto time_type = org::apache::arrow::flatbuf::CreateTime( - builder, - org::apache::arrow::flatbuf::TimeUnit::NANOSECOND, - 64 - ); - return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; - } - case sparrow::data_type::LIST: - { - const auto list_type = org::apache::arrow::flatbuf::CreateList(builder); - return {org::apache::arrow::flatbuf::Type::List, list_type.Union()}; - } - case sparrow::data_type::LARGE_LIST: - { - const auto large_list_type = org::apache::arrow::flatbuf::CreateLargeList(builder); - return {org::apache::arrow::flatbuf::Type::LargeList, large_list_type.Union()}; - } - case sparrow::data_type::LIST_VIEW: - { - const auto list_view_type = org::apache::arrow::flatbuf::CreateListView(builder); - return {org::apache::arrow::flatbuf::Type::ListView, list_view_type.Union()}; - } - case sparrow::data_type::LARGE_LIST_VIEW: - { - const auto large_list_view_type = org::apache::arrow::flatbuf::CreateLargeListView(builder); - return {org::apache::arrow::flatbuf::Type::LargeListView, large_list_view_type.Union()}; - } - case sparrow::data_type::FIXED_SIZED_LIST: - { - // FixedSizeList requires listSize. We need to parse the format_str. - // Format: "+w:size" - const auto list_size = parse_format(format_str, ":"); - if (!list_size.has_value()) - { - throw std::runtime_error( - "Failed to parse FixedSizeList size from format string: " + std::string(format_str) - ); - } - - const auto fixed_size_list_type = org::apache::arrow::flatbuf::CreateFixedSizeList( - builder, - list_size.value() - ); - return {org::apache::arrow::flatbuf::Type::FixedSizeList, fixed_size_list_type.Union()}; - } - case sparrow::data_type::STRUCT: - { - const auto struct_type = org::apache::arrow::flatbuf::CreateStruct_(builder); - return {org::apache::arrow::flatbuf::Type::Struct_, struct_type.Union()}; - } - case sparrow::data_type::MAP: - { - // not sorted keys - const auto map_type = org::apache::arrow::flatbuf::CreateMap(builder, false); - return {org::apache::arrow::flatbuf::Type::Map, map_type.Union()}; - } - case sparrow::data_type::DENSE_UNION: - { - const auto union_type = org::apache::arrow::flatbuf::CreateUnion( - builder, - org::apache::arrow::flatbuf::UnionMode::Dense, - 0 - ); - return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; - } - case sparrow::data_type::SPARSE_UNION: - { - const auto union_type = org::apache::arrow::flatbuf::CreateUnion( - builder, - org::apache::arrow::flatbuf::UnionMode::Sparse, - 0 - ); - return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; - } - case sparrow::data_type::RUN_ENCODED: - { - const auto run_end_encoded_type = org::apache::arrow::flatbuf::CreateRunEndEncoded(builder); - return {org::apache::arrow::flatbuf::Type::RunEndEncoded, run_end_encoded_type.Union()}; - } - case sparrow::data_type::DECIMAL32: - { - return get_flatbuffer_decimal_type(builder, format_str, 32); - } - case sparrow::data_type::DECIMAL64: - { - return get_flatbuffer_decimal_type(builder, format_str, 64); - } - case sparrow::data_type::DECIMAL128: - { - return get_flatbuffer_decimal_type(builder, format_str, 128); - } - case sparrow::data_type::DECIMAL256: - { - return get_flatbuffer_decimal_type(builder, format_str, 256); - } - case sparrow::data_type::FIXED_WIDTH_BINARY: - { - // FixedSizeBinary requires byteWidth. We need to parse the format_str. - // Format: "w:size" - const auto byte_width = parse_format(format_str, ":"); - if (!byte_width.has_value()) - { - throw std::runtime_error( - "Failed to parse FixedWidthBinary size from format string: " - + std::string(format_str) - ); - } - - const auto fixed_width_binary_type = org::apache::arrow::flatbuf::CreateFixedSizeBinary( - builder, - byte_width.value() - ); - return {org::apache::arrow::flatbuf::Type::FixedSizeBinary, fixed_width_binary_type.Union()}; - } - default: - { - throw std::runtime_error("Unsupported data type for serialization"); - } - } - } + return (n + 7) & -8; } } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4e49c5d..2a8252f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,7 +7,12 @@ set(SPARROW_IPC_TESTS_SRC main.cpp test_arrow_array.cpp test_arrow_schema.cpp + test_chunk_memory_output_stream.cpp + test_chunk_memory_serializer.cpp test_de_serialization_with_files.cpp + test_file_output_stream.cpp + test_flatbuffer_utils.cpp + test_memory_output_streams.cpp test_serialize_utils.cpp test_utils.cpp ) @@ -22,23 +27,25 @@ target_link_libraries(${test_target} ) if(WIN32) - find_package(date) # For copying DLLs - add_custom_command( - TARGET ${test_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy - "$" - "$" - COMMAND ${CMAKE_COMMAND} -E copy - "$" - "$" - COMMAND ${CMAKE_COMMAND} -E copy - "$" - "$" - COMMAND ${CMAKE_COMMAND} -E copy - "$" - "$" - COMMENT "Copying sparrow and sparrow-ipc DLLs to executable directory" - ) + if(${SPARROW_IPC_BUILD_SHARED}) + find_package(date) # For copying DLLs + add_custom_command( + TARGET ${test_target} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + "$" + "$" + COMMAND ${CMAKE_COMMAND} -E copy + "$" + "$" + COMMAND ${CMAKE_COMMAND} -E copy + "$" + "$" + COMMAND ${CMAKE_COMMAND} -E copy + "$" + "$" + COMMENT "Copying sparrow and sparrow-ipc DLLs to executable directory" + ) + endif() endif() target_include_directories(${test_target} diff --git a/tests/include/sparrow_ipc_tests_helpers.hpp b/tests/include/sparrow_ipc_tests_helpers.hpp index ad6db6e..79cc84b 100644 --- a/tests/include/sparrow_ipc_tests_helpers.hpp +++ b/tests/include/sparrow_ipc_tests_helpers.hpp @@ -1,7 +1,9 @@ #pragma once -#include "doctest/doctest.h" -#include "sparrow.hpp" +#include + +#include + namespace sparrow_ipc { @@ -32,7 +34,7 @@ namespace sparrow_ipc } // Helper function to create a simple ArrowSchema for testing - ArrowSchema + inline ArrowSchema create_test_arrow_schema(const char* format, const char* name = "test_field", bool nullable = true) { ArrowSchema schema{}; @@ -49,7 +51,8 @@ namespace sparrow_ipc } // Helper function to create ArrowSchema with metadata - ArrowSchema create_test_arrow_schema_with_metadata(const char* format, const char* name = "test_field") + inline ArrowSchema + create_test_arrow_schema_with_metadata(const char* format, const char* name = "test_field") { auto schema = create_test_arrow_schema(format, name); @@ -59,7 +62,7 @@ namespace sparrow_ipc } // Helper function to create a simple record batch for testing - sp::record_batch create_test_record_batch() + inline sp::record_batch create_test_record_batch() { // Create a simple record batch with integer and string columns using initializer syntax return sp::record_batch( diff --git a/tests/test_chunk_memory_output_stream.cpp b/tests/test_chunk_memory_output_stream.cpp new file mode 100644 index 0000000..5015c4c --- /dev/null +++ b/tests/test_chunk_memory_output_stream.cpp @@ -0,0 +1,570 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include "doctest/doctest.h" + +namespace sparrow_ipc +{ + TEST_SUITE("chuncked_memory_output_stream") + { + TEST_CASE("basic construction") + { + SUBCASE("Construction with empty vector of vectors") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 0); + } + + SUBCASE("Construction with existing chunks") + { + std::vector> chunks = { + {1, 2, 3}, + {4, 5, 6, 7}, + {8, 9} + }; + chuncked_memory_output_stream stream(chunks); + + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 9); + CHECK_EQ(chunks.size(), 3); + } + } + + TEST_CASE("write operations with span") + { + SUBCASE("Write single byte span") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + uint8_t data[] = {42}; + std::span span(data, 1); + + auto written = stream.write(span); + + CHECK_EQ(written, 1); + CHECK_EQ(stream.size(), 1); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 1); + CHECK_EQ(chunks[0][0], 42); + } + + SUBCASE("Write multiple bytes span") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + uint8_t data[] = {1, 2, 3, 4, 5}; + std::span span(data, 5); + + auto written = stream.write(span); + + CHECK_EQ(written, 5); + CHECK_EQ(stream.size(), 5); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(chunks[0][i], i + 1); + } + } + + SUBCASE("Write empty span") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + std::span empty_span; + + auto written = stream.write(empty_span); + + CHECK_EQ(written, 0); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 0); + } + + SUBCASE("Multiple span writes create multiple chunks") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + uint8_t data1[] = {10, 20}; + uint8_t data2[] = {30, 40, 50}; + uint8_t data3[] = {60}; + + stream.write(std::span(data1, 2)); + stream.write(std::span(data2, 3)); + stream.write(std::span(data3, 1)); + + CHECK_EQ(stream.size(), 6); + CHECK_EQ(chunks.size(), 3); + + CHECK_EQ(chunks[0].size(), 2); + CHECK_EQ(chunks[0][0], 10); + CHECK_EQ(chunks[0][1], 20); + + CHECK_EQ(chunks[1].size(), 3); + CHECK_EQ(chunks[1][0], 30); + CHECK_EQ(chunks[1][1], 40); + CHECK_EQ(chunks[1][2], 50); + + CHECK_EQ(chunks[2].size(), 1); + CHECK_EQ(chunks[2][0], 60); + } + } + + TEST_CASE("write operations with move") + { + SUBCASE("Write moved vector") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + std::vector buffer = {1, 2, 3, 4, 5}; + auto written = stream.write(std::move(buffer)); + + CHECK_EQ(written, 5); + CHECK_EQ(stream.size(), 5); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(chunks[0][i], i + 1); + } + } + + SUBCASE("Write multiple moved vectors") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + std::vector buffer1 = {10, 20, 30}; + std::vector buffer2 = {40, 50}; + std::vector buffer3 = {60, 70, 80, 90}; + + stream.write(std::move(buffer1)); + stream.write(std::move(buffer2)); + stream.write(std::move(buffer3)); + + CHECK_EQ(stream.size(), 9); + CHECK_EQ(chunks.size(), 3); + + CHECK_EQ(chunks[0].size(), 3); + CHECK_EQ(chunks[1].size(), 2); + CHECK_EQ(chunks[2].size(), 4); + } + + SUBCASE("Write empty moved vector") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + std::vector empty_buffer; + auto written = stream.write(std::move(empty_buffer)); + + CHECK_EQ(written, 0); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 0); + } + } + + TEST_CASE("write operations with repeated value") + { + SUBCASE("Write value multiple times") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + auto written = stream.write(static_cast(255), 5); + + CHECK_EQ(written, 5); + CHECK_EQ(stream.size(), 5); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(chunks[0][i], 255); + } + } + + SUBCASE("Write value zero times") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + auto written = stream.write(static_cast(42), 0); + + CHECK_EQ(written, 0); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 0); + } + + SUBCASE("Multiple repeated value writes") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + stream.write(static_cast(100), 3); + stream.write(static_cast(200), 2); + stream.write(static_cast(50), 4); + + CHECK_EQ(stream.size(), 9); + CHECK_EQ(chunks.size(), 3); + + CHECK_EQ(chunks[0].size(), 3); + for (size_t i = 0; i < 3; ++i) + { + CHECK_EQ(chunks[0][i], 100); + } + + CHECK_EQ(chunks[1].size(), 2); + for (size_t i = 0; i < 2; ++i) + { + CHECK_EQ(chunks[1][i], 200); + } + + CHECK_EQ(chunks[2].size(), 4); + for (size_t i = 0; i < 4; ++i) + { + CHECK_EQ(chunks[2][i], 50); + } + } + } + + TEST_CASE("mixed write operations") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + // Write span + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + + // Write repeated value + stream.write(static_cast(42), 2); + + // Write moved vector + std::vector buffer = {10, 20, 30, 40}; + stream.write(std::move(buffer)); + + CHECK_EQ(stream.size(), 9); + CHECK_EQ(chunks.size(), 3); + + CHECK_EQ(chunks[0].size(), 3); + CHECK_EQ(chunks[0][0], 1); + CHECK_EQ(chunks[0][1], 2); + CHECK_EQ(chunks[0][2], 3); + + CHECK_EQ(chunks[1].size(), 2); + CHECK_EQ(chunks[1][0], 42); + CHECK_EQ(chunks[1][1], 42); + + CHECK_EQ(chunks[2].size(), 4); + CHECK_EQ(chunks[2][0], 10); + CHECK_EQ(chunks[2][1], 20); + CHECK_EQ(chunks[2][2], 30); + CHECK_EQ(chunks[2][3], 40); + } + + TEST_CASE("reserve functionality") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + // Reserve space + stream.reserve(100); + + // Chunks vector should have reserved capacity but size should remain 0 + CHECK_GE(chunks.capacity(), 100); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 0); + + // Writing should work normally after reserve + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_EQ(stream.size(), 3); + CHECK_EQ(chunks.size(), 1); + } + + TEST_CASE("size calculation") + { + SUBCASE("Size with empty chunks") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + CHECK_EQ(stream.size(), 0); + } + + SUBCASE("Size with pre-existing chunks") + { + std::vector> chunks = { + {1, 2, 3}, + {4, 5}, + {6, 7, 8, 9} + }; + chuncked_memory_output_stream stream(chunks); + + CHECK_EQ(stream.size(), 9); + } + + SUBCASE("Size updates after writes") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + CHECK_EQ(stream.size(), 0); + + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + CHECK_EQ(stream.size(), 3); + + stream.write(static_cast(42), 5); + CHECK_EQ(stream.size(), 8); + + std::vector buffer = {10, 20}; + stream.write(std::move(buffer)); + CHECK_EQ(stream.size(), 10); + } + + SUBCASE("Size with chunks of varying sizes") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + stream.write(static_cast(1), 1); + stream.write(static_cast(2), 10); + stream.write(static_cast(3), 100); + stream.write(static_cast(4), 1000); + + CHECK_EQ(stream.size(), 1111); + CHECK_EQ(chunks.size(), 4); + } + } + + TEST_CASE("stream lifecycle") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + SUBCASE("Stream is always open") + { + CHECK(stream.is_open()); + + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + CHECK(stream.is_open()); + + stream.flush(); + CHECK(stream.is_open()); + + stream.close(); + CHECK(stream.is_open()); + } + + SUBCASE("Flush operation") + { + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + + // Flush should not throw or change state + CHECK_NOTHROW(stream.flush()); + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 3); + CHECK_EQ(chunks.size(), 1); + } + + SUBCASE("Close operation") + { + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + + // Close should not throw + CHECK_NOTHROW(stream.close()); + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 3); + CHECK_EQ(chunks.size(), 1); + } + } + + TEST_CASE("large data handling") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + SUBCASE("Single large chunk") + { + const size_t large_size = 10000; + std::vector large_data(large_size); + std::iota(large_data.begin(), large_data.end(), 0); + + auto written = stream.write(std::move(large_data)); + + CHECK_EQ(written, large_size); + CHECK_EQ(stream.size(), large_size); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), large_size); + + // Verify data integrity + for (size_t i = 0; i < large_size; ++i) + { + CHECK_EQ(chunks[0][i], static_cast(i)); + } + } + + SUBCASE("Many small chunks") + { + const size_t num_chunks = 1000; + const size_t chunk_size = 10; + + for (size_t i = 0; i < num_chunks; ++i) + { + uint8_t value = static_cast(i); + stream.write(value, chunk_size); + } + + CHECK_EQ(stream.size(), num_chunks * chunk_size); + CHECK_EQ(chunks.size(), num_chunks); + + for (size_t i = 0; i < num_chunks; ++i) + { + CHECK_EQ(chunks[i].size(), chunk_size); + for (size_t j = 0; j < chunk_size; ++j) + { + CHECK_EQ(chunks[i][j], static_cast(i)); + } + } + } + + SUBCASE("Large repeated value write") + { + const size_t count = 50000; + auto written = stream.write(static_cast(123), count); + + CHECK_EQ(written, count); + CHECK_EQ(stream.size(), count); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), count); + + for (size_t i = 0; i < count; ++i) + { + CHECK_EQ(chunks[0][i], 123); + } + } + } + + TEST_CASE("edge cases") + { + SUBCASE("Maximum value writes") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + auto written = stream.write(std::numeric_limits::max(), 255); + + CHECK_EQ(written, 255); + CHECK_EQ(stream.size(), 255); + CHECK_EQ(chunks.size(), 1); + for (size_t i = 0; i < 255; ++i) + { + CHECK_EQ(chunks[0][i], std::numeric_limits::max()); + } + } + + SUBCASE("Zero byte writes") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + auto written = stream.write(static_cast(0), 100); + + CHECK_EQ(written, 100); + CHECK_EQ(stream.size(), 100); + CHECK_EQ(chunks.size(), 1); + for (size_t i = 0; i < 100; ++i) + { + CHECK_EQ(chunks[0][i], 0); + } + } + + SUBCASE("Interleaved empty and non-empty writes") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + stream.write(static_cast(1), 5); + stream.write(static_cast(2), 0); + stream.write(static_cast(3), 3); + std::vector empty; + stream.write(std::move(empty)); + stream.write(static_cast(4), 2); + + CHECK_EQ(stream.size(), 10); + CHECK_EQ(chunks.size(), 5); + + CHECK_EQ(chunks[0].size(), 5); + CHECK_EQ(chunks[1].size(), 0); + CHECK_EQ(chunks[2].size(), 3); + CHECK_EQ(chunks[3].size(), 0); + CHECK_EQ(chunks[4].size(), 2); + } + } + + TEST_CASE("reference semantics") + { + SUBCASE("Stream modifies original chunks vector") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + + // Verify that the original vector is modified + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 3); + CHECK_EQ(chunks[0][0], 1); + CHECK_EQ(chunks[0][1], 2); + CHECK_EQ(chunks[0][2], 3); + } + + SUBCASE("Multiple streams to same chunks vector") + { + std::vector> chunks; + + { + chuncked_memory_output_stream stream1(chunks); + uint8_t data1[] = {10, 20}; + stream1.write(std::span(data1, 2)); + } + + { + chuncked_memory_output_stream stream2(chunks); + uint8_t data2[] = {30, 40}; + stream2.write(std::span(data2, 2)); + } + + CHECK_EQ(chunks.size(), 2); + CHECK_EQ(chunks[0][0], 10); + CHECK_EQ(chunks[0][1], 20); + CHECK_EQ(chunks[1][0], 30); + CHECK_EQ(chunks[1][1], 40); + } + } + } +} diff --git a/tests/test_chunk_memory_serializer.cpp b/tests/test_chunk_memory_serializer.cpp new file mode 100644 index 0000000..a5077ec --- /dev/null +++ b/tests/test_chunk_memory_serializer.cpp @@ -0,0 +1,381 @@ +#include + +#include +#include + +#include "sparrow_ipc/chunk_memory_output_stream.hpp" +#include "sparrow_ipc/chunk_memory_serializer.hpp" +#include "sparrow_ipc_tests_helpers.hpp" + +namespace sparrow_ipc +{ + namespace sp = sparrow; + + TEST_SUITE("chunk_serializer") + { + TEST_CASE("construction with single record batch") + { + SUBCASE("Valid record batch") + { + auto rb = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb, stream); + + // After construction with single record batch, should have schema + record batch + CHECK_EQ(chunks.size(), 2); + CHECK_GT(chunks[0].size(), 0); // Schema message + CHECK_GT(chunks[1].size(), 0); // Record batch message + CHECK_GT(stream.size(), 0); + } + + SUBCASE("Empty record batch") + { + auto empty_batch = sp::record_batch({}); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(empty_batch, stream); + + CHECK_EQ(chunks.size(), 2); + CHECK_GT(chunks[0].size(), 0); + } + } + + TEST_CASE("construction with range of record batches") + { + SUBCASE("Valid record batches") + { + auto array1 = sp::primitive_array({1, 2, 3}); + auto array2 = sp::primitive_array({1.0, 2.0, 3.0}); + auto rb1 = sp::record_batch( + {{"col1", sp::array(std::move(array1))}, {"col2", sp::array(std::move(array2))}} + ); + + auto array3 = sp::primitive_array({4, 5, 6}); + auto array4 = sp::primitive_array({4.0, 5.0, 6.0}); + auto rb2 = sp::record_batch( + {{"col1", sp::array(std::move(array3))}, {"col2", sp::array(std::move(array4))}} + ); + + std::vector record_batches = {rb1, rb2}; + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(record_batches, stream); + + // Should have schema + 2 record batches = 3 chunks + CHECK_EQ(chunks.size(), 3); + CHECK_GT(chunks[0].size(), 0); // Schema + CHECK_GT(chunks[1].size(), 0); // First record batch + CHECK_GT(chunks[2].size(), 0); // Second record batch + } + + SUBCASE("Empty collection throws exception") + { + std::vector empty_batches; + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + CHECK_THROWS_AS( + chunk_serializer serializer(empty_batches, stream), + std::invalid_argument + ); + } + + SUBCASE("Reserve is called correctly") + { + auto rb = create_test_record_batch(); + std::vector record_batches = {rb}; + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(record_batches, stream); + + // Verify that chunks were reserved (capacity should be >= size) + CHECK_GE(chunks.capacity(), chunks.size()); + } + } + + TEST_CASE("append single record batch") + { + SUBCASE("Append after construction with single batch") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb1, stream); + CHECK_EQ(chunks.size(), 2); // Schema + rb1 + + // Create compatible record batch + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({6, 7, 8}))}, + {"string_col", sp::array(sp::string_array(std::vector{"foo", "bar", "baz"}))}} + ); + + serializer.append(rb2); + + CHECK_EQ(chunks.size(), 3); // Schema + rb1 + rb2 + CHECK_GT(chunks[2].size(), 0); + } + + SUBCASE("Multiple appends") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb1, stream); + + for (int i = 0; i < 3; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i}))}, + {"string_col", sp::array(sp::string_array(std::vector{"test"}))}} + ); + serializer.append(rb); + } + + CHECK_EQ(chunks.size(), 5); // Schema + 1 initial + 3 appended + } + } + + TEST_CASE("append range of record batches") + { + SUBCASE("Append range after construction") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb1, stream); + CHECK_EQ(chunks.size(), 2); + + auto array1 = sp::primitive_array({10, 20}); + auto array2 = sp::string_array(std::vector{"a", "b"}); + auto rb2 = sp::record_batch( + {{"int_col", sp::array(std::move(array1))}, + {"string_col", sp::array(std::move(array2))}} + ); + + auto array3 = sp::primitive_array({30, 40}); + auto array4 = sp::string_array(std::vector{"c", "d"}); + auto rb3 = sp::record_batch( + {{"int_col", sp::array(std::move(array3))}, + {"string_col", sp::array(std::move(array4))}} + ); + + std::vector new_batches = {rb2, rb3}; + serializer.append(new_batches); + + CHECK_EQ(chunks.size(), 4); // Schema + rb1 + rb2 + rb3 + CHECK_GT(chunks[2].size(), 0); + CHECK_GT(chunks[3].size(), 0); + } + + SUBCASE("Reserve is called during range append") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb1, stream); + + auto rb2 = create_test_record_batch(); + auto rb3 = create_test_record_batch(); + std::vector new_batches = {rb2, rb3}; + + size_t old_capacity = chunks.capacity(); + serializer.append(new_batches); + + // Reserve should have been called + CHECK_GE(chunks.capacity(), chunks.size()); + } + + SUBCASE("Empty range append does nothing") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb1, stream); + size_t initial_size = chunks.size(); + + std::vector empty_batches; + serializer.append(empty_batches); + + CHECK_EQ(chunks.size(), initial_size); + } + } + + TEST_CASE("end serialization") + { + SUBCASE("End after construction") + { + auto rb = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb, stream); + size_t initial_size = chunks.size(); + + serializer.end(); + + // End should add end-of-stream marker + CHECK_GT(chunks.size(), initial_size); + } + + SUBCASE("Cannot append after end") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb1, stream); + serializer.end(); + + auto rb2 = create_test_record_batch(); + CHECK_THROWS_AS(serializer.append(rb2), std::runtime_error); + } + + SUBCASE("Cannot append range after end") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb1, stream); + serializer.end(); + + std::vector new_batches = {create_test_record_batch()}; + CHECK_THROWS_AS(serializer.append(new_batches), std::runtime_error); + } + } + + TEST_CASE("stream size tracking") + { + SUBCASE("Size increases with each operation") + { + auto rb = create_test_record_batch(); + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + size_t size_before = stream.size(); + chunk_serializer serializer(rb, stream); + size_t size_after_construction = stream.size(); + + CHECK_GT(size_after_construction, size_before); + + serializer.append(rb); + size_t size_after_append = stream.size(); + + CHECK_GT(size_after_append, size_after_construction); + } + } + + TEST_CASE("large number of record batches") + { + SUBCASE("Handle many record batches efficiently") + { + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + std::vector batches; + const int num_batches = 100; + + for (int i = 0; i < num_batches; ++i) + { + auto array = sp::primitive_array({i, i+1, i+2}); + batches.push_back(sp::record_batch({{"col", sp::array(std::move(array))}})); + } + + chunk_serializer serializer(batches, stream); + + // Should have schema + all batches + CHECK_EQ(chunks.size(), num_batches + 1); + CHECK_GT(stream.size(), 0); + + // Verify each chunk has data + for (const auto& chunk : chunks) + { + CHECK_GT(chunk.size(), 0); + } + } + } + + TEST_CASE("different column types") + { + SUBCASE("Multiple primitive types") + { + auto int_array = sp::primitive_array({1, 2, 3}); + auto double_array = sp::primitive_array({1.5, 2.5, 3.5}); + auto float_array = sp::primitive_array({1.0f, 2.0f, 3.0f}); + + auto rb = sp::record_batch( + {{"int_col", sp::array(std::move(int_array))}, + {"double_col", sp::array(std::move(double_array))}, + {"float_col", sp::array(std::move(float_array))}} + ); + + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + chunk_serializer serializer(rb, stream); + + CHECK_EQ(chunks.size(), 2); // Schema + record batch + CHECK_GT(chunks[0].size(), 0); + CHECK_GT(chunks[1].size(), 0); + } + } + + TEST_CASE("workflow example") + { + SUBCASE("Typical usage pattern") + { + // Create initial record batch + auto rb1 = create_test_record_batch(); + + // Setup chunked stream + std::vector> chunks; + chuncked_memory_output_stream stream(chunks); + + // Create serializer with initial batch + chunk_serializer serializer(rb1, stream); + CHECK_EQ(chunks.size(), 2); + + // Append more batches + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({10, 20}))}, + {"string_col", sp::array(sp::string_array(std::vector{"x", "y"}))}} + ); + serializer.append(rb2); + CHECK_EQ(chunks.size(), 3); + + // Append range of batches + std::vector more_batches; + for (int i = 0; i < 3; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i}))}, + {"string_col", sp::array(sp::string_array(std::vector{"test"}))}} + ); + more_batches.push_back(rb); + } + serializer.append(more_batches); + CHECK_EQ(chunks.size(), 6); + + // End serialization + serializer.end(); + CHECK_GT(chunks.size(), 6); + + // Verify all chunks have data + for (const auto& chunk : chunks) + { + CHECK_GT(chunk.size(), 0); + } + } + } + } +} diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index 8fe825b..66da594 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -14,7 +14,8 @@ #include "doctest/doctest.h" #include "sparrow.hpp" #include "sparrow_ipc/deserialize.hpp" -#include "sparrow_ipc/serialize.hpp" +#include "sparrow_ipc/memory_output_stream.hpp" +#include "sparrow_ipc/serializer.hpp" const std::filesystem::path arrow_testing_data_dir = ARROW_TESTING_DATA_DIR; @@ -162,7 +163,10 @@ TEST_SUITE("Integration tests") std::span(stream_data) ); - const auto serialized_data = sparrow_ipc::serialize(record_batches_from_json); + std::vector serialized_data; + sparrow_ipc::memory_output_stream stream(serialized_data); + sparrow_ipc::serializer serializer(record_batches_from_json, stream); + serializer.end(); const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream( std::span(serialized_data) ); diff --git a/tests/test_file_output_stream.cpp b/tests/test_file_output_stream.cpp new file mode 100644 index 0000000..43a93d2 --- /dev/null +++ b/tests/test_file_output_stream.cpp @@ -0,0 +1,604 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "doctest/doctest.h" + +namespace sparrow_ipc +{ + // Helper class to manage temporary files for testing + class temporary_file + { + private: + + std::filesystem::path m_path; + + public: + + temporary_file(const std::string& prefix = "test_file_output_stream") + { + m_path = std::filesystem::temp_directory_path() + / (prefix + "_" + std::to_string(std::rand()) + ".tmp"); + } + + ~temporary_file() + { + if (std::filesystem::exists(m_path)) + { + std::filesystem::remove(m_path); + } + } + + const std::filesystem::path& path() const + { + return m_path; + } + + std::string path_string() const + { + return m_path.string(); + } + + // Read the entire file content as bytes + std::vector read_content() const + { + std::ifstream file(m_path, std::ios::binary); + if (!file) + { + return {}; + } + + file.seekg(0, std::ios::end); + size_t size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector content(size); + file.read(reinterpret_cast(content.data()), size); + return content; + } + + // Get file size + size_t file_size() const + { + if (!std::filesystem::exists(m_path)) + { + return 0; + } + return std::filesystem::file_size(m_path); + } + }; + + TEST_SUITE("file_output_stream") + { + TEST_CASE("construction and basic state") + { + SUBCASE("Construction with valid file") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + + file_output_stream stream(file); + + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 0); + } + + SUBCASE("Construction with closed file throws exception") + { + temporary_file temp_file; + std::ofstream file; // Not opened + + CHECK_THROWS_AS(file_output_stream{file}, std::runtime_error); + } + + SUBCASE("Construction with file that fails to open throws exception") + { + std::ofstream file("/invalid/path/file.tmp"); // Invalid path + + CHECK_THROWS_AS(file_output_stream{file}, std::runtime_error); + } + } + + TEST_CASE("write operations") + { + SUBCASE("Write single byte span") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + uint8_t data[] = {42}; + std::span span(data, 1); + + auto written = stream.write(span); + + CHECK_EQ(written, 1); + CHECK_EQ(stream.size(), 1); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + REQUIRE_EQ(content.size(), 1); + CHECK_EQ(content[0], 42); + } + + SUBCASE("Write multiple bytes span") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + uint8_t data[] = {1, 2, 3, 4, 5}; + std::span span(data, 5); + + auto written = stream.write(span); + + CHECK_EQ(written, 5); + CHECK_EQ(stream.size(), 5); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + REQUIRE_EQ(content.size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(content[i], i + 1); + } + } + + SUBCASE("Write empty span") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + std::span empty_span; + + auto written = stream.write(empty_span); + + CHECK_EQ(written, 0); + CHECK_EQ(stream.size(), 0); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + CHECK_EQ(content.size(), 0); + } + + SUBCASE("Write single byte using write(uint8_t)") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + auto written = stream.write(static_cast(123), 1); + + CHECK_EQ(written, 1); + CHECK_EQ(stream.size(), 1); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + REQUIRE_EQ(content.size(), 1); + CHECK_EQ(content[0], 123); + } + + SUBCASE("Write value multiple times") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + auto written = stream.write(static_cast(255), 3); + + CHECK_EQ(written, 3); + CHECK_EQ(stream.size(), 3); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + REQUIRE_EQ(content.size(), 3); + CHECK_EQ(content[0], 255); + CHECK_EQ(content[1], 255); + CHECK_EQ(content[2], 255); + } + + SUBCASE("Write value zero times") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + auto written = stream.write(static_cast(42), 0); + + CHECK_EQ(written, 0); + CHECK_EQ(stream.size(), 0); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + CHECK_EQ(content.size(), 0); + } + } + + TEST_CASE("sequential writes") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + // First write + uint8_t data1[] = {10, 20, 30}; + std::span span1(data1, 3); + auto written1 = stream.write(span1); + + CHECK_EQ(written1, 3); + CHECK_EQ(stream.size(), 3); + + // Second write + uint8_t data2[] = {40, 50}; + std::span span2(data2, 2); + auto written2 = stream.write(span2); + + CHECK_EQ(written2, 2); + CHECK_EQ(stream.size(), 5); + + // Third write with repeated value + auto written3 = stream.write(static_cast(60), 2); + + CHECK_EQ(written3, 2); + CHECK_EQ(stream.size(), 7); + + stream.flush(); + file.close(); + + // Verify final file content + auto content = temp_file.read_content(); + std::vector expected = {10, 20, 30, 40, 50, 60, 60}; + CHECK_EQ(content, expected); + } + + TEST_CASE("reserve functionality") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + // Reserve space (should be no-op for file streams) + CHECK_NOTHROW(stream.reserve(100)); + + // Size should remain 0 after reserve + CHECK_EQ(stream.size(), 0); + + // Writing should work normally after reserve + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_EQ(stream.size(), 3); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + CHECK_EQ(content.size(), 3); + } + + TEST_CASE("add_padding functionality") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + SUBCASE("No padding needed when size is multiple of 8") + { + // Write 8 bytes + uint8_t data[] = {1, 2, 3, 4, 5, 6, 7, 8}; + std::span span(data, 8); + stream.write(span); + + auto size_before = stream.size(); + stream.add_padding(); + + CHECK_EQ(stream.size(), size_before); + CHECK_EQ(stream.size(), 8); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + CHECK_EQ(content.size(), 8); + } + + SUBCASE("Padding needed when size is not multiple of 8") + { + // Write 5 bytes + uint8_t data[] = {1, 2, 3, 4, 5}; + std::span span(data, 5); + stream.write(span); + + stream.add_padding(); + + CHECK_EQ(stream.size(), 8); // Should be padded to next multiple of 8 + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + CHECK_EQ(content.size(), 8); + + // Check original data is preserved + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(content[i], i + 1); + } + // Check padding bytes are zero + CHECK_EQ(content[5], 0); + CHECK_EQ(content[6], 0); + CHECK_EQ(content[7], 0); + } + } + + TEST_CASE("stream lifecycle") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + SUBCASE("Stream is initially open") + { + CHECK(stream.is_open()); + } + + SUBCASE("Flush operation") + { + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_NOTHROW(stream.flush()); + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 3); + + // Verify data is written to file after flush + CHECK_EQ(temp_file.file_size(), 3); + } + + SUBCASE("Close operation") + { + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_NOTHROW(stream.close()); + CHECK_FALSE(stream.is_open()); // File stream should be closed + CHECK_EQ(stream.size(), 3); // Size should remain the same + + // Verify data is written to file after close + CHECK_EQ(temp_file.file_size(), 3); + } + + SUBCASE("Multiple close calls are safe") + { + stream.close(); + CHECK_FALSE(stream.is_open()); + + CHECK_NOTHROW(stream.close()); // Second close should not throw + CHECK_FALSE(stream.is_open()); + } + } + + TEST_CASE("large data handling") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + // Write a large amount of data + const size_t large_size = 10000; + std::vector large_data(large_size); + std::iota(large_data.begin(), large_data.end(), 0); // Fill with 0, 1, 2, ... + + std::span span(large_data); + auto written = stream.write(span); + + CHECK_EQ(written, large_size); + CHECK_EQ(stream.size(), large_size); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + CHECK_EQ(content.size(), large_size); + + // Verify data integrity + for (size_t i = 0; i < large_size; ++i) + { + CHECK_EQ(content[i], static_cast(i)); + } + } + + TEST_CASE("edge cases") + { + SUBCASE("Maximum value repeated writes") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + auto written = stream.write(std::numeric_limits::max(), 255); + + CHECK_EQ(written, 255); + CHECK_EQ(stream.size(), 255); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + REQUIRE_EQ(content.size(), 255); + for (size_t i = 0; i < 255; ++i) + { + CHECK_EQ(content[i], std::numeric_limits::max()); + } + } + + SUBCASE("Zero byte repeated writes") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + auto written = stream.write(static_cast(0), 100); + + CHECK_EQ(written, 100); + CHECK_EQ(stream.size(), 100); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + REQUIRE_EQ(content.size(), 100); + for (size_t i = 0; i < 100; ++i) + { + CHECK_EQ(content[i], 0); + } + } + } + + TEST_CASE("different write patterns") + { + SUBCASE("Alternating single and bulk writes") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + // Single byte write + stream.write(static_cast(100)); + + // Bulk write + uint8_t data[] = {200, 201, 202}; + std::span span(data, 3); + stream.write(span); + + // Repeated value write + stream.write(static_cast(150), 2); + + // Single byte write again + stream.write(static_cast(250)); + + CHECK_EQ(stream.size(), 7); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + std::vector expected = {100, 200, 201, 202, 150, 150, 250}; + CHECK_EQ(content, expected); + } + + SUBCASE("Binary data with null bytes") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + uint8_t data[] = {0x00, 0xFF, 0x00, 0xAA, 0x00, 0x55}; + std::span span(data, 6); + stream.write(span); + + CHECK_EQ(stream.size(), 6); + + stream.flush(); + file.close(); + + auto content = temp_file.read_content(); + REQUIRE_EQ(content.size(), 6); + CHECK_EQ(content[0], 0x00); + CHECK_EQ(content[1], 0xFF); + CHECK_EQ(content[2], 0x00); + CHECK_EQ(content[3], 0xAA); + CHECK_EQ(content[4], 0x00); + CHECK_EQ(content[5], 0x55); + } + } + + TEST_CASE("error handling") + { + SUBCASE("Writing to closed stream") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + // Close the underlying file + file.close(); + + // Stream should report as closed + CHECK_FALSE(stream.is_open()); + + // Writing should still work but may not persist (depends on implementation) + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + auto written = stream.write(span); + + // The write operation itself shouldn't throw, but data may not be written + CHECK_EQ(written, 3); + CHECK_EQ(stream.size(), 3); + } + } + + TEST_CASE("size tracking accuracy") + { + temporary_file temp_file; + std::ofstream file(temp_file.path(), std::ios::binary); + file_output_stream stream(file); + + SUBCASE("Size tracking with various write operations") + { + CHECK_EQ(stream.size(), 0); + + // Write span + uint8_t data1[] = {1, 2}; + stream.write(std::span(data1, 2)); + CHECK_EQ(stream.size(), 2); + + // Write single byte + stream.write(static_cast(3)); + CHECK_EQ(stream.size(), 3); + + // Write repeated value + stream.write(static_cast(4), 3); + CHECK_EQ(stream.size(), 6); + + // Write empty span (should not change size) + std::span empty; + stream.write(empty); + CHECK_EQ(stream.size(), 6); + + // Write zero count (should not change size) + stream.write(static_cast(5), 0); + CHECK_EQ(stream.size(), 6); + + stream.flush(); + file.close(); + + // Verify file size matches stream size + CHECK_EQ(temp_file.file_size(), 6); + } + } + } +} diff --git a/tests/test_flatbuffer_utils.cpp b/tests/test_flatbuffer_utils.cpp new file mode 100644 index 0000000..02f97cd --- /dev/null +++ b/tests/test_flatbuffer_utils.cpp @@ -0,0 +1,535 @@ +#include +#include + +#include "sparrow_ipc_tests_helpers.hpp" + +namespace sparrow_ipc +{ + TEST_SUITE("flatbuffer_utils") + { + TEST_CASE("create_metadata") + { + flatbuffers::FlatBufferBuilder builder; + + SUBCASE("No metadata (nullptr)") + { + auto schema = create_test_arrow_schema("i"); + auto metadata_offset = create_metadata(builder, schema); + CHECK_EQ(metadata_offset.o, 0); + } + + SUBCASE("With metadata - basic test") + { + auto schema = create_test_arrow_schema_with_metadata("i"); + auto metadata_offset = create_metadata(builder, schema); + // For now just check that it doesn't crash + // TODO: Add proper metadata testing when sparrow metadata is properly handled + } + } + + TEST_CASE("create_field") + { + flatbuffers::FlatBufferBuilder builder; + + SUBCASE("Basic field creation") + { + auto schema = create_test_arrow_schema("i", "int_field", true); + auto field_offset = create_field(builder, schema); + CHECK_NE(field_offset.o, 0); + } + + SUBCASE("Field with null name") + { + auto schema = create_test_arrow_schema("i", nullptr, false); + auto field_offset = create_field(builder, schema); + CHECK_NE(field_offset.o, 0); + } + + SUBCASE("Non-nullable field") + { + auto schema = create_test_arrow_schema("i", "int_field", false); + auto field_offset = create_field(builder, schema); + CHECK_NE(field_offset.o, 0); + } + } + + TEST_CASE("create_children from ArrowSchema") + { + flatbuffers::FlatBufferBuilder builder; + + SUBCASE("No children") + { + auto schema = create_test_arrow_schema("i"); + auto children_offset = create_children(builder, schema); + CHECK_EQ(children_offset.o, 0); + } + + SUBCASE("With children") + { + auto parent_schema = create_test_arrow_schema("+s"); + auto child1 = new ArrowSchema(create_test_arrow_schema("i", "child1")); + auto child2 = new ArrowSchema(create_test_arrow_schema("u", "child2")); + + ArrowSchema* children[] = {child1, child2}; + parent_schema.children = children; + parent_schema.n_children = 2; + + auto children_offset = create_children(builder, parent_schema); + CHECK_NE(children_offset.o, 0); + + // Clean up + delete child1; + delete child2; + } + + SUBCASE("Null child pointer throws exception") + { + auto parent_schema = create_test_arrow_schema("+s"); + ArrowSchema* children[] = {nullptr}; + parent_schema.children = children; + parent_schema.n_children = 1; + + CHECK_THROWS_AS( + const auto children_offset = create_children(builder, parent_schema), + std::invalid_argument + ); + } + } + + TEST_CASE("create_children from record_batch columns") + { + flatbuffers::FlatBufferBuilder builder; + + SUBCASE("With valid record batch") + { + auto record_batch = create_test_record_batch(); + auto children_offset = create_children(builder, record_batch.columns()); + CHECK_NE(children_offset.o, 0); + } + + SUBCASE("Empty record batch") + { + auto empty_batch = sp::record_batch({}); + + auto children_offset = create_children(builder, empty_batch.columns()); + CHECK_EQ(children_offset.o, 0); + } + } + + TEST_CASE("get_schema_message_builder") + { + SUBCASE("Valid record batch") + { + auto record_batch = create_test_record_batch(); + auto builder = get_schema_message_builder(record_batch); + + CHECK_GT(builder.GetSize(), 0); + CHECK_NE(builder.GetBufferPointer(), nullptr); + } + } + + TEST_CASE("fill_fieldnodes") + { + SUBCASE("Single array without children") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + + std::vector nodes; + fill_fieldnodes(proxy, nodes); + + CHECK_EQ(nodes.size(), 1); + CHECK_EQ(nodes[0].length(), 5); + CHECK_EQ(nodes[0].null_count(), 0); + } + + SUBCASE("Array with null values") + { + // For now, just test with a simple array without explicit nulls + // Creating arrays with null values requires more complex sparrow setup + auto array = sp::primitive_array({1, 2, 3}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + + std::vector nodes; + fill_fieldnodes(proxy, nodes); + + CHECK_EQ(nodes.size(), 1); + CHECK_EQ(nodes[0].length(), 3); + CHECK_EQ(nodes[0].null_count(), 0); + } + } + + TEST_CASE("create_fieldnodes") + { + SUBCASE("Record batch with multiple columns") + { + auto record_batch = create_test_record_batch(); + auto nodes = create_fieldnodes(record_batch); + + CHECK_EQ(nodes.size(), 2); // Two columns + + // Check the first column (integer array) + CHECK_EQ(nodes[0].length(), 5); + CHECK_EQ(nodes[0].null_count(), 0); + + // Check the second column (string array) + CHECK_EQ(nodes[1].length(), 5); + CHECK_EQ(nodes[1].null_count(), 0); + } + } + + TEST_CASE("fill_buffers") + { + SUBCASE("Simple primitive array") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + + std::vector buffers; + int64_t offset = 0; + fill_buffers(proxy, buffers, offset); + + CHECK_GT(buffers.size(), 0); + CHECK_GT(offset, 0); + + // Verify offsets are aligned + for (const auto& buffer : buffers) + { + CHECK_EQ(buffer.offset() % 8, 0); + } + } + } + + TEST_CASE("get_buffers") + { + SUBCASE("Record batch with multiple columns") + { + auto record_batch = create_test_record_batch(); + auto buffers = get_buffers(record_batch); + CHECK_GT(buffers.size(), 0); + // Verify all offsets are properly calculated and aligned + for (size_t i = 1; i < buffers.size(); ++i) + { + CHECK_GE(buffers[i].offset(), buffers[i - 1].offset() + buffers[i - 1].length()); + } + } + } + + TEST_CASE("get_flatbuffer_type") + { + flatbuffers::FlatBufferBuilder builder; + SUBCASE("Null and Boolean types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::NA)).first, + org::apache::arrow::flatbuf::Type::Null + ); + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::BOOL)).first, + org::apache::arrow::flatbuf::Type::Bool + ); + } + + SUBCASE("Integer types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT8)).first, + org::apache::arrow::flatbuf::Type::Int + ); // INT8 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT8)).first, + org::apache::arrow::flatbuf::Type::Int + ); // UINT8 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT16)).first, + org::apache::arrow::flatbuf::Type::Int + ); // INT16 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT16)).first, + org::apache::arrow::flatbuf::Type::Int + ); // UINT16 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT32)).first, + org::apache::arrow::flatbuf::Type::Int + ); // INT32 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT32)).first, + org::apache::arrow::flatbuf::Type::Int + ); // UINT32 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT64)).first, + org::apache::arrow::flatbuf::Type::Int + ); // INT64 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT64)).first, + org::apache::arrow::flatbuf::Type::Int + ); // UINT64 + } + + SUBCASE("Floating Point types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::HALF_FLOAT)).first, + org::apache::arrow::flatbuf::Type::FloatingPoint + ); // HALF_FLOAT + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::FLOAT)).first, + org::apache::arrow::flatbuf::Type::FloatingPoint + ); // FLOAT + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DOUBLE)).first, + org::apache::arrow::flatbuf::Type::FloatingPoint + ); // DOUBLE + } + + SUBCASE("String and Binary types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::STRING)).first, + org::apache::arrow::flatbuf::Type::Utf8 + ); // STRING + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_STRING)) + .first, + org::apache::arrow::flatbuf::Type::LargeUtf8 + ); // LARGE_STRING + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::BINARY)).first, + org::apache::arrow::flatbuf::Type::Binary + ); // BINARY + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_BINARY)) + .first, + org::apache::arrow::flatbuf::Type::LargeBinary + ); // LARGE_BINARY + CHECK_EQ( + get_flatbuffer_type(builder, "vu").first, + org::apache::arrow::flatbuf::Type::Utf8View + ); // STRING_VIEW + CHECK_EQ( + get_flatbuffer_type(builder, "vz").first, + org::apache::arrow::flatbuf::Type::BinaryView + ); // BINARY_VIEW + } + + SUBCASE("Date types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DATE_DAYS)).first, + org::apache::arrow::flatbuf::Type::Date + ); // DATE_DAYS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DATE_MILLISECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Date + ); // DATE_MILLISECONDS + } + + SUBCASE("Timestamp types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_SECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Timestamp + ); // TIMESTAMP_SECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_MILLISECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Timestamp + ); // TIMESTAMP_MILLISECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_MICROSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Timestamp + ); // TIMESTAMP_MICROSECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_NANOSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Timestamp + ); // TIMESTAMP_NANOSECONDS + } + + SUBCASE("Duration types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DURATION_SECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Duration + ); // DURATION_SECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::DURATION_MILLISECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Duration + ); // DURATION_MILLISECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::DURATION_MICROSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Duration + ); // DURATION_MICROSECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::DURATION_NANOSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Duration + ); // DURATION_NANOSECONDS + } + + SUBCASE("Interval types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INTERVAL_MONTHS)) + .first, + org::apache::arrow::flatbuf::Type::Interval + ); // INTERVAL_MONTHS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INTERVAL_DAYS_TIME)) + .first, + org::apache::arrow::flatbuf::Type::Interval + ); // INTERVAL_DAYS_TIME + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Interval + ); // INTERVAL_MONTHS_DAYS_NANOSECONDS + } + + SUBCASE("Time types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_SECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Time + ); // TIME_SECONDS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_MILLISECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Time + ); // TIME_MILLISECONDS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_MICROSECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Time + ); // TIME_MICROSECONDS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_NANOSECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Time + ); // TIME_NANOSECONDS + } + + SUBCASE("List types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LIST)).first, + org::apache::arrow::flatbuf::Type::List + ); // LIST + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_LIST)).first, + org::apache::arrow::flatbuf::Type::LargeList + ); // LARGE_LIST + CHECK_EQ( + get_flatbuffer_type(builder, "+vl").first, + org::apache::arrow::flatbuf::Type::ListView + ); // LIST_VIEW + CHECK_EQ( + get_flatbuffer_type(builder, "+vL").first, + org::apache::arrow::flatbuf::Type::LargeListView + ); // LARGE_LIST_VIEW + CHECK_EQ( + get_flatbuffer_type(builder, "+w:16").first, + org::apache::arrow::flatbuf::Type::FixedSizeList + ); // FIXED_SIZED_LIST + CHECK_THROWS(get_flatbuffer_type(builder, "+w:")); // Invalid FixedSizeList format + } + + SUBCASE("Struct and Map types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::STRUCT)).first, + org::apache::arrow::flatbuf::Type::Struct_ + ); // STRUCT + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::MAP)).first, + org::apache::arrow::flatbuf::Type::Map + ); // MAP + } + + SUBCASE("Union types") + { + CHECK_EQ( + get_flatbuffer_type(builder, "+ud:").first, + org::apache::arrow::flatbuf::Type::Union + ); // DENSE_UNION + CHECK_EQ( + get_flatbuffer_type(builder, "+us:").first, + org::apache::arrow::flatbuf::Type::Union + ); // SPARSE_UNION + } + + SUBCASE("Run-End Encoded type") + { + CHECK_EQ( + get_flatbuffer_type(builder, "+r").first, + org::apache::arrow::flatbuf::Type::RunEndEncoded + ); // RUN_ENCODED + } + + SUBCASE("Decimal types") + { + CHECK_EQ( + get_flatbuffer_type(builder, "d:10,5").first, + org::apache::arrow::flatbuf::Type::Decimal + ); // DECIMAL (general) + CHECK_THROWS(get_flatbuffer_type(builder, "d:10")); // Invalid Decimal format + } + + SUBCASE("Fixed Width Binary type") + { + CHECK_EQ( + get_flatbuffer_type(builder, "w:32").first, + org::apache::arrow::flatbuf::Type::FixedSizeBinary + ); // FIXED_WIDTH_BINARY + CHECK_THROWS(get_flatbuffer_type(builder, "w:")); // Invalid FixedSizeBinary format + } + + SUBCASE("Unsupported type returns Null") + { + CHECK_EQ( + get_flatbuffer_type(builder, "unsupported_format").first, + org::apache::arrow::flatbuf::Type::Null + ); + } + } + + TEST_CASE("get_record_batch_message_builder") + { + SUBCASE("Valid record batch with field nodes and buffers") + { + auto record_batch = create_test_record_batch(); + auto builder = get_record_batch_message_builder(record_batch); + CHECK_GT(builder.GetSize(), 0); + CHECK_NE(builder.GetBufferPointer(), nullptr); + } + } + } +} \ No newline at end of file diff --git a/tests/test_memory_output_streams.cpp b/tests/test_memory_output_streams.cpp new file mode 100644 index 0000000..5e8c107 --- /dev/null +++ b/tests/test_memory_output_streams.cpp @@ -0,0 +1,372 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include "doctest/doctest.h" + +namespace sparrow_ipc +{ + TEST_SUITE("memory_output_stream") + { + TEST_CASE("basic construction") + { + SUBCASE("Construction with std::vector") + { + std::vector buffer; + memory_output_stream stream(buffer); + + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 0); + } + + SUBCASE("Construction with non-empty buffer") + { + std::vector buffer = {1, 2, 3, 4, 5}; + memory_output_stream stream(buffer); + + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 5); + } + } + + TEST_CASE("write operations") + { + SUBCASE("Write single byte span") + { + std::vector buffer; + memory_output_stream stream(buffer); + + uint8_t data[] = {42}; + std::span span(data, 1); + + auto written = stream.write(span); + + CHECK_EQ(written, 1); + CHECK_EQ(stream.size(), 1); + CHECK_EQ(buffer.size(), 1); + CHECK_EQ(buffer[0], 42); + } + + SUBCASE("Write multiple bytes span") + { + std::vector buffer; + memory_output_stream stream(buffer); + + uint8_t data[] = {1, 2, 3, 4, 5}; + std::span span(data, 5); + + auto written = stream.write(span); + + CHECK_EQ(written, 5); + CHECK_EQ(stream.size(), 5); + CHECK_EQ(buffer.size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(buffer[i], i + 1); + } + } + + SUBCASE("Write empty span") + { + std::vector buffer; + memory_output_stream stream(buffer); + + std::span empty_span; + + auto written = stream.write(empty_span); + + CHECK_EQ(written, 0); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(buffer.size(), 0); + } + + SUBCASE("Write single byte (convenience method)") + { + std::vector buffer; + memory_output_stream stream(buffer); + + uint8_t single_byte = 123; + auto written = stream.write(std::span{&single_byte, 1}); + + CHECK_EQ(written, 1); + CHECK_EQ(stream.size(), 1); + CHECK_EQ(buffer.size(), 1); + CHECK_EQ(buffer[0], 123); + } + + SUBCASE("Write value multiple times") + { + std::vector buffer; + memory_output_stream stream(buffer); + + auto written = stream.write(static_cast(255), 3); + + CHECK_EQ(written, 3); + CHECK_EQ(stream.size(), 3); + CHECK_EQ(buffer.size(), 3); + CHECK_EQ(buffer[0], 255); + CHECK_EQ(buffer[1], 255); + CHECK_EQ(buffer[2], 255); + } + + SUBCASE("Write value zero times") + { + std::vector buffer; + memory_output_stream stream(buffer); + + auto written = stream.write(static_cast(42), 0); + + CHECK_EQ(written, 0); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(buffer.size(), 0); + } + } + + TEST_CASE("sequential writes") + { + std::vector buffer; + memory_output_stream stream(buffer); + + // First write + uint8_t data1[] = {10, 20, 30}; + std::span span1(data1, 3); + auto written1 = stream.write(span1); + + CHECK_EQ(written1, 3); + CHECK_EQ(stream.size(), 3); + + // Second write + uint8_t data2[] = {40, 50}; + std::span span2(data2, 2); + auto written2 = stream.write(span2); + + CHECK_EQ(written2, 2); + CHECK_EQ(stream.size(), 5); + + // Third write with repeated value + auto written3 = stream.write(static_cast(60), 2); + + CHECK_EQ(written3, 2); + CHECK_EQ(stream.size(), 7); + + // Verify final buffer content + std::vector expected = {10, 20, 30, 40, 50, 60, 60}; + CHECK_EQ(buffer, expected); + } + + TEST_CASE("reserve functionality") + { + std::vector buffer; + memory_output_stream stream(buffer); + + // Reserve space + stream.reserve(100); + + // Buffer should have reserved capacity but size should remain 0 + CHECK_GE(buffer.capacity(), 100); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(buffer.size(), 0); + + // Writing should work normally after reserve + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_EQ(stream.size(), 3); + CHECK_EQ(buffer.size(), 3); + } + + TEST_CASE("add_padding functionality") + { + std::vector buffer; + memory_output_stream stream(buffer); + + SUBCASE("No padding needed when size is multiple of 8") + { + // Write 8 bytes + uint8_t data[] = {1, 2, 3, 4, 5, 6, 7, 8}; + std::span span(data, 8); + stream.write(span); + + auto size_before = stream.size(); + stream.add_padding(); + + CHECK_EQ(stream.size(), size_before); + CHECK_EQ(buffer.size(), 8); + } + + SUBCASE("Padding needed when size is not multiple of 8") + { + // Write 5 bytes + uint8_t data[] = {1, 2, 3, 4, 5}; + std::span span(data, 5); + stream.write(span); + + stream.add_padding(); + + CHECK_EQ(stream.size(), 8); // Should be padded to next multiple of 8 + CHECK_EQ(buffer.size(), 8); + + // Check padding bytes are zero + CHECK_EQ(buffer[5], 0); + CHECK_EQ(buffer[6], 0); + CHECK_EQ(buffer[7], 0); + } + + SUBCASE("Padding for different sizes") + { + // Test various sizes and their expected padding + std::vector> test_cases = { + {0, 0}, // 0 -> 0 (no padding needed) + {1, 7}, // 1 -> 8 (7 padding bytes) + {2, 6}, // 2 -> 8 (6 padding bytes) + {3, 5}, // 3 -> 8 (5 padding bytes) + {4, 4}, // 4 -> 8 (4 padding bytes) + {5, 3}, // 5 -> 8 (3 padding bytes) + {6, 2}, // 6 -> 8 (2 padding bytes) + {7, 1}, // 7 -> 8 (1 padding byte) + {8, 0}, // 8 -> 8 (no padding needed) + {9, 7}, // 9 -> 16 (7 padding bytes) + }; + + for (const auto& [initial_size, expected_padding] : test_cases) + { + std::vector test_buffer; + memory_output_stream test_stream(test_buffer); + + // Write initial_size bytes + if (initial_size > 0) + { + std::vector data(initial_size, 42); + std::span span(data); + test_stream.write(span); + } + + auto size_before = test_stream.size(); + test_stream.add_padding(); + auto size_after = test_stream.size(); + + CHECK_EQ(size_before, initial_size); + CHECK_EQ(size_after - size_before, expected_padding); + CHECK_EQ(size_after % 8, 0); // Should always be multiple of 8 + } + } + } + + TEST_CASE("stream lifecycle") + { + std::vector buffer; + memory_output_stream stream(buffer); + + SUBCASE("Stream is initially open") + { + CHECK(stream.is_open()); + } + + SUBCASE("Flush operation") + { + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + // Flush should not throw or change state for memory stream + CHECK_NOTHROW(stream.flush()); + CHECK(stream.is_open()); + CHECK_EQ(stream.size(), 3); + } + + SUBCASE("Close operation") + { + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + // Close should not throw for memory stream + CHECK_NOTHROW(stream.close()); + CHECK(stream.is_open()); // Memory stream should remain open + CHECK_EQ(stream.size(), 3); + } + } + + TEST_CASE("large data handling") + { + std::vector buffer; + memory_output_stream stream(buffer); + + // Write a large amount of data + const size_t large_size = 10000; + std::vector large_data(large_size); + std::iota(large_data.begin(), large_data.end(), 0); // Fill with 0, 1, 2, ... + + std::span span(large_data); + auto written = stream.write(span); + + CHECK_EQ(written, large_size); + CHECK_EQ(stream.size(), large_size); + CHECK_EQ(buffer.size(), large_size); + + // Verify data integrity + for (size_t i = 0; i < large_size; ++i) + { + CHECK_EQ(buffer[i], static_cast(i)); + } + } + + TEST_CASE("edge cases") + { + SUBCASE("Maximum value repeated writes") + { + std::vector buffer; + memory_output_stream stream(buffer); + + auto written = stream.write(std::numeric_limits::max(), 255); + + CHECK_EQ(written, 255); + CHECK_EQ(stream.size(), 255); + for (size_t i = 0; i < 255; ++i) + { + CHECK_EQ(buffer[i], std::numeric_limits::max()); + } + } + + SUBCASE("Zero byte repeated writes") + { + std::vector buffer; + memory_output_stream stream(buffer); + + auto written = stream.write(static_cast(0), 100); + + CHECK_EQ(written, 100); + CHECK_EQ(stream.size(), 100); + for (size_t i = 0; i < 100; ++i) + { + CHECK_EQ(buffer[i], 0); + } + } + } + + TEST_CASE("different container types") + { + SUBCASE("With pre-filled vector") + { + std::vector buffer = {100, 200}; + memory_output_stream stream(buffer); + + CHECK_EQ(stream.size(), 2); + + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_EQ(stream.size(), 5); + std::vector expected = {100, 200, 1, 2, 3}; + CHECK_EQ(buffer, expected); + } + } + } +} \ No newline at end of file diff --git a/tests/test_serialize_utils.cpp b/tests/test_serialize_utils.cpp index 2997843..e10eb98 100644 --- a/tests/test_serialize_utils.cpp +++ b/tests/test_serialize_utils.cpp @@ -1,7 +1,10 @@ +#include + #include #include #include "sparrow_ipc/magic_values.hpp" +#include "sparrow_ipc/memory_output_stream.hpp" #include "sparrow_ipc/serialize_utils.hpp" #include "sparrow_ipc/utils.hpp" #include "sparrow_ipc_tests_helpers.hpp" @@ -10,355 +13,317 @@ namespace sparrow_ipc { namespace sp = sparrow; - TEST_CASE("create_metadata") + TEST_SUITE("serialize_utils") { - flatbuffers::FlatBufferBuilder builder; - - SUBCASE("No metadata (nullptr)") - { - auto schema = create_test_arrow_schema("i"); - auto metadata_offset = create_metadata(builder, schema); - CHECK_EQ(metadata_offset.o, 0); - } - - SUBCASE("With metadata - basic test") + TEST_CASE("serialize_schema_message") { - auto schema = create_test_arrow_schema_with_metadata("i"); - auto metadata_offset = create_metadata(builder, schema); - // For now just check that it doesn't crash - // TODO: Add proper metadata testing when sparrow metadata is properly handled - } - } - - TEST_CASE("create_field") - { - flatbuffers::FlatBufferBuilder builder; - - SUBCASE("Basic field creation") - { - auto schema = create_test_arrow_schema("i", "int_field", true); - auto field_offset = create_field(builder, schema); - CHECK_NE(field_offset.o, 0); - } - - SUBCASE("Field with null name") - { - auto schema = create_test_arrow_schema("i", nullptr, false); - auto field_offset = create_field(builder, schema); - CHECK_NE(field_offset.o, 0); + SUBCASE("Valid record batch") + { + std::vector serialized; + memory_output_stream stream(serialized); + auto record_batch = create_test_record_batch(); + serialize_schema_message(record_batch, stream); + + CHECK_GT(serialized.size(), 0); + + // Check that it starts with continuation bytes + CHECK_EQ(serialized.size() >= continuation.size(), true); + for (size_t i = 0; i < continuation.size(); ++i) + { + CHECK_EQ(serialized[i], continuation[i]); + } + + // Check that the total size is aligned to 8 bytes + CHECK_EQ(serialized.size() % 8, 0); + } } - SUBCASE("Non-nullable field") + TEST_CASE("fill_body") { - auto schema = create_test_arrow_schema("i", "int_field", false); - auto field_offset = create_field(builder, schema); - CHECK_NE(field_offset.o, 0); + SUBCASE("Simple primitive array") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + std::vector body; + sparrow_ipc::memory_output_stream stream(body); + fill_body(proxy, stream); + CHECK_GT(body.size(), 0); + // Body size should be aligned + CHECK_EQ(body.size() % 8, 0); + } } - } - TEST_CASE("create_children from ArrowSchema") - { - flatbuffers::FlatBufferBuilder builder; - - SUBCASE("No children") + TEST_CASE("generate_body") { - auto schema = create_test_arrow_schema("i"); - auto children_offset = create_children(builder, schema); - CHECK_EQ(children_offset.o, 0); + SUBCASE("Record batch with multiple columns") + { + auto record_batch = create_test_record_batch(); + std::vector serialized; + memory_output_stream stream(serialized); + generate_body(record_batch, stream); + CHECK_GT(serialized.size(), 0); + CHECK_EQ(serialized.size() % 8, 0); + } } - SUBCASE("With children") + TEST_CASE("calculate_body_size") { - auto parent_schema = create_test_arrow_schema("+s"); - auto child1 = new ArrowSchema(create_test_arrow_schema("i", "child1")); - auto child2 = new ArrowSchema(create_test_arrow_schema("u", "child2")); - - ArrowSchema* children[] = {child1, child2}; - parent_schema.children = children; - parent_schema.n_children = 2; - - auto children_offset = create_children(builder, parent_schema); - CHECK_NE(children_offset.o, 0); - - // Clean up - delete child1; - delete child2; - } + SUBCASE("Single array") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); - SUBCASE("Null child pointer throws exception") - { - auto parent_schema = create_test_arrow_schema("+s"); - ArrowSchema* children[] = {nullptr}; - parent_schema.children = children; - parent_schema.n_children = 1; + auto size = calculate_body_size(proxy); + CHECK_GT(size, 0); + CHECK_EQ(size % 8, 0); + } - CHECK_THROWS_AS(create_children(builder, parent_schema), std::invalid_argument); + SUBCASE("Record batch") + { + auto record_batch = create_test_record_batch(); + auto size = calculate_body_size(record_batch); + CHECK_GT(size, 0); + CHECK_EQ(size % 8, 0); + std::vector serialized; + memory_output_stream stream(serialized); + generate_body(record_batch, stream); + CHECK_EQ(size, static_cast(serialized.size())); + } } - } - TEST_CASE("create_children from record_batch columns") - { - flatbuffers::FlatBufferBuilder builder; - - SUBCASE("With valid record batch") + TEST_CASE("calculate_schema_message_size") { - auto record_batch = create_test_record_batch(); - auto children_offset = create_children(builder, record_batch.columns()); - CHECK_NE(children_offset.o, 0); - } + SUBCASE("Single column record batch") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto record_batch = sp::record_batch({{"column1", sp::array(std::move(array))}}); - SUBCASE("Empty record batch") - { - auto empty_batch = sp::record_batch({}); + const auto estimated_size = calculate_schema_message_size(record_batch); + CHECK_GT(estimated_size, 0); + CHECK_EQ(estimated_size % 8, 0); - auto children_offset = create_children(builder, empty_batch.columns()); - CHECK_EQ(children_offset.o, 0); - } - } + // Verify by actual serialization + std::vector serialized; + memory_output_stream stream(serialized); + serialize_schema_message(record_batch, stream); - TEST_CASE("get_schema_message_builder") - { - SUBCASE("Valid record batch") - { - auto record_batch = create_test_record_batch(); - auto builder = get_schema_message_builder(record_batch); + CHECK_EQ(estimated_size, serialized.size()); + } - CHECK_GT(builder.GetSize(), 0); - CHECK_NE(builder.GetBufferPointer(), nullptr); - } - } + SUBCASE("Multi-column record batch") + { + auto record_batch = create_test_record_batch(); - TEST_CASE("serialize_schema_message") - { - SUBCASE("Valid record batch") - { - auto record_batch = create_test_record_batch(); - auto serialized = serialize_schema_message(record_batch); + auto estimated_size = calculate_schema_message_size(record_batch); + CHECK_GT(estimated_size, 0); + CHECK_EQ(estimated_size % 8, 0); - CHECK_GT(serialized.size(), 0); + std::vector serialized; + memory_output_stream stream(serialized); + serialize_schema_message(record_batch, stream); - // Check that it starts with continuation bytes - CHECK_EQ(serialized.size() >= continuation.size(), true); - for (size_t i = 0; i < continuation.size(); ++i) - { - CHECK_EQ(serialized[i], continuation[i]); + CHECK_EQ(estimated_size, serialized.size()); } - - // Check that the total size is aligned to 8 bytes - CHECK_EQ(serialized.size() % 8, 0); } - } - TEST_CASE("fill_fieldnodes") - { - SUBCASE("Single array without children") + TEST_CASE("calculate_record_batch_message_size") { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); - - std::vector nodes; - fill_fieldnodes(proxy, nodes); - - CHECK_EQ(nodes.size(), 1); - CHECK_EQ(nodes[0].length(), 5); - CHECK_EQ(nodes[0].null_count(), 0); - } + SUBCASE("Single column record batch") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto record_batch = sp::record_batch({{"column1", sp::array(std::move(array))}}); - SUBCASE("Array with null values") - { - // For now, just test with a simple array without explicit nulls - // Creating arrays with null values requires more complex sparrow setup - auto array = sp::primitive_array({1, 2, 3}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); + auto estimated_size = calculate_record_batch_message_size(record_batch); + CHECK_GT(estimated_size, 0); + CHECK_EQ(estimated_size % 8, 0); - std::vector nodes; - fill_fieldnodes(proxy, nodes); + std::vector serialized; + memory_output_stream stream(serialized); + serialize_record_batch(record_batch, stream); - CHECK_EQ(nodes.size(), 1); - CHECK_EQ(nodes[0].length(), 3); - CHECK_EQ(nodes[0].null_count(), 0); - } - } + CHECK_EQ(estimated_size, serialized.size()); + } - TEST_CASE("create_fieldnodes") - { - SUBCASE("Record batch with multiple columns") - { - auto record_batch = create_test_record_batch(); - auto nodes = create_fieldnodes(record_batch); + SUBCASE("Multi-column record batch") + { + auto record_batch = create_test_record_batch(); - CHECK_EQ(nodes.size(), 2); // Two columns + auto estimated_size = calculate_record_batch_message_size(record_batch); + CHECK_GT(estimated_size, 0); + CHECK_EQ(estimated_size % 8, 0); - // Check the first column (integer array) - CHECK_EQ(nodes[0].length(), 5); - CHECK_EQ(nodes[0].null_count(), 0); + // Verify by actual serialization + std::vector serialized; + memory_output_stream stream(serialized); + serialize_record_batch(record_batch, stream); - // Check the second column (string array) - CHECK_EQ(nodes[1].length(), 5); - CHECK_EQ(nodes[1].null_count(), 0); + CHECK_EQ(estimated_size, serialized.size()); + } } - } - TEST_CASE("fill_buffers") - { - SUBCASE("Simple primitive array") + TEST_CASE("calculate_total_serialized_size") { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); + SUBCASE("Single record batch") + { + auto record_batch = create_test_record_batch(); + std::vector batches = {record_batch}; - std::vector buffers; - int64_t offset = 0; - fill_buffers(proxy, buffers, offset); + auto estimated_size = calculate_total_serialized_size(batches); + CHECK_GT(estimated_size, 0); - CHECK_GT(buffers.size(), 0); - CHECK_GT(offset, 0); + // Should equal schema size + record batch size + auto schema_size = calculate_schema_message_size(record_batch); + auto batch_size = calculate_record_batch_message_size(record_batch); + CHECK_EQ(estimated_size, schema_size + batch_size); + } - // Verify offsets are aligned - for (const auto& buffer : buffers) + SUBCASE("Multiple record batches") { - CHECK_EQ(buffer.offset() % 8, 0); + auto array1 = sp::primitive_array({1, 2, 3}); + auto array2 = sp::primitive_array({1.0, 2.0, 3.0}); + auto record_batch1 = sp::record_batch( + {{"col1", sp::array(std::move(array1))}, {"col2", sp::array(std::move(array2))}} + ); + + auto array3 = sp::primitive_array({4, 5, 6}); + auto array4 = sp::primitive_array({4.0, 5.0, 6.0}); + auto record_batch2 = sp::record_batch( + {{"col1", sp::array(std::move(array3))}, {"col2", sp::array(std::move(array4))}} + ); + + std::vector batches = {record_batch1, record_batch2}; + + auto estimated_size = calculate_total_serialized_size(batches); + CHECK_GT(estimated_size, 0); + + // Should equal schema size + sum of record batch sizes + auto schema_size = calculate_schema_message_size(batches[0]); + auto batch1_size = calculate_record_batch_message_size(batches[0]); + auto batch2_size = calculate_record_batch_message_size(batches[1]); + CHECK_EQ(estimated_size, schema_size + batch1_size + batch2_size); } - } - } - TEST_CASE("get_buffers") - { - SUBCASE("Record batch with multiple columns") - { - auto record_batch = create_test_record_batch(); - auto buffers = get_buffers(record_batch); - CHECK_GT(buffers.size(), 0); - // Verify all offsets are properly calculated and aligned - for (size_t i = 1; i < buffers.size(); ++i) + SUBCASE("Empty collection") { - CHECK_GE(buffers[i].offset(), buffers[i - 1].offset() + buffers[i - 1].length()); + std::vector empty_batches; + auto estimated_size = calculate_total_serialized_size(empty_batches); + CHECK_EQ(estimated_size, 0); } - } - } - - TEST_CASE("fill_body") - { - SUBCASE("Simple primitive array") - { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); - std::vector body; - fill_body(proxy, body); - CHECK_GT(body.size(), 0); - // Body size should be aligned - CHECK_EQ(body.size() % 8, 0); - } - } - - TEST_CASE("generate_body") - { - SUBCASE("Record batch with multiple columns") - { - auto record_batch = create_test_record_batch(); - auto body = generate_body(record_batch); - CHECK_GT(body.size(), 0); - CHECK_EQ(body.size() % 8, 0); - } - } - - TEST_CASE("calculate_body_size") - { - SUBCASE("Single array") - { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); + SUBCASE("Inconsistent schemas throw exception") + { + auto array1 = sp::primitive_array({1, 2, 3}); + auto record_batch1 = sp::record_batch({{"col1", sp::array(std::move(array1))}}); - auto size = calculate_body_size(proxy); - CHECK_GT(size, 0); - CHECK_EQ(size % 8, 0); - } + auto array2 = sp::primitive_array({1.0, 2.0, 3.0}); + auto record_batch2 = sp::record_batch( + {{"col2", sp::array(std::move(array2))}} // Different column name + ); - SUBCASE("Record batch") - { - auto record_batch = create_test_record_batch(); - auto size = calculate_body_size(record_batch); - CHECK_GT(size, 0); - CHECK_EQ(size % 8, 0); - auto body = generate_body(record_batch); - CHECK_EQ(size, static_cast(body.size())); - } - } + std::vector batches = {record_batch1, record_batch2}; - TEST_CASE("get_record_batch_message_builder") - { - SUBCASE("Valid record batch with field nodes and buffers") - { - auto record_batch = create_test_record_batch(); - auto nodes = create_fieldnodes(record_batch); - auto buffers = get_buffers(record_batch); - auto builder = get_record_batch_message_builder(record_batch, nodes, buffers); - CHECK_GT(builder.GetSize(), 0); - CHECK_NE(builder.GetBufferPointer(), nullptr); + CHECK_THROWS_AS(auto size = calculate_total_serialized_size(batches), std::invalid_argument); + } } - } - TEST_CASE("serialize_record_batch") - { - SUBCASE("Valid record batch") + TEST_CASE("memory_reservation_performance") { - auto record_batch = create_test_record_batch(); - auto serialized = serialize_record_batch(record_batch); - CHECK_GT(serialized.size(), 0); - - // Check that it starts with continuation bytes - CHECK_GE(serialized.size(), continuation.size()); - for (size_t i = 0; i < continuation.size(); ++i) + SUBCASE("Large record batch benefits from size estimation") { - CHECK_EQ(serialized[i], continuation[i]); + // Create a larger record batch for testing memory reservation + std::vector large_data; + large_data.reserve(10000); + for (int i = 0; i < 10000; ++i) + { + large_data.push_back(i); + } + + auto array = sp::primitive_array(large_data); + auto record_batch = sp::record_batch({{"large_column", sp::array(std::move(array))}}); + + // Test with size estimation (current implementation) + std::vector with_estimation; + memory_output_stream stream_with_estimation(with_estimation); + + // Reserve memory based on calculated size to avoid reallocations + const std::size_t total_size = calculate_record_batch_message_size(record_batch); + stream_with_estimation.reserve(stream_with_estimation.size() + total_size); + + auto start_time = std::chrono::high_resolution_clock::now(); + serialize_record_batch(record_batch, stream_with_estimation); + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration_with_estimation = end_time - start_time; + + // Verify size estimation accuracy + auto estimated_size = calculate_record_batch_message_size(record_batch); + CHECK_EQ(estimated_size, with_estimation.size()); + + // The serialization should complete successfully + CHECK_GT(with_estimation.size(), 0); + CHECK_EQ(with_estimation.size() % 8, 0); + + // Test without size estimation (no pre-reservation) + std::vector without_estimation; + memory_output_stream stream_without_estimation(without_estimation); + auto start_time_no_est = std::chrono::high_resolution_clock::now(); + serialize_record_batch(record_batch, stream_without_estimation); + auto end_time_no_est = std::chrono::high_resolution_clock::now(); + auto duration_without_estimation = end_time_no_est - start_time_no_est; + DOCTEST_MESSAGE( + "With estimation: " + << std::chrono::duration_cast(duration_with_estimation).count() + << " us, Without estimation: " + << std::chrono::duration_cast(duration_without_estimation).count() + << " us" + ); } - - // Check that the metadata part is aligned to 8 bytes - // Find the end of metadata (before body starts) - size_t continuation_size = continuation.size(); - size_t length_prefix_size = sizeof(uint32_t); - - CHECK_GT(serialized.size(), continuation_size + length_prefix_size); - - // Extract message length - uint32_t message_length; - std::memcpy(&message_length, serialized.data() + continuation_size, sizeof(uint32_t)); - - size_t metadata_end = continuation_size + length_prefix_size + message_length; - size_t aligned_metadata_end = utils::align_to_8(static_cast(metadata_end)); - - // Verify alignment - CHECK_EQ(aligned_metadata_end % 8, 0); - CHECK_LE(aligned_metadata_end, serialized.size()); } - SUBCASE("Empty record batch") - { - auto empty_batch = sp::record_batch({}); - auto serialized = serialize_record_batch(empty_batch); - CHECK_GT(serialized.size(), 0); - CHECK_GE(serialized.size(), continuation.size()); - } - } - - TEST_CASE("Integration test - schema and record batch serialization") - { - SUBCASE("Serialize schema and record batch for same data") + TEST_CASE("serialize_record_batch") { - auto record_batch = create_test_record_batch(); - - auto schema_serialized = serialize_schema_message(record_batch); - auto record_batch_serialized = serialize_record_batch(record_batch); - - CHECK_GT(schema_serialized.size(), 0); - CHECK_GT(record_batch_serialized.size(), 0); - - // Both should start with continuation bytes - CHECK_GE(schema_serialized.size(), continuation.size()); - CHECK_GE(record_batch_serialized.size(), continuation.size()); + SUBCASE("Valid record batch") + { + auto record_batch = create_test_record_batch(); + std::vector serialized; + memory_output_stream stream(serialized); + serialize_record_batch(record_batch, stream); + CHECK_GT(serialized.size(), 0); + + // Check that it starts with continuation bytes + CHECK_GE(serialized.size(), continuation.size()); + for (size_t i = 0; i < continuation.size(); ++i) + { + CHECK_EQ(serialized[i], continuation[i]); + } + + // Check that the metadata part is aligned to 8 bytes + // Find the end of metadata (before body starts) + size_t continuation_size = continuation.size(); + size_t length_prefix_size = sizeof(uint32_t); + + CHECK_GT(serialized.size(), continuation_size + length_prefix_size); + + // Extract message length + uint32_t message_length; + std::memcpy(&message_length, serialized.data() + continuation_size, sizeof(uint32_t)); + + size_t metadata_end = continuation_size + length_prefix_size + message_length; + size_t aligned_metadata_end = utils::align_to_8(static_cast(metadata_end)); + + // Verify alignment + CHECK_EQ(aligned_metadata_end % 8, 0); + CHECK_LE(aligned_metadata_end, serialized.size()); + } - // Both should be properly aligned - CHECK_EQ(schema_serialized.size() % 8, 0); + SUBCASE("Empty record batch") + { + auto empty_batch = sp::record_batch({}); + std::vector serialized; + memory_output_stream stream(serialized); + serialize_record_batch(empty_batch, stream); + CHECK_GT(serialized.size(), 0); + CHECK_GE(serialized.size(), continuation.size()); + } } } } \ No newline at end of file diff --git a/tests/test_utils.cpp b/tests/test_utils.cpp index ab9f4a0..0619d68 100644 --- a/tests/test_utils.cpp +++ b/tests/test_utils.cpp @@ -1,9 +1,8 @@ #include -#include -#include "sparrow_ipc/arrow_interface/arrow_array_schema_common_release.hpp" #include "sparrow_ipc/utils.hpp" + namespace sparrow_ipc { TEST_CASE("align_to_8") @@ -16,334 +15,4 @@ namespace sparrow_ipc CHECK_EQ(utils::align_to_8(15), 16); CHECK_EQ(utils::align_to_8(16), 16); } - - TEST_CASE("get_flatbuffer_type") - { - flatbuffers::FlatBufferBuilder builder; - SUBCASE("Null and Boolean types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::NA)).first, - org::apache::arrow::flatbuf::Type::Null - ); - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::BOOL)).first, - org::apache::arrow::flatbuf::Type::Bool - ); - } - - SUBCASE("Integer types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT8)).first, - org::apache::arrow::flatbuf::Type::Int - ); // INT8 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT8)).first, - org::apache::arrow::flatbuf::Type::Int - ); // UINT8 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT16)).first, - org::apache::arrow::flatbuf::Type::Int - ); // INT16 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT16)).first, - org::apache::arrow::flatbuf::Type::Int - ); // UINT16 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT32)).first, - org::apache::arrow::flatbuf::Type::Int - ); // INT32 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT32)).first, - org::apache::arrow::flatbuf::Type::Int - ); // UINT32 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT64)).first, - org::apache::arrow::flatbuf::Type::Int - ); // INT64 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT64)).first, - org::apache::arrow::flatbuf::Type::Int - ); // UINT64 - } - - SUBCASE("Floating Point types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::HALF_FLOAT)) - .first, - org::apache::arrow::flatbuf::Type::FloatingPoint - ); // HALF_FLOAT - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::FLOAT)).first, - org::apache::arrow::flatbuf::Type::FloatingPoint - ); // FLOAT - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DOUBLE)).first, - org::apache::arrow::flatbuf::Type::FloatingPoint - ); // DOUBLE - } - - SUBCASE("String and Binary types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::STRING)).first, - org::apache::arrow::flatbuf::Type::Utf8 - ); // STRING - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_STRING)) - .first, - org::apache::arrow::flatbuf::Type::LargeUtf8 - ); // LARGE_STRING - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::BINARY)).first, - org::apache::arrow::flatbuf::Type::Binary - ); // BINARY - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_BINARY)) - .first, - org::apache::arrow::flatbuf::Type::LargeBinary - ); // LARGE_BINARY - CHECK_EQ( - utils::get_flatbuffer_type(builder, "vu").first, - org::apache::arrow::flatbuf::Type::Utf8View - ); // STRING_VIEW - CHECK_EQ( - utils::get_flatbuffer_type(builder, "vz").first, - org::apache::arrow::flatbuf::Type::BinaryView - ); // BINARY_VIEW - } - - SUBCASE("Date types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DATE_DAYS)) - .first, - org::apache::arrow::flatbuf::Type::Date - ); // DATE_DAYS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DATE_MILLISECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Date - ); // DATE_MILLISECONDS - } - - SUBCASE("Timestamp types") - { - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_SECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Timestamp - ); // TIMESTAMP_SECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_MILLISECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Timestamp - ); // TIMESTAMP_MILLISECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_MICROSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Timestamp - ); // TIMESTAMP_MICROSECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_NANOSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Timestamp - ); // TIMESTAMP_NANOSECONDS - } - - SUBCASE("Duration types") - { - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DURATION_SECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Duration - ); // DURATION_SECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DURATION_MILLISECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Duration - ); // DURATION_MILLISECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DURATION_MICROSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Duration - ); // DURATION_MICROSECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DURATION_NANOSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Duration - ); // DURATION_NANOSECONDS - } - - SUBCASE("Interval types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INTERVAL_MONTHS)) - .first, - org::apache::arrow::flatbuf::Type::Interval - ); // INTERVAL_MONTHS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::INTERVAL_DAYS_TIME) - ) - .first, - org::apache::arrow::flatbuf::Type::Interval - ); // INTERVAL_DAYS_TIME - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Interval - ); // INTERVAL_MONTHS_DAYS_NANOSECONDS - } - - SUBCASE("Time types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_SECONDS)) - .first, - org::apache::arrow::flatbuf::Type::Time - ); // TIME_SECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIME_MILLISECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Time - ); // TIME_MILLISECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIME_MICROSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Time - ); // TIME_MICROSECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIME_NANOSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Time - ); // TIME_NANOSECONDS - } - - SUBCASE("List types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LIST)).first, - org::apache::arrow::flatbuf::Type::List - ); // LIST - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_LIST)) - .first, - org::apache::arrow::flatbuf::Type::LargeList - ); // LARGE_LIST - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+vl").first, - org::apache::arrow::flatbuf::Type::ListView - ); // LIST_VIEW - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+vL").first, - org::apache::arrow::flatbuf::Type::LargeListView - ); // LARGE_LIST_VIEW - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+w:16").first, - org::apache::arrow::flatbuf::Type::FixedSizeList - ); // FIXED_SIZED_LIST - CHECK_THROWS(utils::get_flatbuffer_type(builder, "+w:")); // Invalid FixedSizeList format - } - - SUBCASE("Struct and Map types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::STRUCT)).first, - org::apache::arrow::flatbuf::Type::Struct_ - ); // STRUCT - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::MAP)).first, - org::apache::arrow::flatbuf::Type::Map - ); // MAP - } - - SUBCASE("Union types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+ud:").first, - org::apache::arrow::flatbuf::Type::Union - ); // DENSE_UNION - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+us:").first, - org::apache::arrow::flatbuf::Type::Union - ); // SPARSE_UNION - } - - SUBCASE("Run-End Encoded type") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+r").first, - org::apache::arrow::flatbuf::Type::RunEndEncoded - ); // RUN_ENCODED - } - - SUBCASE("Decimal types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "d:10,5").first, - org::apache::arrow::flatbuf::Type::Decimal - ); // DECIMAL (general) - CHECK_THROWS(utils::get_flatbuffer_type(builder, "d:10")); // Invalid Decimal format - } - - SUBCASE("Fixed Width Binary type") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "w:32").first, - org::apache::arrow::flatbuf::Type::FixedSizeBinary - ); // FIXED_WIDTH_BINARY - CHECK_THROWS(utils::get_flatbuffer_type(builder, "w:")); // Invalid FixedSizeBinary format - } - - SUBCASE("Unsupported type returns Null") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "unsupported_format").first, - org::apache::arrow::flatbuf::Type::Null - ); - } - } } From c4d42eb6feda9e64c71135e6fa5f54231c4be42f Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Fri, 3 Oct 2025 16:04:23 +0200 Subject: [PATCH 02/11] fix --- src/serialize_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serialize_utils.cpp b/src/serialize_utils.cpp index 5bcbca8..96afbfe 100644 --- a/src/serialize_utils.cpp +++ b/src/serialize_utils.cpp @@ -46,7 +46,7 @@ namespace sparrow_ipc return std::reduce( record_batch.columns().begin(), record_batch.columns().end(), - 0, + int64_t{0}, [](int64_t acc, const sparrow::array& arr) { const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(arr); From 4785a7bfefe85f4e8737def71a3eda1c7b8f7114 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Fri, 3 Oct 2025 16:47:14 +0200 Subject: [PATCH 03/11] fix --- src/serialize_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serialize_utils.cpp b/src/serialize_utils.cpp index 96afbfe..377779a 100644 --- a/src/serialize_utils.cpp +++ b/src/serialize_utils.cpp @@ -43,7 +43,7 @@ namespace sparrow_ipc int64_t calculate_body_size(const sparrow::record_batch& record_batch) { - return std::reduce( + return std::accumulate( record_batch.columns().begin(), record_batch.columns().end(), int64_t{0}, From b3beb9a960217cd43e2e05b782c45039ee41dd41 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Mon, 6 Oct 2025 12:23:22 +0200 Subject: [PATCH 04/11] wip --- include/sparrow_ipc/chunk_memory_serializer.hpp | 3 ++- include/sparrow_ipc/serialize_utils.hpp | 4 ++-- include/sparrow_ipc/serializer.hpp | 2 +- include/sparrow_ipc/utils.hpp | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/sparrow_ipc/chunk_memory_serializer.hpp b/include/sparrow_ipc/chunk_memory_serializer.hpp index a897354..05e7bf2 100644 --- a/include/sparrow_ipc/chunk_memory_serializer.hpp +++ b/include/sparrow_ipc/chunk_memory_serializer.hpp @@ -6,10 +6,11 @@ #include "sparrow_ipc/memory_output_stream.hpp" #include "sparrow_ipc/serialize.hpp" #include "sparrow_ipc/serialize_utils.hpp" +#include "sparrow_ipc/config/config.hpp" namespace sparrow_ipc { - class chunk_serializer + class SPARROW_IPC_API chunk_serializer { public: diff --git a/include/sparrow_ipc/serialize_utils.hpp b/include/sparrow_ipc/serialize_utils.hpp index 0ae9832..ad1ace7 100644 --- a/include/sparrow_ipc/serialize_utils.hpp +++ b/include/sparrow_ipc/serialize_utils.hpp @@ -177,7 +177,7 @@ namespace sparrow_ipc * * @param stream The output stream where padding bytes will be added */ - void add_padding(output_stream& stream); + SPARROW_IPC_API void add_padding(output_stream& stream); - std::vector get_column_dtypes(const sparrow::record_batch& rb); + SPARROW_IPC_API std::vector get_column_dtypes(const sparrow::record_batch& rb); } diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp index f4e0fd5..fe54b4a 100644 --- a/include/sparrow_ipc/serializer.hpp +++ b/include/sparrow_ipc/serializer.hpp @@ -29,7 +29,7 @@ namespace sparrow_ipc * - Stream reservation to minimize memory reallocations * - Lazy evaluation of size calculations using lambda functions */ - class serializer + class SPARROW_IPC_API serializer { public: diff --git a/include/sparrow_ipc/utils.hpp b/include/sparrow_ipc/utils.hpp index 3da1e54..63f1fb8 100644 --- a/include/sparrow_ipc/utils.hpp +++ b/include/sparrow_ipc/utils.hpp @@ -11,7 +11,7 @@ namespace sparrow_ipc::utils { // Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies - size_t align_to_8(const size_t n); + SPARROW_IPC_API size_t align_to_8(const size_t n); /** * @brief Checks if all record batches in a collection have consistent structure. From 08e01838d9642cf3787ecb21fbc4274c1dec4814 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Mon, 6 Oct 2025 13:27:24 +0200 Subject: [PATCH 05/11] fix --- tests/CMakeLists.txt | 2 +- tests/test_flatbuffer_utils.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2a8252f..b489bbc 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,7 +11,7 @@ set(SPARROW_IPC_TESTS_SRC test_chunk_memory_serializer.cpp test_de_serialization_with_files.cpp test_file_output_stream.cpp - test_flatbuffer_utils.cpp + $<$>:test_flatbuffer_utils.cpp> test_memory_output_streams.cpp test_serialize_utils.cpp test_utils.cpp diff --git a/tests/test_flatbuffer_utils.cpp b/tests/test_flatbuffer_utils.cpp index 02f97cd..fd48410 100644 --- a/tests/test_flatbuffer_utils.cpp +++ b/tests/test_flatbuffer_utils.cpp @@ -509,7 +509,7 @@ namespace sparrow_ipc get_flatbuffer_type(builder, "w:32").first, org::apache::arrow::flatbuf::Type::FixedSizeBinary ); // FIXED_WIDTH_BINARY - CHECK_THROWS(get_flatbuffer_type(builder, "w:")); // Invalid FixedSizeBinary format + CHECK_THROWS(static_cast(get_flatbuffer_type(builder, "w:"))); // Invalid FixedSizeBinary format } SUBCASE("Unsupported type returns Null") From 36fe23669a699ad4023c4b77a203a71a3a0e1c4e Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Mon, 6 Oct 2025 15:56:38 +0200 Subject: [PATCH 06/11] Update include/sparrow_ipc/serializer.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- include/sparrow_ipc/serializer.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp index fe54b4a..b571c65 100644 --- a/include/sparrow_ipc/serializer.hpp +++ b/include/sparrow_ipc/serializer.hpp @@ -109,6 +109,10 @@ namespace sparrow_ipc m_pstream->reserve(reserve_function); for (const auto& rb : record_batches) { + if (get_column_dtypes(rb) != m_dtypes) + { + throw std::invalid_argument("Record batch schema does not match serializer schema"); + } serialize_record_batch(rb, *m_pstream); } } From b5b3d97ee61af7110d4502847000068944acc8f1 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Mon, 6 Oct 2025 15:59:23 +0200 Subject: [PATCH 07/11] Fix name --- .../chunk_memory_output_stream.hpp | 4 +- .../sparrow_ipc/chunk_memory_serializer.hpp | 6 +-- src/chunk_memory_serializer.cpp | 2 +- tests/test_chunk_memory_output_stream.cpp | 54 +++++++++---------- tests/test_chunk_memory_serializer.cpp | 34 ++++++------ 5 files changed, 50 insertions(+), 50 deletions(-) diff --git a/include/sparrow_ipc/chunk_memory_output_stream.hpp b/include/sparrow_ipc/chunk_memory_output_stream.hpp index bad94ef..4ad516d 100644 --- a/include/sparrow_ipc/chunk_memory_output_stream.hpp +++ b/include/sparrow_ipc/chunk_memory_output_stream.hpp @@ -11,11 +11,11 @@ namespace sparrow_ipc requires std::ranges::random_access_range && std::ranges::random_access_range> && std::same_as::value_type, uint8_t> - class chuncked_memory_output_stream final : public output_stream + class chunked_memory_output_stream final : public output_stream { public: - explicit chuncked_memory_output_stream(R& chunks) + explicit chunked_memory_output_stream(R& chunks) : m_chunks(&chunks) {}; std::size_t write(std::span span) override diff --git a/include/sparrow_ipc/chunk_memory_serializer.hpp b/include/sparrow_ipc/chunk_memory_serializer.hpp index 05e7bf2..09cd38d 100644 --- a/include/sparrow_ipc/chunk_memory_serializer.hpp +++ b/include/sparrow_ipc/chunk_memory_serializer.hpp @@ -16,14 +16,14 @@ namespace sparrow_ipc chunk_serializer( const sparrow::record_batch& rb, - chuncked_memory_output_stream>>& stream + chunked_memory_output_stream>>& stream ); template requires std::same_as, sparrow::record_batch> chunk_serializer( const R& record_batches, - chuncked_memory_output_stream>>& stream + chunked_memory_output_stream>>& stream ) : m_pstream(&stream) { @@ -68,7 +68,7 @@ namespace sparrow_ipc private: std::vector m_dtypes; - chuncked_memory_output_stream>>* m_pstream; + chunked_memory_output_stream>>* m_pstream; bool m_ended{false}; }; } \ No newline at end of file diff --git a/src/chunk_memory_serializer.cpp b/src/chunk_memory_serializer.cpp index 46e073d..cdf031c 100644 --- a/src/chunk_memory_serializer.cpp +++ b/src/chunk_memory_serializer.cpp @@ -7,7 +7,7 @@ namespace sparrow_ipc { chunk_serializer::chunk_serializer( const sparrow::record_batch& rb, - chuncked_memory_output_stream>>& stream + chunked_memory_output_stream>>& stream ) : m_pstream(&stream) , m_dtypes(get_column_dtypes(rb)) diff --git a/tests/test_chunk_memory_output_stream.cpp b/tests/test_chunk_memory_output_stream.cpp index 5015c4c..22367ba 100644 --- a/tests/test_chunk_memory_output_stream.cpp +++ b/tests/test_chunk_memory_output_stream.cpp @@ -11,14 +11,14 @@ namespace sparrow_ipc { - TEST_SUITE("chuncked_memory_output_stream") + TEST_SUITE("chunked_memory_output_stream") { TEST_CASE("basic construction") { SUBCASE("Construction with empty vector of vectors") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); CHECK(stream.is_open()); CHECK_EQ(stream.size(), 0); @@ -32,7 +32,7 @@ namespace sparrow_ipc {4, 5, 6, 7}, {8, 9} }; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); CHECK(stream.is_open()); CHECK_EQ(stream.size(), 9); @@ -45,7 +45,7 @@ namespace sparrow_ipc SUBCASE("Write single byte span") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); uint8_t data[] = {42}; std::span span(data, 1); @@ -62,7 +62,7 @@ namespace sparrow_ipc SUBCASE("Write multiple bytes span") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); uint8_t data[] = {1, 2, 3, 4, 5}; std::span span(data, 5); @@ -82,7 +82,7 @@ namespace sparrow_ipc SUBCASE("Write empty span") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); std::span empty_span; @@ -97,7 +97,7 @@ namespace sparrow_ipc SUBCASE("Multiple span writes create multiple chunks") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); uint8_t data1[] = {10, 20}; uint8_t data2[] = {30, 40, 50}; @@ -129,7 +129,7 @@ namespace sparrow_ipc SUBCASE("Write moved vector") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); std::vector buffer = {1, 2, 3, 4, 5}; auto written = stream.write(std::move(buffer)); @@ -147,7 +147,7 @@ namespace sparrow_ipc SUBCASE("Write multiple moved vectors") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); std::vector buffer1 = {10, 20, 30}; std::vector buffer2 = {40, 50}; @@ -168,7 +168,7 @@ namespace sparrow_ipc SUBCASE("Write empty moved vector") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); std::vector empty_buffer; auto written = stream.write(std::move(empty_buffer)); @@ -185,7 +185,7 @@ namespace sparrow_ipc SUBCASE("Write value multiple times") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); auto written = stream.write(static_cast(255), 5); @@ -202,7 +202,7 @@ namespace sparrow_ipc SUBCASE("Write value zero times") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); auto written = stream.write(static_cast(42), 0); @@ -215,7 +215,7 @@ namespace sparrow_ipc SUBCASE("Multiple repeated value writes") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); stream.write(static_cast(100), 3); stream.write(static_cast(200), 2); @@ -247,7 +247,7 @@ namespace sparrow_ipc TEST_CASE("mixed write operations") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); // Write span uint8_t data[] = {1, 2, 3}; @@ -282,7 +282,7 @@ namespace sparrow_ipc TEST_CASE("reserve functionality") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); // Reserve space stream.reserve(100); @@ -306,7 +306,7 @@ namespace sparrow_ipc SUBCASE("Size with empty chunks") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); CHECK_EQ(stream.size(), 0); } @@ -318,7 +318,7 @@ namespace sparrow_ipc {4, 5}, {6, 7, 8, 9} }; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); CHECK_EQ(stream.size(), 9); } @@ -326,7 +326,7 @@ namespace sparrow_ipc SUBCASE("Size updates after writes") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); CHECK_EQ(stream.size(), 0); @@ -345,7 +345,7 @@ namespace sparrow_ipc SUBCASE("Size with chunks of varying sizes") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); stream.write(static_cast(1), 1); stream.write(static_cast(2), 10); @@ -360,7 +360,7 @@ namespace sparrow_ipc TEST_CASE("stream lifecycle") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); SUBCASE("Stream is always open") { @@ -405,7 +405,7 @@ namespace sparrow_ipc TEST_CASE("large data handling") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); SUBCASE("Single large chunk") { @@ -473,7 +473,7 @@ namespace sparrow_ipc SUBCASE("Maximum value writes") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); auto written = stream.write(std::numeric_limits::max(), 255); @@ -489,7 +489,7 @@ namespace sparrow_ipc SUBCASE("Zero byte writes") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); auto written = stream.write(static_cast(0), 100); @@ -505,7 +505,7 @@ namespace sparrow_ipc SUBCASE("Interleaved empty and non-empty writes") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); stream.write(static_cast(1), 5); stream.write(static_cast(2), 0); @@ -530,7 +530,7 @@ namespace sparrow_ipc SUBCASE("Stream modifies original chunks vector") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); uint8_t data[] = {1, 2, 3}; stream.write(std::span(data, 3)); @@ -548,13 +548,13 @@ namespace sparrow_ipc std::vector> chunks; { - chuncked_memory_output_stream stream1(chunks); + chunked_memory_output_stream stream1(chunks); uint8_t data1[] = {10, 20}; stream1.write(std::span(data1, 2)); } { - chuncked_memory_output_stream stream2(chunks); + chunked_memory_output_stream stream2(chunks); uint8_t data2[] = {30, 40}; stream2.write(std::span(data2, 2)); } diff --git a/tests/test_chunk_memory_serializer.cpp b/tests/test_chunk_memory_serializer.cpp index a5077ec..cc514c1 100644 --- a/tests/test_chunk_memory_serializer.cpp +++ b/tests/test_chunk_memory_serializer.cpp @@ -19,7 +19,7 @@ namespace sparrow_ipc { auto rb = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb, stream); @@ -34,7 +34,7 @@ namespace sparrow_ipc { auto empty_batch = sp::record_batch({}); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(empty_batch, stream); @@ -61,7 +61,7 @@ namespace sparrow_ipc std::vector record_batches = {rb1, rb2}; std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(record_batches, stream); @@ -76,7 +76,7 @@ namespace sparrow_ipc { std::vector empty_batches; std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); CHECK_THROWS_AS( chunk_serializer serializer(empty_batches, stream), @@ -89,7 +89,7 @@ namespace sparrow_ipc auto rb = create_test_record_batch(); std::vector record_batches = {rb}; std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(record_batches, stream); @@ -104,7 +104,7 @@ namespace sparrow_ipc { auto rb1 = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb1, stream); CHECK_EQ(chunks.size(), 2); // Schema + rb1 @@ -125,7 +125,7 @@ namespace sparrow_ipc { auto rb1 = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb1, stream); @@ -148,7 +148,7 @@ namespace sparrow_ipc { auto rb1 = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb1, stream); CHECK_EQ(chunks.size(), 2); @@ -179,7 +179,7 @@ namespace sparrow_ipc { auto rb1 = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb1, stream); @@ -198,7 +198,7 @@ namespace sparrow_ipc { auto rb1 = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb1, stream); size_t initial_size = chunks.size(); @@ -216,7 +216,7 @@ namespace sparrow_ipc { auto rb = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb, stream); size_t initial_size = chunks.size(); @@ -231,7 +231,7 @@ namespace sparrow_ipc { auto rb1 = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb1, stream); serializer.end(); @@ -244,7 +244,7 @@ namespace sparrow_ipc { auto rb1 = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb1, stream); serializer.end(); @@ -260,7 +260,7 @@ namespace sparrow_ipc { auto rb = create_test_record_batch(); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); size_t size_before = stream.size(); chunk_serializer serializer(rb, stream); @@ -280,7 +280,7 @@ namespace sparrow_ipc SUBCASE("Handle many record batches efficiently") { std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); std::vector batches; const int num_batches = 100; @@ -320,7 +320,7 @@ namespace sparrow_ipc ); std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); chunk_serializer serializer(rb, stream); @@ -339,7 +339,7 @@ namespace sparrow_ipc // Setup chunked stream std::vector> chunks; - chuncked_memory_output_stream stream(chunks); + chunked_memory_output_stream stream(chunks); // Create serializer with initial batch chunk_serializer serializer(rb1, stream); From 458a9cf5f1d214c5ae0caa34469a4272916d6d7b Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Mon, 6 Oct 2025 16:03:24 +0200 Subject: [PATCH 08/11] wip --- include/sparrow_ipc/serialize_utils.hpp | 3 ++- include/sparrow_ipc/serializer.hpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/include/sparrow_ipc/serialize_utils.hpp b/include/sparrow_ipc/serialize_utils.hpp index ad1ace7..5b96504 100644 --- a/include/sparrow_ipc/serialize_utils.hpp +++ b/include/sparrow_ipc/serialize_utils.hpp @@ -103,7 +103,8 @@ namespace sparrow_ipc } // Calculate schema message size (only once) - std::size_t total_size = calculate_schema_message_size(record_batches[0]); + auto it = std::ranges::begin(record_batches); + std::size_t total_size = calculate_schema_message_size(*it); // Calculate record batch message sizes for (const auto& record_batch : record_batches) diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp index fe54b4a..7226a39 100644 --- a/include/sparrow_ipc/serializer.hpp +++ b/include/sparrow_ipc/serializer.hpp @@ -35,7 +35,7 @@ namespace sparrow_ipc serializer(const sparrow::record_batch& rb, output_stream& stream); - template + template requires std::same_as, sparrow::record_batch> serializer(const R& record_batches, output_stream& stream) : m_pstream(&stream) From 75c982715769c121d38be161df4176a2e9acd3fd Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Mon, 6 Oct 2025 15:56:38 +0200 Subject: [PATCH 09/11] Update include/sparrow_ipc/serializer.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- include/sparrow_ipc/serializer.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp index 7226a39..ec75d1d 100644 --- a/include/sparrow_ipc/serializer.hpp +++ b/include/sparrow_ipc/serializer.hpp @@ -109,6 +109,10 @@ namespace sparrow_ipc m_pstream->reserve(reserve_function); for (const auto& rb : record_batches) { + if (get_column_dtypes(rb) != m_dtypes) + { + throw std::invalid_argument("Record batch schema does not match serializer schema"); + } serialize_record_batch(rb, *m_pstream); } } From 37ec2fed9d136871e49a0739d7190208a5decbd0 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Wed, 8 Oct 2025 14:43:02 +0200 Subject: [PATCH 10/11] Try fix --- .github/workflows/windows.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index c215557..1aba990 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -39,6 +39,7 @@ jobs: cmake -S ./ -B ./build \ -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX \ -DCMAKE_PREFIX_PATH=$GLOB_PREFIX_PATH \ + -DCMAKE_LIBRARY_PATH=$GLOB_PREFIX_PATH \ -DSPARROW_IPC_BUILD_SHARED=${{ matrix.build_shared }} \ -DSPARROW_IPC_BUILD_TESTS=ON From c69100fcc71fa90bf3ded915683ea17ffe0d191d Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Wed, 8 Oct 2025 15:12:03 +0200 Subject: [PATCH 11/11] fix --- .github/workflows/windows.yml | 1 - tests/CMakeLists.txt | 36 +++++++++++++++++------------------ 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 1aba990..c215557 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -39,7 +39,6 @@ jobs: cmake -S ./ -B ./build \ -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX \ -DCMAKE_PREFIX_PATH=$GLOB_PREFIX_PATH \ - -DCMAKE_LIBRARY_PATH=$GLOB_PREFIX_PATH \ -DSPARROW_IPC_BUILD_SHARED=${{ matrix.build_shared }} \ -DSPARROW_IPC_BUILD_TESTS=ON diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b489bbc..1e51d94 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -27,25 +27,23 @@ target_link_libraries(${test_target} ) if(WIN32) - if(${SPARROW_IPC_BUILD_SHARED}) - find_package(date) # For copying DLLs - add_custom_command( - TARGET ${test_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy - "$" - "$" - COMMAND ${CMAKE_COMMAND} -E copy - "$" - "$" - COMMAND ${CMAKE_COMMAND} -E copy - "$" - "$" - COMMAND ${CMAKE_COMMAND} -E copy - "$" - "$" - COMMENT "Copying sparrow and sparrow-ipc DLLs to executable directory" - ) - endif() + find_package(date) # For copying DLLs + add_custom_command( + TARGET ${test_target} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + "$" + "$" + COMMAND ${CMAKE_COMMAND} -E copy + "$" + "$" + COMMAND ${CMAKE_COMMAND} -E copy + "$" + "$" + COMMAND ${CMAKE_COMMAND} -E copy + "$" + "$" + COMMENT "Copying sparrow and sparrow-ipc DLLs to executable directory" + ) endif() target_include_directories(${test_target}