Skip to content

Commit 81b2dd9

Browse files
committed
refactor: Address review comments for CSV union schema feature
Addresses all review feedback from PR #17553 to improve the CSV schema union implementation that allows reading CSV files with different column counts. Changes based on review: - Moved unit tests from separate tests.rs to bottom of file_format.rs - Updated documentation wording from "now supports" to "can handle" - Removed all println statements from integration test - Added comprehensive assertions for actual row content verification - Simplified HashSet initialization using HashSet::from([...]) syntax - Updated truncated_rows config documentation to reflect expanded purpose - Removed unnecessary min() calculation in column processing loop - Fixed clippy warnings by using enumerate() instead of range loop Technical improvements: - Tests now verify null patterns correctly across union schema - Cleaner iteration logic without redundant bounds checking - Better documentation explaining union schema behavior The feature continues to work as designed: - Creates union schema from all CSV files in a directory - Files with fewer columns have nulls for missing fields - Requires explicit opt-in via truncated_rows(true) - Maintains full backward compatibility
1 parent f27028e commit 81b2dd9

File tree

5 files changed

+154
-159
lines changed

5 files changed

+154
-159
lines changed

datafusion/common/src/config.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2526,9 +2526,12 @@ config_namespace! {
25262526
// The input regex for Nulls when loading CSVs.
25272527
pub null_regex: Option<String>, default = None
25282528
pub comment: Option<u8>, default = None
2529-
// Whether to allow truncated rows when parsing.
2530-
// By default this is set to false and will error if the CSV rows have different lengths.
2531-
// When set to true then it will allow records with less than the expected number of columns
2529+
/// Whether to allow CSV files with varying numbers of columns.
2530+
/// By default this is set to false and will error if the CSV rows have different lengths.
2531+
/// When set to true:
2532+
/// - Allows reading multiple CSV files with different column counts
2533+
/// - Creates a union schema during inference containing all columns found across files
2534+
/// - Files with fewer columns will have missing columns filled with null values
25322535
pub truncated_rows: Option<bool>, default = None
25332536
}
25342537
}

