Skip to content

Commit ca2a398

Browse files
committed
Initial work for file format writer API
1 parent 3f37f56 commit ca2a398

File tree

3 files changed

+208
-74
lines changed

3 files changed

+208
-74
lines changed

pyiceberg/io/fileformat.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""File Format API for reading and writing Iceberg data files."""
19+
20+
from __future__ import annotations
21+
22+
from abc import ABC, abstractmethod
23+
from dataclasses import dataclass
24+
from typing import TYPE_CHECKING, Any
25+
26+
from pyiceberg.io import OutputFile
27+
from pyiceberg.manifest import FileFormat
28+
from pyiceberg.partitioning import PartitionField, PartitionSpec, partition_record_value
29+
from pyiceberg.schema import Schema
30+
from pyiceberg.typedef import Properties, Record
31+
32+
if TYPE_CHECKING:
33+
import pyarrow as pa
34+
35+
from pyiceberg.io.pyarrow import StatsAggregator
36+
37+
38+
@dataclass(frozen=True)
39+
class DataFileStatistics:
40+
record_count: int
41+
column_sizes: dict[int, int]
42+
value_counts: dict[int, int]
43+
null_value_counts: dict[int, int]
44+
nan_value_counts: dict[int, int]
45+
column_aggregates: dict[int, StatsAggregator]
46+
split_offsets: list[int]
47+
48+
def _partition_value(self, partition_field: PartitionField, schema: Schema) -> Any:
49+
if partition_field.source_id not in self.column_aggregates:
50+
return None
51+
52+
source_field = schema.find_field(partition_field.source_id)
53+
iceberg_transform = partition_field.transform
54+
55+
if not iceberg_transform.preserves_order:
56+
raise ValueError(
57+
f"Cannot infer partition value from parquet metadata for a non-linear Partition Field: "
58+
f"{partition_field.name} with transform {partition_field.transform}"
59+
)
60+
61+
transform_func = iceberg_transform.transform(source_field.field_type)
62+
63+
lower_value = transform_func(
64+
partition_record_value(
65+
partition_field=partition_field,
66+
value=self.column_aggregates[partition_field.source_id].current_min,
67+
schema=schema,
68+
)
69+
)
70+
upper_value = transform_func(
71+
partition_record_value(
72+
partition_field=partition_field,
73+
value=self.column_aggregates[partition_field.source_id].current_max,
74+
schema=schema,
75+
)
76+
)
77+
if lower_value != upper_value:
78+
raise ValueError(
79+
f"Cannot infer partition value from parquet metadata as there are more than one partition values "
80+
f"for Partition Field: {partition_field.name}. {lower_value=}, {upper_value=}"
81+
)
82+
83+
return lower_value
84+
85+
def partition(self, partition_spec: PartitionSpec, schema: Schema) -> Record:
86+
return Record(*[self._partition_value(field, schema) for field in partition_spec.fields])
87+
88+
def to_serialized_dict(self) -> dict[str, Any]:
89+
lower_bounds = {}
90+
upper_bounds = {}
91+
92+
for k, agg in self.column_aggregates.items():
93+
_min = agg.min_as_bytes()
94+
if _min is not None:
95+
lower_bounds[k] = _min
96+
_max = agg.max_as_bytes()
97+
if _max is not None:
98+
upper_bounds[k] = _max
99+
return {
100+
"record_count": self.record_count,
101+
"column_sizes": self.column_sizes,
102+
"value_counts": self.value_counts,
103+
"null_value_counts": self.null_value_counts,
104+
"nan_value_counts": self.nan_value_counts,
105+
"lower_bounds": lower_bounds,
106+
"upper_bounds": upper_bounds,
107+
"split_offsets": self.split_offsets,
108+
}
109+
110+
111+
class FileFormatWriter(ABC):
112+
"""Writes data to a single file in a specific format."""
113+
114+
_result: DataFileStatistics | None = None
115+
116+
@abstractmethod
117+
def write(self, table: pa.Table) -> None:
118+
"""Write a batch of data. May be called multiple times."""
119+
120+
@abstractmethod
121+
def close(self) -> DataFileStatistics:
122+
"""Finalize the file and return statistics."""
123+
124+
def result(self) -> DataFileStatistics:
125+
"""Return statistics from a previous close() call."""
126+
if self._result is None:
127+
raise RuntimeError("Writer has not been closed yet")
128+
return self._result
129+
130+
def __enter__(self) -> FileFormatWriter:
131+
"""Enter the context manager."""
132+
return self
133+
134+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
135+
"""Exit the context manager, closing the writer and caching statistics."""
136+
if exc_type is not None:
137+
try:
138+
self._result = self.close()
139+
except Exception:
140+
pass
141+
return
142+
self._result = self.close()
143+
144+
145+
class FileFormatModel(ABC):
146+
"""Represents a file format's capabilities. Creates writers."""
147+
148+
@property
149+
@abstractmethod
150+
def format(self) -> FileFormat: ...
151+
152+
@abstractmethod
153+
def file_extension(self) -> str:
154+
"""Return file extension without dot, e.g. 'parquet', 'orc'."""
155+
156+
@abstractmethod
157+
def create_writer(
158+
self,
159+
output_file: OutputFile,
160+
file_schema: Schema,
161+
properties: Properties,
162+
) -> FileFormatWriter: ...
163+
164+
165+
class FileFormatFactory:
166+
"""Registry of FileFormatModel implementations."""
167+
168+
_registry: dict[FileFormat, FileFormatModel] = {}
169+
170+
@classmethod
171+
def register(cls, model: FileFormatModel) -> None:
172+
cls._registry[model.format] = model
173+
174+
@classmethod
175+
def get(cls, file_format: FileFormat) -> FileFormatModel:
176+
if file_format not in cls._registry:
177+
raise ValueError(f"No writer registered for {file_format}. Available: {list(cls._registry.keys())}")
178+
return cls._registry[file_format]
179+
180+
@classmethod
181+
def available_formats(cls) -> list[FileFormat]:
182+
return list(cls._registry.keys())

