diff --git a/swift/Arrow/Sources/Arrow/ArrowWriter.swift b/swift/Arrow/Sources/Arrow/ArrowWriter.swift index 54581ba396ff..3aa25b62b49b 100644 --- a/swift/Arrow/Sources/Arrow/ArrowWriter.swift +++ b/swift/Arrow/Sources/Arrow/ArrowWriter.swift @@ -71,11 +71,30 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length public init() {} private func writeField(_ fbb: inout FlatBufferBuilder, field: ArrowField) -> Result { + var fieldsOffset: Offset? + if let nestedField = field.type as? ArrowNestedType { + var offsets = [Offset]() + for field in nestedField.fields { + switch writeField(&fbb, field: field) { + case .success(let offset): + offsets.append(offset) + case .failure(let error): + return .failure(error) + } + } + + fieldsOffset = fbb.createVector(ofOffsets: offsets) + } + let nameOffset = fbb.create(string: field.name) let fieldTypeOffsetResult = toFBType(&fbb, arrowType: field.type) let startOffset = org_apache_arrow_flatbuf_Field.startField(&fbb) org_apache_arrow_flatbuf_Field.add(name: nameOffset, &fbb) org_apache_arrow_flatbuf_Field.add(nullable: field.isNullable, &fbb) + if let childrenOffset = fieldsOffset { + org_apache_arrow_flatbuf_Field.addVectorOf(children: childrenOffset, &fbb) + } + switch toFBTypeEnum(field.type) { case .success(let type): org_apache_arrow_flatbuf_Field.add(typeType: type, &fbb) @@ -101,7 +120,6 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length case .failure(let error): return .failure(error) } - } let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets) @@ -126,7 +144,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))} withUnsafeBytes(of: rbResult.1.o.littleEndian) {writer.append(Data($0))} writer.append(rbResult.0) - switch writeRecordBatchData(&writer, batch: batch) { + switch writeRecordBatchData(&writer, fields: batch.schema.fields, columns: batch.columns) { case .success: rbBlocks.append( org_apache_arrow_flatbuf_Block(offset: Int64(startIndex), @@ -143,37 +161,59 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(rbBlocks) } - private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> { - let schema = batch.schema - var fbb = FlatBufferBuilder() - - // write out field nodes - var fieldNodeOffsets = [Offset]() - fbb.startVector(schema.fields.count, elementSize: MemoryLayout.size) - for index in (0 ..< schema.fields.count).reversed() { - let column = batch.column(index) + private func writeFieldNodes(_ fields: [ArrowField], columns: [ArrowArrayHolder], offsets: inout [Offset], + fbb: inout FlatBufferBuilder) { + for index in (0 ..< fields.count).reversed() { + let column = columns[index] let fieldNode = org_apache_arrow_flatbuf_FieldNode(length: Int64(column.length), nullCount: Int64(column.nullCount)) - fieldNodeOffsets.append(fbb.create(struct: fieldNode)) + offsets.append(fbb.create(struct: fieldNode)) + if let nestedType = column.type as? ArrowNestedType { + let structArray = column.array as? StructArray + writeFieldNodes(nestedType.fields, columns: structArray!.arrowFields!, offsets: &offsets, fbb: &fbb) + } } + } - let nodeOffset = fbb.endVector(len: schema.fields.count) - - // write out buffers - var buffers = [org_apache_arrow_flatbuf_Buffer]() - var bufferOffset = Int(0) - for index in 0 ..< batch.schema.fields.count { - let column = batch.column(index) + private func writeBufferInfo(_ fields: [ArrowField], + columns: [ArrowArrayHolder], + bufferOffset: inout Int, + buffers: inout [org_apache_arrow_flatbuf_Buffer], + fbb: inout FlatBufferBuilder) { + for index in 0 ..< fields.count { + let column = columns[index] let colBufferDataSizes = column.getBufferDataSizes() for var bufferDataSize in colBufferDataSizes { bufferDataSize = getPadForAlignment(bufferDataSize) let buffer = org_apache_arrow_flatbuf_Buffer(offset: Int64(bufferOffset), length: Int64(bufferDataSize)) buffers.append(buffer) bufferOffset += bufferDataSize + if let nestedType = column.type as? ArrowNestedType { + let structArray = column.array as? StructArray + writeBufferInfo(nestedType.fields, columns: structArray!.arrowFields!, + bufferOffset: &bufferOffset, buffers: &buffers, fbb: &fbb) + } } } + } + private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> { + let schema = batch.schema + var fbb = FlatBufferBuilder() + + // write out field nodes + var fieldNodeOffsets = [Offset]() + fbb.startVector(schema.fields.count, elementSize: MemoryLayout.size) + writeFieldNodes(schema.fields, columns: batch.columns, offsets: &fieldNodeOffsets, fbb: &fbb) + let nodeOffset = fbb.endVector(len: fieldNodeOffsets.count) + + // write out buffers + var buffers = [org_apache_arrow_flatbuf_Buffer]() + var bufferOffset = Int(0) + writeBufferInfo(schema.fields, columns: batch.columns, + bufferOffset: &bufferOffset, buffers: &buffers, + fbb: &fbb) org_apache_arrow_flatbuf_RecordBatch.startVectorOfBuffers(batch.schema.fields.count, in: &fbb) for buffer in buffers.reversed() { fbb.create(struct: buffer) @@ -196,13 +236,28 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success((fbb.data, Offset(offset: UInt32(fbb.data.count)))) } - private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result { - for index in 0 ..< batch.schema.fields.count { - let column = batch.column(index) + private func writeRecordBatchData( + _ writer: inout DataWriter, fields: [ArrowField], + columns: [ArrowArrayHolder]) + -> Result { + for index in 0 ..< fields.count { + let column = columns[index] let colBufferData = column.getBufferData() for var bufferData in colBufferData { addPadForAlignment(&bufferData) writer.append(bufferData) + if let nestedType = column.type as? ArrowNestedType { + guard let structArray = column.array as? StructArray else { + return .failure(.invalid("Struct type array expected for nested type")) + } + + switch writeRecordBatchData(&writer, fields: nestedType.fields, columns: structArray.arrowFields!) { + case .success: + continue + case .failure(let error): + return .failure(error) + } + } } } @@ -226,11 +281,10 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length org_apache_arrow_flatbuf_Footer.addVectorOf(recordBatches: rbBlkEnd, &fbb) let footerOffset = org_apache_arrow_flatbuf_Footer.endFooter(&fbb, start: footerStartOffset) fbb.finish(offset: footerOffset) + return .success(fbb.data) case .failure(let error): return .failure(error) } - - return .success(fbb.data) } private func writeFile(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result { @@ -265,7 +319,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(true) } - public func writeSteaming(_ info: ArrowWriter.Info) -> Result { + public func writeStreaming(_ info: ArrowWriter.Info) -> Result { let writer: any DataWriter = InMemDataWriter() switch toMessage(info.schema) { case .success(let schemaData): @@ -343,7 +397,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length writer.append(message.0) addPadForAlignment(&writer) var dataWriter: any DataWriter = InMemDataWriter() - switch writeRecordBatchData(&dataWriter, batch: batch) { + switch writeRecordBatchData(&dataWriter, fields: batch.schema.fields, columns: batch.columns) { case .success: return .success([ (writer as! InMemDataWriter).data, // swiftlint:disable:this force_cast @@ -377,3 +431,4 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(fbb.data) } } +// swiftlint:disable:this file_length diff --git a/swift/Arrow/Sources/Arrow/ArrowWriterHelper.swift b/swift/Arrow/Sources/Arrow/ArrowWriterHelper.swift index fdc72ef6e7f1..4d63192585f9 100644 --- a/swift/Arrow/Sources/Arrow/ArrowWriterHelper.swift +++ b/swift/Arrow/Sources/Arrow/ArrowWriterHelper.swift @@ -25,67 +25,69 @@ extension Data { } func toFBTypeEnum(_ arrowType: ArrowType) -> Result { - let infoType = arrowType.info - if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowInt16 || - infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt8 || - infoType == ArrowType.ArrowUInt16 || infoType == ArrowType.ArrowUInt32 || - infoType == ArrowType.ArrowUInt64 || infoType == ArrowType.ArrowInt32 { + let typeId = arrowType.id + switch typeId { + case .int8, .int16, .int32, .int64, .uint8, .uint16, .uint32, .uint64: return .success(org_apache_arrow_flatbuf_Type_.int) - } else if infoType == ArrowType.ArrowFloat || infoType == ArrowType.ArrowDouble { + case .float, .double: return .success(org_apache_arrow_flatbuf_Type_.floatingpoint) - } else if infoType == ArrowType.ArrowString { + case .string: return .success(org_apache_arrow_flatbuf_Type_.utf8) - } else if infoType == ArrowType.ArrowBinary { + case .binary: return .success(org_apache_arrow_flatbuf_Type_.binary) - } else if infoType == ArrowType.ArrowBool { + case .boolean: return .success(org_apache_arrow_flatbuf_Type_.bool) - } else if infoType == ArrowType.ArrowDate32 || infoType == ArrowType.ArrowDate64 { + case .date32, .date64: return .success(org_apache_arrow_flatbuf_Type_.date) - } else if infoType == ArrowType.ArrowTime32 || infoType == ArrowType.ArrowTime64 { + case .time32, .time64: return .success(org_apache_arrow_flatbuf_Type_.time) + case .strct: + return .success(org_apache_arrow_flatbuf_Type_.struct_) + default: + return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(typeId)")) } - return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(infoType)")) } -func toFBType( // swiftlint:disable:this cyclomatic_complexity +func toFBType( // swiftlint:disable:this cyclomatic_complexity function_body_length _ fbb: inout FlatBufferBuilder, arrowType: ArrowType ) -> Result { let infoType = arrowType.info - if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowUInt8 { + switch arrowType.id { + case .int8, .uint8: return .success(org_apache_arrow_flatbuf_Int.createInt( &fbb, bitWidth: 8, isSigned: infoType == ArrowType.ArrowInt8)) - } else if infoType == ArrowType.ArrowInt16 || infoType == ArrowType.ArrowUInt16 { + case .int16, .uint16: return .success(org_apache_arrow_flatbuf_Int.createInt( &fbb, bitWidth: 16, isSigned: infoType == ArrowType.ArrowInt16)) - } else if infoType == ArrowType.ArrowInt32 || infoType == ArrowType.ArrowUInt32 { + case .int32, .uint32: return .success(org_apache_arrow_flatbuf_Int.createInt( &fbb, bitWidth: 32, isSigned: infoType == ArrowType.ArrowInt32)) - } else if infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt64 { + case .int64, .uint64: return .success(org_apache_arrow_flatbuf_Int.createInt( &fbb, bitWidth: 64, isSigned: infoType == ArrowType.ArrowInt64)) - } else if infoType == ArrowType.ArrowFloat { + case .float: return .success(org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .single)) - } else if infoType == ArrowType.ArrowDouble { + case .double: return .success(org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .double)) - } else if infoType == ArrowType.ArrowString { + case .string: return .success(org_apache_arrow_flatbuf_Utf8.endUtf8( &fbb, start: org_apache_arrow_flatbuf_Utf8.startUtf8(&fbb))) - } else if infoType == ArrowType.ArrowBinary { + case .binary: return .success(org_apache_arrow_flatbuf_Binary.endBinary( &fbb, start: org_apache_arrow_flatbuf_Binary.startBinary(&fbb))) - } else if infoType == ArrowType.ArrowBool { + case .boolean: return .success(org_apache_arrow_flatbuf_Bool.endBool( &fbb, start: org_apache_arrow_flatbuf_Bool.startBool(&fbb))) - } else if infoType == ArrowType.ArrowDate32 { + case .date32: let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb) org_apache_arrow_flatbuf_Date.add(unit: .day, &fbb) return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset)) - } else if infoType == ArrowType.ArrowDate64 { + case .date64: let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb) org_apache_arrow_flatbuf_Date.add(unit: .millisecond, &fbb) return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset)) - } else if infoType == ArrowType.ArrowTime32 { + case .time32: let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb) if let timeType = arrowType as? ArrowTypeTime32 { org_apache_arrow_flatbuf_Time.add(unit: timeType.unit == .seconds ? .second : .millisecond, &fbb) @@ -93,7 +95,7 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity } return .failure(.invalid("Unable to case to Time32")) - } else if infoType == ArrowType.ArrowTime64 { + case .time64: let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb) if let timeType = arrowType as? ArrowTypeTime64 { org_apache_arrow_flatbuf_Time.add(unit: timeType.unit == .microseconds ? .microsecond : .nanosecond, &fbb) @@ -101,9 +103,12 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity } return .failure(.invalid("Unable to case to Time64")) + case .strct: + let startOffset = org_apache_arrow_flatbuf_Struct_.startStruct_(&fbb) + return .success(org_apache_arrow_flatbuf_Struct_.endStruct_(&fbb, start: startOffset)) + default: + return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)")) } - - return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)")) } func addPadForAlignment(_ data: inout Data, alignment: Int = 8) { diff --git a/swift/Arrow/Sources/Arrow/ProtoUtil.swift b/swift/Arrow/Sources/Arrow/ProtoUtil.swift index ac61030c08b0..88cfb0bfcde4 100644 --- a/swift/Arrow/Sources/Arrow/ProtoUtil.swift +++ b/swift/Arrow/Sources/Arrow/ProtoUtil.swift @@ -17,7 +17,7 @@ import Foundation -func fromProto( // swiftlint:disable:this cyclomatic_complexity +func fromProto( // swiftlint:disable:this cyclomatic_complexity function_body_length field: org_apache_arrow_flatbuf_Field ) -> ArrowField { let type = field.typeType @@ -65,7 +65,13 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity arrowType = ArrowTypeTime64(arrowUnit) } case .struct_: - arrowType = ArrowType(ArrowType.ArrowStruct) + var children = [ArrowField]() + for index in 0..) throws -> [RecordBatch] { let recordBatches: [RecordBatch] @@ -55,6 +73,37 @@ func checkBoolRecordBatch(_ result: Result) throws -> [RecordBatch] { + let recordBatches: [RecordBatch] + switch result { + case .success(let result): + recordBatches = result.batches + case .failure(let error): + throw error + } + + XCTAssertEqual(recordBatches.count, 1) + for recordBatch in recordBatches { + XCTAssertEqual(recordBatch.length, 3) + XCTAssertEqual(recordBatch.columns.count, 1) + XCTAssertEqual(recordBatch.schema.fields.count, 1) + XCTAssertEqual(recordBatch.schema.fields[0].name, "my struct") + XCTAssertEqual(recordBatch.schema.fields[0].type.id, .strct) + let structArray = recordBatch.columns[0].array as? StructArray + XCTAssertEqual(structArray!.arrowFields!.count, 2) + XCTAssertEqual(structArray!.arrowFields![0].type.id, .string) + XCTAssertEqual(structArray!.arrowFields![1].type.id, .boolean) + let column = recordBatch.columns[0] + let str = column.array as? AsString + XCTAssertEqual("\(str!.asString(0))", "{0,false}") + XCTAssertEqual("\(str!.asString(1))", "{1,true}") + XCTAssertTrue(column.array.asAny(2) == nil) + } + + return recordBatches +} + func currentDirectory(path: String = #file) -> URL { return URL(fileURLWithPath: path).deletingLastPathComponent() } @@ -69,6 +118,47 @@ func makeSchema() -> ArrowSchema { .finish() } +func makeStructSchema() -> ArrowSchema { + let testObj = StructTest() + var fields = [ArrowField]() + let buildStructType = {() -> ArrowNestedType in + let mirror = Mirror(reflecting: testObj) + for (property, value) in mirror.children { + let arrowType = ArrowType(ArrowType.infoForType(type(of: value))) + fields.append(ArrowField(property!, type: arrowType, isNullable: true)) + } + + return ArrowNestedType(ArrowType.ArrowStruct, fields: fields) + } + + return ArrowSchema.Builder() + .addField("struct1", type: buildStructType(), isNullable: true) + .finish() +} + +func makeStructRecordBatch() throws -> RecordBatch { + let testData = StructTest() + let dateNow = Date.now + let structBuilder = try ArrowArrayBuilders.loadStructArrayBuilderForType(testData) + structBuilder.append([true, Int8(1), Int16(2), Int32(3), Int64(4), + UInt8(5), UInt16(6), UInt32(7), UInt64(8), Double(9.9), + Float(10.10), "11", Data("12".utf8), dateNow]) + structBuilder.append(nil) + structBuilder.append([true, Int8(13), Int16(14), Int32(15), Int64(16), + UInt8(17), UInt16(18), UInt32(19), UInt64(20), Double(21.21), + Float(22.22), "23", Data("24".utf8), dateNow]) + let structHolder = ArrowArrayHolderImpl(try structBuilder.finish()) + let result = RecordBatch.Builder() + .addColumn("struct1", arrowArray: structHolder) + .finish() + switch result { + case .success(let recordBatch): + return recordBatch + case .failure(let error): + throw error + } +} + func makeRecordBatch() throws -> RecordBatch { let uint8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() uint8Builder.append(10) @@ -124,7 +214,7 @@ final class IPCStreamReaderTests: XCTestCase { let recordBatch = try makeRecordBatch() let arrowWriter = ArrowWriter() let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch]) - switch arrowWriter.writeSteaming(writerInfo) { + switch arrowWriter.writeStreaming(writerInfo) { case .success(let writeData): let arrowReader = ArrowReader() switch arrowReader.readStreaming(writeData) { @@ -173,43 +263,6 @@ final class IPCStreamReaderTests: XCTestCase { } final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body_length - func testFileReader_struct() throws { - let fileURL = currentDirectory().appendingPathComponent("../../testdata_struct.arrow") - let arrowReader = ArrowReader() - let result = arrowReader.fromFile(fileURL) - let recordBatches: [RecordBatch] - switch result { - case .success(let result): - recordBatches = result.batches - case .failure(let error): - throw error - } - - XCTAssertEqual(recordBatches.count, 1) - for recordBatch in recordBatches { - XCTAssertEqual(recordBatch.length, 3) - XCTAssertEqual(recordBatch.columns.count, 1) - XCTAssertEqual(recordBatch.schema.fields.count, 1) - XCTAssertEqual(recordBatch.schema.fields[0].type.info, ArrowType.ArrowStruct) - let column = recordBatch.columns[0] - XCTAssertNotNil(column.array as? StructArray) - if let structArray = column.array as? StructArray { - XCTAssertEqual(structArray.arrowFields?.count, 2) - XCTAssertEqual(structArray.arrowFields?[0].type.info, ArrowType.ArrowString) - XCTAssertEqual(structArray.arrowFields?[1].type.info, ArrowType.ArrowBool) - for index in 0.. ArrowNestedType in + let mirror = Mirror(reflecting: testObj) + for (property, value) in mirror.children { + let arrowType = ArrowType(ArrowType.infoForType(type(of: value))) + fields.append(ArrowField(property!, type: arrowType, isNullable: true)) + } + + return ArrowNestedType(ArrowType.ArrowStruct, fields: fields) + } + + let structType = buildStructType() + XCTAssertEqual(structType.id, ArrowTypeId.strct) + XCTAssertEqual(structType.fields.count, 14) + XCTAssertEqual(structType.fields[0].type.id, ArrowTypeId.boolean) + XCTAssertEqual(structType.fields[1].type.id, ArrowTypeId.int8) + XCTAssertEqual(structType.fields[2].type.id, ArrowTypeId.int16) + XCTAssertEqual(structType.fields[3].type.id, ArrowTypeId.int32) + XCTAssertEqual(structType.fields[4].type.id, ArrowTypeId.int64) + XCTAssertEqual(structType.fields[5].type.id, ArrowTypeId.uint8) + XCTAssertEqual(structType.fields[6].type.id, ArrowTypeId.uint16) + XCTAssertEqual(structType.fields[7].type.id, ArrowTypeId.uint32) + XCTAssertEqual(structType.fields[8].type.id, ArrowTypeId.uint64) + XCTAssertEqual(structType.fields[9].type.id, ArrowTypeId.double) + XCTAssertEqual(structType.fields[10].type.id, ArrowTypeId.float) + XCTAssertEqual(structType.fields[11].type.id, ArrowTypeId.string) + XCTAssertEqual(structType.fields[12].type.id, ArrowTypeId.binary) + XCTAssertEqual(structType.fields[13].type.id, ArrowTypeId.date64) + } + func testTable() throws { let doubleBuilder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() doubleBuilder.append(11.11)