datafusion/core/tests/csv_schema_fix_test.rs

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,24 @@ async fn test_csv_schema_inference_different_column_counts() -> Result<()> {
5151
// Enable truncated_rows to handle files with different column counts
5252
let df = ctx
5353
.read_csv(
54-
temp_path.to_str().unwrap(),
55-
CsvReadOptions::new().truncated_rows(true)
54+
temp_path.to_str().unwrap(),
55+
CsvReadOptions::new().truncated_rows(true),
5656
)
5757
.await
5858
.expect("Should successfully read CSV directory with different column counts");
5959

6060
// Verify the schema contains all 6 columns (union of both files)
6161
let df_clone = df.clone();
6262
let schema = df_clone.schema();
63-
assert_eq!(schema.fields().len(), 6, "Schema should contain all 6 columns");
63+
assert_eq!(
64+
schema.fields().len(),
65+
6,
66+
"Schema should contain all 6 columns"
67+
);
6468

6569
// Check that we have all expected columns
66-
let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
70+
let field_names: Vec<&str> =
71+
schema.fields().iter().map(|f| f.name().as_str()).collect();
6772
assert!(field_names.contains(&"service_id"));
6873
assert!(field_names.contains(&"route_type"));
6974
assert!(field_names.contains(&"agency_id"));
@@ -82,29 +87,63 @@ async fn test_csv_schema_inference_different_column_counts() -> Result<()> {
8287

8388
// Verify we can actually read the data
8489
let results = df.collect().await?;
85-
90+
8691
// Calculate total rows across all batches
8792
let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum();
8893
assert_eq!(total_rows, 6, "Should have 6 total rows across all batches");
8994

9095
// All batches should have 6 columns (the union schema)
9196
for batch in &results {
9297
assert_eq!(batch.num_columns(), 6, "All batches should have 6 columns");
98+
assert_eq!(
99+
batch.schema().fields().len(),
100+
6,
101+
"Each batch should use the union schema with 6 fields"
102+
);
93103
}
94104

95-
// Verify that the union schema is being used correctly
96-
// We should be able to find records from both files
97-
println!("✅ Successfully read {} record batches with {} total rows", results.len(), total_rows);
105+
// Verify the actual content of the data
106+
// Since we don't know the exact order of rows, just verify the overall structure
107+
108+
// Check that all batches have nulls in the correct places
109+
let mut null_count_col3 = 0;
110+
let mut null_count_col4 = 0;
111+
let mut null_count_col5 = 0;
112+
let mut non_null_count_col3 = 0;
113+
let mut non_null_count_col4 = 0;
114+
let mut non_null_count_col5 = 0;
98115

99-
// Verify schema has all expected columns
100116
for batch in &results {
101-
assert_eq!(batch.schema().fields().len(), 6, "Each batch should use the union schema with 6 fields");
117+
// Count nulls and non-nulls for columns 3-5 (platform_number, direction, stop_sequence)
118+
for i in 0..batch.num_rows() {
119+
if batch.column(3).is_null(i) {
120+
null_count_col3 += 1;
121+
} else {
122+
non_null_count_col3 += 1;
123+
}
124+
125+
if batch.column(4).is_null(i) {
126+
null_count_col4 += 1;
127+
} else {
128+
non_null_count_col4 += 1;
129+
}
130+
131+
if batch.column(5).is_null(i) {
132+
null_count_col5 += 1;
133+
} else {
134+
non_null_count_col5 += 1;
135+
}
136+
}
102137
}
103-
104-
println!("✅ Successfully verified CSV schema inference fix!");
105-
println!(" - Read {} files with different column counts (3 vs 6)", temp_dir.path().read_dir().unwrap().count());
106-
println!(" - Inferred schema with {} columns", schema.fields().len());
107-
println!(" - Processed {} total rows", total_rows);
138+
139+
// Verify that we have the expected pattern:
140+
// 3 rows with nulls (from file1) and 3 rows with non-nulls (from file2)
141+
assert_eq!(null_count_col3, 3, "Should have 3 null values in platform_number column");
142+
assert_eq!(non_null_count_col3, 3, "Should have 3 non-null values in platform_number column");
143+
assert_eq!(null_count_col4, 3, "Should have 3 null values in direction column");
144+
assert_eq!(non_null_count_col4, 3, "Should have 3 non-null values in direction column");
145+
assert_eq!(null_count_col5, 3, "Should have 3 null values in stop_sequence column");
146+
assert_eq!(non_null_count_col5, 3, "Should have 3 non-null values in stop_sequence column");
108147

109148
Ok(())
110-
}
149+
}

