44import decimal
55import itertools
66import logging
7+ import warnings
78from collections .abc import Generator
89from collections .abc import Iterator
910from collections .abc import Mapping
11+ from io import BufferedReader
1012from io import BytesIO
1113from os import PathLike
1214from pathlib import Path
4648from databento .common .constants import SCHEMA_STRUCT_MAP
4749from databento .common .constants import SCHEMA_STRUCT_MAP_V1
4850from databento .common .error import BentoError
51+ from databento .common .error import BentoWarning
4952from databento .common .symbology import InstrumentMap
5053from databento .common .types import DBNRecord
5154from databento .common .types import Default
@@ -150,7 +153,7 @@ def __init__(self, source: PathLike[str] | str):
150153 )
151154
152155 self ._name = self ._path .name
153- self .__buffer : IO [ bytes ] | None = None
156+ self .__buffer : BufferedReader | None = None
154157
155158 @property
156159 def name (self ) -> str :
@@ -189,13 +192,13 @@ def path(self) -> Path:
189192 return self ._path
190193
191194 @property
192- def reader (self ) -> IO [ bytes ] :
195+ def reader (self ) -> BufferedReader :
193196 """
194197 Return a reader for this file.
195198
196199 Returns
197200 -------
198- IO
201+ BufferedReader
199202
200203 """
201204 if self .__buffer is None :
@@ -259,14 +262,14 @@ def nbytes(self) -> int:
259262 return self .__buffer .getbuffer ().nbytes
260263
261264 @property
262- def reader (self ) -> IO [ bytes ] :
265+ def reader (self ) -> BytesIO :
263266 """
264267 Return a reader for this buffer. The reader beings at the start of the
265268 buffer.
266269
267270 Returns
268271 -------
269- IO
272+ BytesIO
270273
271274 """
272275 self .__buffer .seek (0 )
@@ -391,8 +394,8 @@ def __iter__(self) -> Generator[DBNRecord, None, None]:
391394 yield record
392395 else :
393396 if len (decoder .buffer ()) > 0 :
394- raise BentoError (
395- "DBN file is truncated or contains an incomplete record" ,
397+ warnings . warn (
398+ BentoWarning ( "DBN file is truncated or contains an incomplete record" ) ,
396399 )
397400 break
398401
@@ -516,21 +519,18 @@ def reader(self) -> IO[bytes]:
516519
517520 Returns
518521 -------
519- BinaryIO
522+ IO[bytes]
520523
521524 See Also
522525 --------
523526 DBNStore.raw
524527
525528 """
526529 if self .compression == Compression .ZSTD :
527- reader : IO [ bytes ] = zstandard .ZstdDecompressor ().stream_reader (
530+ return zstandard .ZstdDecompressor ().stream_reader (
528531 self ._data_source .reader ,
529532 )
530- else :
531- reader = self ._data_source .reader
532-
533- return reader
533+ return self ._data_source .reader
534534
535535 @property
536536 def schema (self ) -> Schema | None :
@@ -792,6 +792,7 @@ def to_csv(
792792 map_symbols : bool = True ,
793793 compression : Compression | str = Compression .NONE ,
794794 schema : Schema | str | None = None ,
795+ mode : Literal ["w" , "x" ] = "w" ,
795796 ) -> None :
796797 """
797798 Write the data to a file in CSV format.
@@ -816,6 +817,8 @@ def to_csv(
816817 schema : Schema or str, optional
817818 The DBN schema for the csv.
818819 This is only required when reading a DBN stream with mixed record types.
820+ mode : str, default "w"
821+ The file write mode to use, either "x" or "w".
819822
820823 Raises
821824 ------
@@ -825,14 +828,15 @@ def to_csv(
825828 """
826829 compression = validate_enum (compression , Compression , "compression" )
827830 schema = validate_maybe_enum (schema , Schema , "schema" )
831+ file_path = validate_file_write_path (path , "path" , exist_ok = mode == "w" )
828832 if schema is None :
829833 if self .schema is None :
830834 raise ValueError ("a schema must be specified for mixed DBN data" )
831835 schema = self .schema
832836
833- with open (path , "xb " ) as output :
837+ with open (file_path , f" { mode } b " ) as output :
834838 self ._transcode (
835- output = output ,
839+ output = output , # type: ignore [arg-type]
836840 encoding = Encoding .CSV ,
837841 pretty_px = pretty_px ,
838842 pretty_ts = pretty_ts ,
@@ -961,6 +965,7 @@ def to_parquet(
961965 pretty_ts : bool = True ,
962966 map_symbols : bool = True ,
963967 schema : Schema | str | None = None ,
968+ mode : Literal ["w" , "x" ] = "w" ,
964969 ** kwargs : Any ,
965970 ) -> None :
966971 """
@@ -983,6 +988,8 @@ def to_parquet(
983988 schema : Schema or str, optional
984989 The DBN schema for the parquet file.
985990 This is only required when reading a DBN stream with mixed record types.
991+ mode : str, default "w"
992+ The file write mode to use, either "x" or "w".
986993
987994 Raises
988995 ------
@@ -994,6 +1001,7 @@ def to_parquet(
9941001 if price_type == "decimal" :
9951002 raise ValueError ("the 'decimal' price type is not currently supported" )
9961003
1004+ file_path = validate_file_write_path (path , "path" , exist_ok = mode == "w" )
9971005 schema = validate_maybe_enum (schema , Schema , "schema" )
9981006 if schema is None :
9991007 if self .schema is None :
@@ -1015,7 +1023,7 @@ def to_parquet(
10151023 # Initialize the writer using the first DataFrame
10161024 parquet_schema = pa .Schema .from_pandas (frame )
10171025 writer = pq .ParquetWriter (
1018- where = path ,
1026+ where = file_path ,
10191027 schema = parquet_schema ,
10201028 ** kwargs ,
10211029 )
@@ -1033,6 +1041,7 @@ def to_file(
10331041 self ,
10341042 path : PathLike [str ] | str ,
10351043 mode : Literal ["w" , "x" ] = "w" ,
1044+ compression : Compression | str | None = None ,
10361045 ) -> None :
10371046 """
10381047 Write the data to a DBN file at the given path.
@@ -1043,6 +1052,8 @@ def to_file(
10431052 The file path to write to.
10441053 mode : str, default "w"
10451054 The file write mode to use, either "x" or "w".
1055+ compression : Compression or str, optional
1056+ The compression format to write. If `None`, uses the same compression as the underlying data.
10461057
10471058 Raises
10481059 ------
@@ -1054,9 +1065,35 @@ def to_file(
10541065 If path is not writable.
10551066
10561067 """
1068+ compression = validate_maybe_enum (compression , Compression , "compression" )
10571069 file_path = validate_file_write_path (path , "path" , exist_ok = mode == "w" )
1058- file_path .write_bytes (self ._data_source .reader .read ())
1059- self ._data_source = FileDataSource (file_path )
1070+
1071+ writer : IO [bytes ] | zstandard .ZstdCompressionWriter
1072+ if compression is None or compression == self .compression :
1073+ # Handle trivial case
1074+ with open (file_path , mode = f"{ mode } b" ) as writer :
1075+ reader = self ._data_source .reader
1076+ while chunk := reader .read (2 ** 16 ):
1077+ writer .write (chunk )
1078+ return
1079+
1080+ if compression == Compression .ZSTD :
1081+ writer = zstandard .ZstdCompressor (
1082+ write_checksum = True ,
1083+ ).stream_writer (
1084+ open (file_path , mode = f"{ mode } b" ),
1085+ closefd = True ,
1086+ )
1087+ else :
1088+ writer = open (file_path , mode = f"{ mode } b" )
1089+
1090+ try :
1091+ reader = self .reader
1092+
1093+ while chunk := reader .read (2 ** 16 ):
1094+ writer .write (chunk )
1095+ finally :
1096+ writer .close ()
10601097
10611098 def to_json (
10621099 self ,
@@ -1066,6 +1103,7 @@ def to_json(
10661103 map_symbols : bool = True ,
10671104 compression : Compression | str = Compression .NONE ,
10681105 schema : Schema | str | None = None ,
1106+ mode : Literal ["w" , "x" ] = "w" ,
10691107 ) -> None :
10701108 """
10711109 Write the data to a file in JSON format.
@@ -1089,6 +1127,8 @@ def to_json(
10891127 schema : Schema or str, optional
10901128 The DBN schema for the json.
10911129 This is only required when reading a DBN stream with mixed record types.
1130+ mode : str, default "w"
1131+ The file write mode to use, either "x" or "w".
10921132
10931133 Raises
10941134 ------
@@ -1098,14 +1138,16 @@ def to_json(
10981138 """
10991139 compression = validate_enum (compression , Compression , "compression" )
11001140 schema = validate_maybe_enum (schema , Schema , "schema" )
1141+ file_path = validate_file_write_path (path , "path" , exist_ok = mode == "w" )
1142+
11011143 if schema is None :
11021144 if self .schema is None :
11031145 raise ValueError ("a schema must be specified for mixed DBN data" )
11041146 schema = self .schema
11051147
1106- with open (path , "xb " ) as output :
1148+ with open (file_path , f" { mode } b " ) as output :
11071149 self ._transcode (
1108- output = output ,
1150+ output = output , # type: ignore [arg-type]
11091151 encoding = Encoding .JSON ,
11101152 pretty_px = pretty_px ,
11111153 pretty_ts = pretty_ts ,
@@ -1239,8 +1281,10 @@ def _transcode(
12391281 transcoder .write (byte_chunk )
12401282
12411283 if transcoder .buffer ():
1242- raise BentoError (
1243- "DBN file is truncated or contains an incomplete record" ,
1284+ warnings .warn (
1285+ BentoWarning (
1286+ "DBN file is truncated or contains an incomplete record" ,
1287+ ),
12441288 )
12451289
12461290 transcoder .flush ()
@@ -1285,28 +1329,38 @@ def __init__(
12851329 self ._dtype = np .dtype (dtype )
12861330 self ._offset = offset
12871331 self ._count = count
1332+ self ._close_on_next = False
12881333
12891334 self ._reader .seek (offset )
12901335
12911336 def __iter__ (self ) -> NDArrayStreamIterator :
12921337 return self
12931338
12941339 def __next__ (self ) -> np .ndarray [Any , Any ]:
1340+ if self ._close_on_next :
1341+ raise StopIteration
1342+
12951343 if self ._count is None :
12961344 read_size = - 1
12971345 else :
12981346 read_size = self ._dtype .itemsize * max (self ._count , 1 )
12991347
13001348 if buffer := self ._reader .read (read_size ):
1349+ loose_bytes = len (buffer ) % self ._dtype .itemsize
1350+ if loose_bytes != 0 :
1351+ warnings .warn (
1352+ BentoWarning ("DBN file is truncated or contains an incomplete record" ),
1353+ )
1354+ buffer = buffer [:- loose_bytes ]
1355+ self ._close_on_next = True # decode one more buffer before stopping
1356+
13011357 try :
13021358 return np .frombuffer (
13031359 buffer = buffer ,
13041360 dtype = self ._dtype ,
13051361 )
1306- except ValueError :
1307- raise BentoError (
1308- "DBN file is truncated or contains an incomplete record" ,
1309- )
1362+ except ValueError as exc :
1363+ raise BentoError ("Cannot decode DBN stream" ) from exc
13101364
13111365 raise StopIteration
13121366
@@ -1351,10 +1405,8 @@ def __next__(self) -> np.ndarray[Any, Any]:
13511405 dtype = self ._dtype ,
13521406 count = num_records ,
13531407 )
1354- except ValueError :
1355- raise BentoError (
1356- "DBN file is truncated or contains an incomplete record" ,
1357- ) from None
1408+ except ValueError as exc :
1409+ raise BentoError ("Cannot decode DBN stream" ) from exc
13581410
13591411
13601412class DataFrameIterator :
0 commit comments