pyiceberg/io/pyarrow.py

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,13 @@
120120
OutputFile,
121121
OutputStream,
122122
)
123+
from pyiceberg.io.fileformat import DataFileStatistics
123124
from pyiceberg.manifest import (
124125
DataFile,
125126
DataFileContent,
126127
FileFormat,
127128
)
128-
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
129+
from pyiceberg.partitioning import PartitionFieldValue, PartitionKey, PartitionSpec
129130
from pyiceberg.schema import (
130131
PartnerAccessor,
131132
PreOrderSchemaVisitor,
@@ -2473,79 +2474,6 @@ def parquet_path_to_id_mapping(
24732474
return result
24742475

24752476

2476-
@dataclass(frozen=True)
2477-
class DataFileStatistics:
2478-
record_count: int
2479-
column_sizes: dict[int, int]
2480-
value_counts: dict[int, int]
2481-
null_value_counts: dict[int, int]
2482-
nan_value_counts: dict[int, int]
2483-
column_aggregates: dict[int, StatsAggregator]
2484-
split_offsets: list[int]
2485-
2486-
def _partition_value(self, partition_field: PartitionField, schema: Schema) -> Any:
2487-
if partition_field.source_id not in self.column_aggregates:
2488-
return None
2489-
2490-
source_field = schema.find_field(partition_field.source_id)
2491-
iceberg_transform = partition_field.transform
2492-
2493-
if not iceberg_transform.preserves_order:
2494-
raise ValueError(
2495-
f"Cannot infer partition value from parquet metadata for a non-linear Partition Field: "
2496-
f"{partition_field.name} with transform {partition_field.transform}"
2497-
)
2498-
2499-
transform_func = iceberg_transform.transform(source_field.field_type)
2500-
2501-
lower_value = transform_func(
2502-
partition_record_value(
2503-
partition_field=partition_field,
2504-
value=self.column_aggregates[partition_field.source_id].current_min,
2505-
schema=schema,
2506-
)
2507-
)
2508-
upper_value = transform_func(
2509-
partition_record_value(
2510-
partition_field=partition_field,
2511-
value=self.column_aggregates[partition_field.source_id].current_max,
2512-
schema=schema,
2513-
)
2514-
)
2515-
if lower_value != upper_value:
2516-
raise ValueError(
2517-
f"Cannot infer partition value from parquet metadata as there are more than one partition values "
2518-
f"for Partition Field: {partition_field.name}. {lower_value=}, {upper_value=}"
2519-
)
2520-
2521-
return lower_value
2522-
2523-
def partition(self, partition_spec: PartitionSpec, schema: Schema) -> Record:
2524-
return Record(*[self._partition_value(field, schema) for field in partition_spec.fields])
2525-
2526-
def to_serialized_dict(self) -> dict[str, Any]:
2527-
lower_bounds = {}
2528-
upper_bounds = {}
2529-
2530-
for k, agg in self.column_aggregates.items():
2531-
_min = agg.min_as_bytes()
2532-
if _min is not None:
2533-
lower_bounds[k] = _min
2534-
_max = agg.max_as_bytes()
2535-
if _max is not None:
2536-
upper_bounds[k] = _max
2537-
return {
2538-
"record_count": self.record_count,
2539-
"column_sizes": self.column_sizes,
2540-
"value_counts": self.value_counts,
2541-
"null_value_counts": self.null_value_counts,
2542-
"nan_value_counts": self.nan_value_counts,
2543-
"lower_bounds": lower_bounds,
2544-
"upper_bounds": upper_bounds,
2545-
"split_offsets": self.split_offsets,
2546-
}
2547-
2548-
25492477
def data_file_statistics_from_parquet_metadata(
25502478
parquet_metadata: pq.FileMetaData,
25512479
stats_columns: dict[int, StatisticsCollector],

tests/io/test_fileformat.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
19+
def test_backward_compat_import() -> None:
20+
"""DataFileStatistics can still be imported from pyiceberg.io.pyarrow."""
21+
from pyiceberg.io.fileformat import DataFileStatistics as dFS # noqa: F401
22+
from pyiceberg.io.pyarrow import DataFileStatistics # noqa: F401
23+
24+
assert DataFileStatistics is dFS

0 commit comments

Comments
 (0)