datafusion/datasource-csv/src/file_format.rs

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ use arrow::error::ArrowError;
3131
use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
3232
use datafusion_common::file_options::csv_writer::CsvWriterOptions;
3333
use datafusion_common::{
34-
not_impl_err, DataFusionError, GetExt, Result, Statistics,
35-
DEFAULT_CSV_EXTENSION,
34+
not_impl_err, DataFusionError, GetExt, Result, Statistics, DEFAULT_CSV_EXTENSION,
3635
};
3736
use datafusion_common_runtime::SpawnedTask;
3837
use datafusion_datasource::decoder::Decoder;
@@ -499,17 +498,17 @@ impl CsvFormat {
499498
/// stream of delimited chunks returning the inferred schema and the
500499
/// number of lines that were read.
501500
///
502-
/// This method now supports CSV files with different numbers of columns.
501+
/// This method can handle CSV files with different numbers of columns.
503502
/// The inferred schema will be the union of all columns found across all files.
504503
/// Files with fewer columns will have missing columns filled with null values.
505504
///
506505
/// # Example
507-
///
506+
///
508507
/// If you have two CSV files:
509508
/// - `file1.csv`: `col1,col2,col3`
510509
/// - `file2.csv`: `col1,col2,col3,col4,col5`
511510
///
512-
/// The inferred schema will contain all 5 columns, with files that don't
511+
/// The inferred schema will contain all 5 columns, with files that don't
513512
/// have columns 4 and 5 having null values for those columns.
514513
pub async fn infer_schema_from_stream(
515514
&self,
@@ -585,14 +584,13 @@ impl CsvFormat {
585584
column_type_possibilities.push(possibilities);
586585
}
587586
}
588-
587+
589588
// Update type possibilities for columns that exist in this file
590-
// We take the minimum of fields.len() and column_type_possibilities.len()
591-
// to avoid index out of bounds when a file has fewer columns
592-
let max_fields_to_process = fields.len().min(column_type_possibilities.len());
593-
for field_idx in 0..max_fields_to_process {
594-
if let Some(field) = fields.get(field_idx) {
595-
column_type_possibilities[field_idx].insert(field.data_type().clone());
589+
// Only process fields that exist in both the current file and our tracking structures
590+
for (field_idx, field) in fields.iter().enumerate() {
591+
if field_idx < column_type_possibilities.len() {
592+
column_type_possibilities[field_idx]
593+
.insert(field.data_type().clone());
596594
}
597595
}
598596
}
@@ -607,7 +605,10 @@ impl CsvFormat {
607605
}
608606
}
609607

610-
pub(crate) fn build_schema_helper(names: Vec<String>, types: &[HashSet<DataType>]) -> Schema {
608+
pub(crate) fn build_schema_helper(
609+
names: Vec<String>,
610+
types: &[HashSet<DataType>],
611+
) -> Schema {
611612
let fields = names
612613
.into_iter()
613614
.zip(types)
@@ -781,3 +782,82 @@ impl DataSink for CsvSink {
781782
FileSink::write_all(self, data, context).await
782783
}
783784
}
785+
786+
#[cfg(test)]
787+
mod tests {
788+
use super::build_schema_helper;
789+
use arrow::datatypes::DataType;
790+
use std::collections::HashSet;
791+
792+
#[test]
793+
fn test_build_schema_helper_different_column_counts() {
794+
// Test the core schema building logic with different column counts
795+
let mut column_names =
796+
vec!["col1".to_string(), "col2".to_string(), "col3".to_string()];
797+
798+
// Simulate adding two more columns from another file
799+
column_names.push("col4".to_string());
800+
column_names.push("col5".to_string());
801+
802+
let column_type_possibilities = vec![
803+
HashSet::from([DataType::Int64]),
804+
HashSet::from([DataType::Utf8]),
805+
HashSet::from([DataType::Float64]),
806+
HashSet::from([DataType::Utf8]), // col4
807+
HashSet::from([DataType::Utf8]), // col5
808+
];
809+
810+
let schema = build_schema_helper(column_names, &column_type_possibilities);
811+
812+
// Verify schema has 5 columns
813+
assert_eq!(schema.fields().len(), 5);
814+
assert_eq!(schema.field(0).name(), "col1");
815+
assert_eq!(schema.field(1).name(), "col2");
816+
assert_eq!(schema.field(2).name(), "col3");
817+
assert_eq!(schema.field(3).name(), "col4");
818+
assert_eq!(schema.field(4).name(), "col5");
819+
820+
// All fields should be nullable
821+
for field in schema.fields() {
822+
assert!(
823+
field.is_nullable(),
824+
"Field {} should be nullable",
825+
field.name()
826+
);
827+
}
828+
}
829+
830+
#[test]
831+
fn test_build_schema_helper_type_merging() {
832+
// Test type merging logic
833+
let column_names = vec!["col1".to_string(), "col2".to_string()];
834+
835+
let column_type_possibilities = vec![
836+
HashSet::from([DataType::Int64, DataType::Float64]), // Should resolve to Float64
837+
HashSet::from([DataType::Utf8]), // Should remain Utf8
838+
];
839+
840+
let schema = build_schema_helper(column_names, &column_type_possibilities);
841+
842+
// col1 should be Float64 due to Int64 + Float64 = Float64
843+
assert_eq!(*schema.field(0).data_type(), DataType::Float64);
844+
845+
// col2 should remain Utf8
846+
assert_eq!(*schema.field(1).data_type(), DataType::Utf8);
847+
}
848+
849+
#[test]
850+
fn test_build_schema_helper_conflicting_types() {
851+
// Test when we have incompatible types - should default to Utf8
852+
let column_names = vec!["col1".to_string()];
853+
854+
let column_type_possibilities = vec![
855+
HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), // Should resolve to Utf8 due to conflicts
856+
];
857+
858+
let schema = build_schema_helper(column_names, &column_type_possibilities);
859+
860+
// Should default to Utf8 for conflicting types
861+
assert_eq!(*schema.field(0).data_type(), DataType::Utf8);
862+
}
863+
}

datafusion/datasource-csv/src/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
pub mod file_format;
2323
pub mod source;
24-
#[cfg(test)]
25-
mod tests;
2624

2725
use std::sync::Arc;
2826

0 commit comments

Comments
 (0)