diff --git a/CHANGELOG.md b/CHANGELOG.md index d1fc99b..1bbf8dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ All notable changes to this project will be documented in this file. ## Unreleased - Upgrade zstd source code from v1.5.6 to [v1.5.7](https://github.com/facebook/zstd/releases/tag/v1.5.7) +- Raise an exception when attempting to decompress empty data - Build wheels for Windows ARM64 - Support for PyPy 3.11 diff --git a/src/bin_ext/decompressor.c b/src/bin_ext/decompressor.c index b557dd3..f88113a 100644 --- a/src/bin_ext/decompressor.c +++ b/src/bin_ext/decompressor.c @@ -825,7 +825,7 @@ decompress(PyObject *module, PyObject *args, PyObject *kwargs) /* Check data integrity. at_frame_edge flag is 1 when both the input and output streams are at a frame edge. */ - if (self.at_frame_edge == 0) { + if (self.at_frame_edge == 0 || in.pos == 0) { char *extra_msg = (Py_SIZE(ret) == 0) ? "." : ", if want to output these decompressed data, use " "decompress_stream function or " diff --git a/src/bin_ext/file.c b/src/bin_ext/file.c index 4fc1705..0911ea4 100644 --- a/src/bin_ext/file.c +++ b/src/bin_ext/file.c @@ -19,6 +19,7 @@ typedef struct { /* File states. On Linux/macOS/Windows, Py_off_t is signed, so ZstdFile/SeekableZstdFile use int64_t as file position/size. */ PyObject *fp; /* File object */ + int sof; /* At SOF, 0 or 1. */ int eof; /* At EOF, 0 or 1. */ int64_t pos; /* Decompressed position, >= 0. */ int64_t size; /* File size, -1 means unknown. */ @@ -109,6 +110,7 @@ ZstdFileReader_init(ZstdFileReader *self, PyObject *args, PyObject *kwargs) assert(self->dict == NULL); assert(self->read_size == NULL); assert(self->fp == NULL); + assert(self->sof == 0); assert(self->eof == 0); assert(self->pos == 0); assert(self->size == 0); @@ -139,6 +141,7 @@ ZstdFileReader_init(ZstdFileReader *self, PyObject *args, PyObject *kwargs) /* File states */ Py_INCREF(fp); self->fp = fp; + self->sof = 1; self->size = -1; /* Decompression states */ @@ -239,7 +242,7 @@ decompress_into(ZstdFileReader *self, /* EOF */ if (read_len == 0) { - if (self->at_frame_edge) { + if (self->at_frame_edge && !self->sof) { self->eof = 1; self->pos += out->pos - orig_pos; self->size = self->pos; @@ -254,6 +257,7 @@ decompress_into(ZstdFileReader *self, self->in.src = read_buf; self->in.size = read_len; self->in.pos = 0; + self->sof = 0; } /* Decompress */ @@ -423,6 +427,7 @@ ZstdFileReader_reset_session(ZstdFileReader *self) { /* Reset decompression states */ self->needs_input = 1; + self->sof = 1; self->at_frame_edge = 1; self->in.size = 0; self->in.pos = 0; diff --git a/src/cffi/decompressor.py b/src/cffi/decompressor.py index 4fe230e..a268116 100644 --- a/src/cffi/decompressor.py +++ b/src/cffi/decompressor.py @@ -401,7 +401,7 @@ def decompress(data, zstd_dict=None, option=None): # Check data integrity. at_frame_edge flag is True when the both the input # and output streams are at a frame edge. - if not decomp._at_frame_edge: + if not decomp._at_frame_edge or not in_buf.pos: extra_msg = "." if (len(ret) == 0) \ else (", if want to output these decompressed data, use " "decompress_stream function or " diff --git a/src/cffi/file.py b/src/cffi/file.py index cd38e84..9c815e8 100644 --- a/src/cffi/file.py +++ b/src/cffi/file.py @@ -15,6 +15,7 @@ def __init__(self, fp, zstd_dict, option, read_size): # File states, the last three are public attributes. self._fp = fp + self._sof = True # start of file self.eof = False self.pos = 0 # Decompressed position self.size = -1 # File size, -1 means unknown. @@ -68,7 +69,7 @@ def _decompress_into(self, out_b, fill_full): self._in_dat = self._fp.read(self._read_size) # EOF if not self._in_dat: - if self._at_frame_edge: + if self._at_frame_edge and not self._sof: self.eof = True self.pos += out_b.pos - orig_pos self.size = self.pos @@ -79,6 +80,7 @@ def _decompress_into(self, out_b, fill_full): in_b.src = ffi.from_buffer(self._in_dat) in_b.size = _nbytes(self._in_dat) in_b.pos = 0 + self._sof = False # Decompress zstd_ret = m.ZSTD_decompressStream(self._dctx, out_b, in_b) @@ -171,6 +173,7 @@ def forward(self, offset): def reset_session(self): # Reset decompression states self._needs_input = True + self._sof = True self._at_frame_edge = True self._in_buf.size = 0 self._in_buf.pos = 0 diff --git a/tests/test_seekable.py b/tests/test_seekable.py index 5deb01c..5f18abe 100644 --- a/tests/test_seekable.py +++ b/tests/test_seekable.py @@ -845,7 +845,8 @@ def test_load(self): # empty b = BytesIO() with SeekableZstdFile(b, 'r') as f: - self.assertEqual(f.read(10), b'') + with self.assertRaises(EOFError): + f.read(10) # not a seekable format b = BytesIO(COMPRESSED*10) @@ -882,15 +883,8 @@ def test_read(self): def test_read_empty(self): with SeekableZstdFile(BytesIO(b''), 'r') as f: - self.assertEqual(f.read(), b'') - self.assertEqual(f.tell(), 0) - - self.assertEqual(f.seek(2), 0) - self.assertEqual(f.read(), b'') - self.assertEqual(f.tell(), 0) - - self.assertEqual(f.seek(-2), 0) - self.assertEqual(f.read(), b'') + with self.assertRaises(EOFError): + f.read() self.assertEqual(f.tell(), 0) def test_seek(self): diff --git a/tests/test_zstd.py b/tests/test_zstd.py index fc8399a..3435010 100644 --- a/tests/test_zstd.py +++ b/tests/test_zstd.py @@ -1229,7 +1229,8 @@ def test_compress_empty(self): bo.close() def test_decompress_empty(self): - self.assertEqual(decompress(b''), b'') + with self.assertRaises(ZstdError): + decompress(b'') d = ZstdDecompressor() self.assertEqual(d.decompress(b''), b'') @@ -1325,7 +1326,8 @@ def setUpClass(cls): cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|' def test_function_decompress(self): - self.assertEqual(decompress(b''), b'') + with self.assertRaises(ZstdError): + decompress(b'') self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*1024) @@ -2234,8 +2236,8 @@ def test_empty_input(self): dat1 = b'' # decompress() function - dat2 = decompress(dat1) - self.assertEqual(len(dat2), 0) + with self.assertRaises(ZstdError): + decompress(dat1) # ZstdDecompressor class d = ZstdDecompressor() @@ -2880,10 +2882,12 @@ def test_read_0(self): # empty file with ZstdFile(BytesIO(b'')) as f: self.assertEqual(f.read(0), b"") - self.assertEqual(f.read(10), b"") + with self.assertRaises(EOFError): + f.read(10) with ZstdFile(BytesIO(b'')) as f: - self.assertEqual(f.read(10), b"") + with self.assertRaises(EOFError): + f.read(10) def test_read_10(self): with ZstdFile(BytesIO(COMPRESSED_100_PLUS_32KB)) as f: @@ -3517,8 +3521,9 @@ def test_binary_modes(self): def test_text_modes(self): # empty input with open(BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader: - for _ in reader: - pass + with self.assertRaises(EOFError): + for _ in reader: + pass # read uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")