Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2535,9 +2535,15 @@ config_namespace! {
// The input regex for Nulls when loading CSVs.
pub null_regex: Option<String>, default = None
pub comment: Option<u8>, default = None
// Whether to allow truncated rows when parsing.
// By default this is set to false and will error if the CSV rows have different lengths.
// When set to true then it will allow records with less than the expected number of columns
/// Whether to allow truncated rows when parsing, both within a single file and across files.
///
/// When set to false (default), reading a single CSV file which has rows of different lengths will
/// error; if reading multiple CSV files with different number of columns, it will also fail.
///
/// When set to true, reading a single CSV file with rows of different lengths will pad the truncated
/// rows with null values for the missing columns; if reading multiple CSV files with different number
/// of columns, it creates a union schema containing all columns found across the files, and will
/// pad any files missing columns with null values for their rows.
pub truncated_rows: Option<bool>, default = None
}
}
Expand Down
122 changes: 122 additions & 0 deletions datafusion/core/tests/csv_schema_fix_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Test for CSV schema inference with different column counts (GitHub issue #17516)
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::test_util::batches_to_sort_string;
use insta::assert_snapshot;
use std::fs;
use tempfile::TempDir;

#[tokio::test]
async fn test_csv_schema_inference_different_column_counts() -> Result<()> {
// Create temporary directory for test files
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let temp_path = temp_dir.path();

// Create CSV file 1 with 3 columns (simulating older railway services format)
let csv1_content = r#"service_id,route_type,agency_id
1,bus,agency1
2,rail,agency2
3,bus,agency3
"#;
fs::write(temp_path.join("services_2024.csv"), csv1_content)?;

// Create CSV file 2 with 6 columns (simulating newer railway services format)
let csv2_content = r#"service_id,route_type,agency_id,stop_platform_change,stop_planned_platform,stop_actual_platform
4,rail,agency2,true,Platform A,Platform B
5,bus,agency1,false,Stop 1,Stop 1
6,rail,agency3,true,Platform C,Platform D
"#;
fs::write(temp_path.join("services_2025.csv"), csv2_content)?;

// Create DataFusion context
let ctx = SessionContext::new();

// This should now work (previously would have failed with column count mismatch)
// Enable truncated_rows to handle files with different column counts
let df = ctx
.read_csv(
temp_path.to_str().unwrap(),
CsvReadOptions::new().truncated_rows(true),
)
.await
.expect("Should successfully read CSV directory with different column counts");

// Verify the schema contains all 6 columns (union of both files)
let df_clone = df.clone();
let schema = df_clone.schema();
assert_eq!(
schema.fields().len(),
6,
"Schema should contain all 6 columns"
);

// Check that we have all expected columns
let field_names: Vec<&str> =
schema.fields().iter().map(|f| f.name().as_str()).collect();
assert!(field_names.contains(&"service_id"));
assert!(field_names.contains(&"route_type"));
assert!(field_names.contains(&"agency_id"));
assert!(field_names.contains(&"stop_platform_change"));
assert!(field_names.contains(&"stop_planned_platform"));
assert!(field_names.contains(&"stop_actual_platform"));

// All fields should be nullable since they don't appear in all files
for field in schema.fields() {
assert!(
field.is_nullable(),
"Field {} should be nullable",
field.name()
);
}

// Verify we can actually read the data
let results = df.collect().await?;

// Calculate total rows across all batches
let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(total_rows, 6, "Should have 6 total rows across all batches");

// All batches should have 6 columns (the union schema)
for batch in &results {
assert_eq!(batch.num_columns(), 6, "All batches should have 6 columns");
assert_eq!(
batch.schema().fields().len(),
6,
"Each batch should use the union schema with 6 fields"
);
}

// Verify the actual content of the data using snapshot testing
assert_snapshot!(batches_to_sort_string(&results), @r"
+------------+------------+-----------+----------------------+-----------------------+----------------------+
| service_id | route_type | agency_id | stop_platform_change | stop_planned_platform | stop_actual_platform |
+------------+------------+-----------+----------------------+-----------------------+----------------------+
| 1 | bus | agency1 | | | |
| 2 | rail | agency2 | | | |
| 3 | bus | agency3 | | | |
| 4 | rail | agency2 | true | Platform A | Platform B |
| 5 | bus | agency1 | false | Stop 1 | Stop 1 |
| 6 | rail | agency3 | true | Platform C | Platform D |
+------------+------------+-----------+----------------------+-----------------------+----------------------+
");

Ok(())
}
124 changes: 116 additions & 8 deletions datafusion/datasource-csv/src/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,20 @@ impl FileFormat for CsvFormat {
impl CsvFormat {
/// Return the inferred schema reading up to records_to_read from a
/// stream of delimited chunks returning the inferred schema and the
/// number of lines that were read
/// number of lines that were read.
///
/// This method can handle CSV files with different numbers of columns.
/// The inferred schema will be the union of all columns found across all files.
/// Files with fewer columns will have missing columns filled with null values.
///
/// # Example
///
/// If you have two CSV files:
/// - `file1.csv`: `col1,col2,col3`
/// - `file2.csv`: `col1,col2,col3,col4,col5`
///
/// The inferred schema will contain all 5 columns, with files that don't
/// have columns 4 and 5 having null values for those columns.
pub async fn infer_schema_from_stream(
&self,
state: &dyn Session,
Expand Down Expand Up @@ -560,21 +573,37 @@ impl CsvFormat {
})
.unzip();
} else {
if fields.len() != column_type_possibilities.len() {
if fields.len() != column_type_possibilities.len()
&& !self.options.truncated_rows.unwrap_or(false)
{
return exec_err!(
"Encountered unequal lengths between records on CSV file whilst inferring schema. \
Expected {} fields, found {} fields at record {}",
column_type_possibilities.len(),
fields.len(),
record_number + 1
);
"Encountered unequal lengths between records on CSV file whilst inferring schema. \
Expected {} fields, found {} fields at record {}",
column_type_possibilities.len(),
fields.len(),
record_number + 1
);
}

// First update type possibilities for existing columns using zip
column_type_possibilities.iter_mut().zip(&fields).for_each(
|(possibilities, field)| {
possibilities.insert(field.data_type().clone());
},
);

// Handle files with different numbers of columns by extending the schema
if fields.len() > column_type_possibilities.len() {
// New columns found - extend our tracking structures
for field in fields.iter().skip(column_type_possibilities.len()) {
column_names.push(field.name().clone());
let mut possibilities = HashSet::new();
if records_read > 0 {
possibilities.insert(field.data_type().clone());
}
column_type_possibilities.push(possibilities);
}
}
}

if records_to_read == 0 {
Expand Down Expand Up @@ -761,3 +790,82 @@ impl DataSink for CsvSink {
FileSink::write_all(self, data, context).await
}
}

#[cfg(test)]
mod tests {
use super::build_schema_helper;
use arrow::datatypes::DataType;
use std::collections::HashSet;

#[test]
fn test_build_schema_helper_different_column_counts() {
// Test the core schema building logic with different column counts
let mut column_names =
vec!["col1".to_string(), "col2".to_string(), "col3".to_string()];

// Simulate adding two more columns from another file
column_names.push("col4".to_string());
column_names.push("col5".to_string());

let column_type_possibilities = vec![
HashSet::from([DataType::Int64]),
HashSet::from([DataType::Utf8]),
HashSet::from([DataType::Float64]),
HashSet::from([DataType::Utf8]), // col4
HashSet::from([DataType::Utf8]), // col5
];

let schema = build_schema_helper(column_names, &column_type_possibilities);

// Verify schema has 5 columns
assert_eq!(schema.fields().len(), 5);
assert_eq!(schema.field(0).name(), "col1");
assert_eq!(schema.field(1).name(), "col2");
assert_eq!(schema.field(2).name(), "col3");
assert_eq!(schema.field(3).name(), "col4");
assert_eq!(schema.field(4).name(), "col5");

// All fields should be nullable
for field in schema.fields() {
assert!(
field.is_nullable(),
"Field {} should be nullable",
field.name()
);
}
}

#[test]
fn test_build_schema_helper_type_merging() {
// Test type merging logic
let column_names = vec!["col1".to_string(), "col2".to_string()];

let column_type_possibilities = vec![
HashSet::from([DataType::Int64, DataType::Float64]), // Should resolve to Float64
HashSet::from([DataType::Utf8]), // Should remain Utf8
];

let schema = build_schema_helper(column_names, &column_type_possibilities);

// col1 should be Float64 due to Int64 + Float64 = Float64
assert_eq!(*schema.field(0).data_type(), DataType::Float64);

// col2 should remain Utf8
assert_eq!(*schema.field(1).data_type(), DataType::Utf8);
}

#[test]
fn test_build_schema_helper_conflicting_types() {
// Test when we have incompatible types - should default to Utf8
let column_names = vec!["col1".to_string()];

let column_type_possibilities = vec![
HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), // Should resolve to Utf8 due to conflicts
];

let schema = build_schema_helper(column_names, &column_type_possibilities);

// Should default to Utf8 for conflicting types
assert_eq!(*schema.field(0).data_type(), DataType::Utf8);
}
}