From 85684fe4af1e35233f3ac921ed45b95202cda562 Mon Sep 17 00:00:00 2001 From: abandy Date: Thu, 25 Jul 2024 15:49:52 -0400 Subject: [PATCH] GH-43169: [Swift] Add StructArray to ArrowReader (#43335) ### Rationale for this change Structs have been added for Swift but currently the ArrowReader does not support them. This PR adds the ArrowReader support ### What changes are included in this PR? Adding StructArray to ArrowReader ### Are these changes tested? The next PR for the ArrowWriter will include a test for reading and writing Structs. * GitHub Issue: #43169 Authored-by: Alva Bandy Signed-off-by: Sutou Kouhei --- .../Arrow/Sources/Arrow/ArrowCImporter.swift | 3 +- swift/Arrow/Sources/Arrow/ArrowReader.swift | 199 ++++++++++++------ .../Sources/Arrow/ArrowReaderHelper.swift | 59 +++++- swift/Arrow/Tests/ArrowTests/ArrayTests.swift | 2 +- 4 files changed, 194 insertions(+), 69 deletions(-) diff --git a/swift/Arrow/Sources/Arrow/ArrowCImporter.swift b/swift/Arrow/Sources/Arrow/ArrowCImporter.swift index f55077ef3dc95..e65d78d730be7 100644 --- a/swift/Arrow/Sources/Arrow/ArrowCImporter.swift +++ b/swift/Arrow/Sources/Arrow/ArrowCImporter.swift @@ -153,7 +153,8 @@ public class ArrowCImporter { } } - switch makeArrayHolder(arrowField, buffers: arrowBuffers, nullCount: nullCount) { + switch makeArrayHolder(arrowField, buffers: arrowBuffers, + nullCount: nullCount, children: nil, rbLength: 0) { case .success(let holder): return .success(ImportArrayHolder(holder, cArrayPtr: cArrayPtr)) case .failure(let err): diff --git a/swift/Arrow/Sources/Arrow/ArrowReader.swift b/swift/Arrow/Sources/Arrow/ArrowReader.swift index 237f22dc979e3..ae187e22eef70 100644 --- a/swift/Arrow/Sources/Arrow/ArrowReader.swift +++ b/swift/Arrow/Sources/Arrow/ArrowReader.swift @@ -21,14 +21,46 @@ import Foundation let FILEMARKER = "ARROW1" let CONTINUATIONMARKER = -1 -public class ArrowReader { - private struct DataLoadInfo { +public class ArrowReader { // swiftlint:disable:this type_body_length + private class RecordBatchData { + let schema: org_apache_arrow_flatbuf_Schema let recordBatch: org_apache_arrow_flatbuf_RecordBatch - let field: org_apache_arrow_flatbuf_Field - let nodeIndex: Int32 - let bufferIndex: Int32 + private var fieldIndex: Int32 = 0 + private var nodeIndex: Int32 = 0 + private var bufferIndex: Int32 = 0 + init(_ recordBatch: org_apache_arrow_flatbuf_RecordBatch, + schema: org_apache_arrow_flatbuf_Schema) { + self.recordBatch = recordBatch + self.schema = schema + } + + func nextNode() -> org_apache_arrow_flatbuf_FieldNode? { + if nodeIndex >= self.recordBatch.nodesCount {return nil} + defer {nodeIndex += 1} + return self.recordBatch.nodes(at: nodeIndex) + } + + func nextBuffer() -> org_apache_arrow_flatbuf_Buffer? { + if bufferIndex >= self.recordBatch.buffersCount {return nil} + defer {bufferIndex += 1} + return self.recordBatch.buffers(at: bufferIndex) + } + + func nextField() -> org_apache_arrow_flatbuf_Field? { + if fieldIndex >= self.schema.fieldsCount {return nil} + defer {fieldIndex += 1} + return self.schema.fields(at: fieldIndex) + } + + func isDone() -> Bool { + return nodeIndex >= self.recordBatch.nodesCount + } + } + + private struct DataLoadInfo { let fileData: Data let messageOffset: Int64 + var batchData: RecordBatchData } public class ArrowReaderResult { @@ -54,49 +86,104 @@ public class ArrowReader { return .success(builder.finish()) } - private func loadPrimitiveData(_ loadInfo: DataLoadInfo) -> Result { - do { - let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)! - let nullLength = UInt(ceil(Double(node.length) / 8)) - try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex) - let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)! - let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData, - length: nullLength, messageOffset: loadInfo.messageOffset) - try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1) - let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)! - let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData, - length: UInt(node.length), messageOffset: loadInfo.messageOffset) - return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowValueBuffer], - nullCount: UInt(node.nullCount)) - } catch let error as ArrowError { - return .failure(error) - } catch { - return .failure(.unknownError("\(error)")) + private func loadStructData(_ loadInfo: DataLoadInfo, + field: org_apache_arrow_flatbuf_Field) + -> Result { + guard let node = loadInfo.batchData.nextNode() else { + return .failure(.invalid("Node not found")) + } + + guard let nullBuffer = loadInfo.batchData.nextBuffer() else { + return .failure(.invalid("Null buffer not found")) + } + + let nullLength = UInt(ceil(Double(node.length) / 8)) + let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData, + length: nullLength, messageOffset: loadInfo.messageOffset) + var children = [ArrowData]() + for index in 0.. Result { - let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)! - do { - let nullLength = UInt(ceil(Double(node.length) / 8)) - try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex) - let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)! - let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData, - length: nullLength, messageOffset: loadInfo.messageOffset) - try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1) - let offsetBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)! - let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: loadInfo.fileData, - length: UInt(node.length), messageOffset: loadInfo.messageOffset) - try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 2) - let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 2)! - let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData, - length: UInt(node.length), messageOffset: loadInfo.messageOffset) - return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer], - nullCount: UInt(node.nullCount)) - } catch let error as ArrowError { - return .failure(error) - } catch { - return .failure(.unknownError("\(error)")) + private func loadPrimitiveData( + _ loadInfo: DataLoadInfo, + field: org_apache_arrow_flatbuf_Field) + -> Result { + guard let node = loadInfo.batchData.nextNode() else { + return .failure(.invalid("Node not found")) + } + + guard let nullBuffer = loadInfo.batchData.nextBuffer() else { + return .failure(.invalid("Null buffer not found")) + } + + guard let valueBuffer = loadInfo.batchData.nextBuffer() else { + return .failure(.invalid("Value buffer not found")) + } + + let nullLength = UInt(ceil(Double(node.length) / 8)) + let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData, + length: nullLength, messageOffset: loadInfo.messageOffset) + let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData, + length: UInt(node.length), messageOffset: loadInfo.messageOffset) + return makeArrayHolder(field, buffers: [arrowNullBuffer, arrowValueBuffer], + nullCount: UInt(node.nullCount), children: nil, + rbLength: UInt(loadInfo.batchData.recordBatch.length)) + } + + private func loadVariableData( + _ loadInfo: DataLoadInfo, + field: org_apache_arrow_flatbuf_Field) + -> Result { + guard let node = loadInfo.batchData.nextNode() else { + return .failure(.invalid("Node not found")) + } + + guard let nullBuffer = loadInfo.batchData.nextBuffer() else { + return .failure(.invalid("Null buffer not found")) + } + + guard let offsetBuffer = loadInfo.batchData.nextBuffer() else { + return .failure(.invalid("Offset buffer not found")) + } + + guard let valueBuffer = loadInfo.batchData.nextBuffer() else { + return .failure(.invalid("Value buffer not found")) + } + + let nullLength = UInt(ceil(Double(node.length) / 8)) + let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData, + length: nullLength, messageOffset: loadInfo.messageOffset) + let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: loadInfo.fileData, + length: UInt(node.length), messageOffset: loadInfo.messageOffset) + let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData, + length: UInt(node.length), messageOffset: loadInfo.messageOffset) + return makeArrayHolder(field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer], + nullCount: UInt(node.nullCount), children: nil, + rbLength: UInt(loadInfo.batchData.recordBatch.length)) + } + + private func loadField( + _ loadInfo: DataLoadInfo, + field: org_apache_arrow_flatbuf_Field) + -> Result { + if isNestedType(field.typeType) { + return loadStructData(loadInfo, field: field) + } else if isFixedPrimitive(field.typeType) { + return loadPrimitiveData(loadInfo, field: field) + } else { + return loadVariableData(loadInfo, field: field) } } @@ -107,23 +194,17 @@ public class ArrowReader { data: Data, messageEndOffset: Int64 ) -> Result { - let nodesCount = recordBatch.nodesCount - var bufferIndex: Int32 = 0 var columns: [ArrowArrayHolder] = [] - for nodeIndex in 0 ..< nodesCount { - let field = schema.fields(at: nodeIndex)! - let loadInfo = DataLoadInfo(recordBatch: recordBatch, field: field, - nodeIndex: nodeIndex, bufferIndex: bufferIndex, - fileData: data, messageOffset: messageEndOffset) - var result: Result - if isFixedPrimitive(field.typeType) { - result = loadPrimitiveData(loadInfo) - bufferIndex += 2 - } else { - result = loadVariableData(loadInfo) - bufferIndex += 3 + let batchData = RecordBatchData(recordBatch, schema: schema) + let loadInfo = DataLoadInfo(fileData: data, + messageOffset: messageEndOffset, + batchData: batchData) + while !batchData.isDone() { + guard let field = batchData.nextField() else { + return .failure(.invalid("Field not found")) } + let result = loadField(loadInfo, field: field) switch result { case .success(let holder): columns.append(holder) diff --git a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift index 22c0672b27eac..48c6fd855073a 100644 --- a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift +++ b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift @@ -117,19 +117,42 @@ private func makeFixedHolder( } } + func makeStructHolder( + _ field: ArrowField, + buffers: [ArrowBuffer], + nullCount: UInt, + children: [ArrowData], + rbLength: UInt +) -> Result { + do { + let arrowData = try ArrowData(field.type, + buffers: buffers, children: children, + nullCount: nullCount, length: rbLength) + return .success(ArrowArrayHolderImpl(try StructArray(arrowData))) + } catch let error as ArrowError { + return .failure(error) + } catch { + return .failure(.unknownError("\(error)")) + } +} + func makeArrayHolder( _ field: org_apache_arrow_flatbuf_Field, buffers: [ArrowBuffer], - nullCount: UInt + nullCount: UInt, + children: [ArrowData]?, + rbLength: UInt ) -> Result { let arrowField = fromProto(field: field) - return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount) + return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount, children: children, rbLength: rbLength) } func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity _ field: ArrowField, buffers: [ArrowBuffer], - nullCount: UInt + nullCount: UInt, + children: [ArrowData]?, + rbLength: UInt ) -> Result { let typeId = field.type.id switch typeId { @@ -159,12 +182,12 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity return makeStringHolder(buffers, nullCount: nullCount) case .binary: return makeBinaryHolder(buffers, nullCount: nullCount) - case .date32: + case .date32, .date64: return makeDateHolder(field, buffers: buffers, nullCount: nullCount) - case .time32: - return makeTimeHolder(field, buffers: buffers, nullCount: nullCount) - case .time64: + case .time32, .time64: return makeTimeHolder(field, buffers: buffers, nullCount: nullCount) + case .strct: + return makeStructHolder(field, buffers: buffers, nullCount: nullCount, children: children!, rbLength: rbLength) default: return .failure(.unknownType("Type \(typeId) currently not supported")) } @@ -187,7 +210,16 @@ func isFixedPrimitive(_ type: org_apache_arrow_flatbuf_Type_) -> Bool { } } -func findArrowType( // swiftlint:disable:this cyclomatic_complexity +func isNestedType(_ type: org_apache_arrow_flatbuf_Type_) -> Bool { + switch type { + case .struct_: + return true + default: + return false + } +} + +func findArrowType( // swiftlint:disable:this cyclomatic_complexity function_body_length _ field: org_apache_arrow_flatbuf_Field) -> ArrowType { let type = field.typeType switch type { @@ -229,6 +261,17 @@ func findArrowType( // swiftlint:disable:this cyclomatic_complexity } return ArrowTypeTime64(timeType.unit == .microsecond ? .microseconds : .nanoseconds) + case .struct_: + _ = field.type(type: org_apache_arrow_flatbuf_Struct_.self)! + var fields = [ArrowField]() + for index in 0..