Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ All notable changes to this project will be documented in this file.
## Unreleased

- Upgrade zstd source code from v1.5.6 to [v1.5.7](https://github.com/facebook/zstd/releases/tag/v1.5.7)
- Raise an exception when attempting to decompress empty data
- Build wheels for Windows ARM64
- Support for PyPy 3.11

Expand Down
2 changes: 1 addition & 1 deletion src/bin_ext/decompressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ decompress(PyObject *module, PyObject *args, PyObject *kwargs)

/* Check data integrity. at_frame_edge flag is 1 when both the input and
output streams are at a frame edge. */
if (self.at_frame_edge == 0) {
if (self.at_frame_edge == 0 || in.pos == 0) {
char *extra_msg = (Py_SIZE(ret) == 0) ? "." :
", if want to output these decompressed data, use "
"decompress_stream function or "
Expand Down
7 changes: 6 additions & 1 deletion src/bin_ext/file.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ typedef struct {
/* File states. On Linux/macOS/Windows, Py_off_t is signed, so
ZstdFile/SeekableZstdFile use int64_t as file position/size. */
PyObject *fp; /* File object */
int sof; /* At SOF, 0 or 1. */
int eof; /* At EOF, 0 or 1. */
int64_t pos; /* Decompressed position, >= 0. */
int64_t size; /* File size, -1 means unknown. */
Expand Down Expand Up @@ -109,6 +110,7 @@ ZstdFileReader_init(ZstdFileReader *self, PyObject *args, PyObject *kwargs)
assert(self->dict == NULL);
assert(self->read_size == NULL);
assert(self->fp == NULL);
assert(self->sof == 0);
assert(self->eof == 0);
assert(self->pos == 0);
assert(self->size == 0);
Expand Down Expand Up @@ -139,6 +141,7 @@ ZstdFileReader_init(ZstdFileReader *self, PyObject *args, PyObject *kwargs)
/* File states */
Py_INCREF(fp);
self->fp = fp;
self->sof = 1;
self->size = -1;

/* Decompression states */
Expand Down Expand Up @@ -239,7 +242,7 @@ decompress_into(ZstdFileReader *self,

/* EOF */
if (read_len == 0) {
if (self->at_frame_edge) {
if (self->at_frame_edge && !self->sof) {
self->eof = 1;
self->pos += out->pos - orig_pos;
self->size = self->pos;
Expand All @@ -254,6 +257,7 @@ decompress_into(ZstdFileReader *self,
self->in.src = read_buf;
self->in.size = read_len;
self->in.pos = 0;
self->sof = 0;
}

/* Decompress */
Expand Down Expand Up @@ -423,6 +427,7 @@ ZstdFileReader_reset_session(ZstdFileReader *self)
{
/* Reset decompression states */
self->needs_input = 1;
self->sof = 1;
self->at_frame_edge = 1;
self->in.size = 0;
self->in.pos = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/cffi/decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def decompress(data, zstd_dict=None, option=None):

# Check data integrity. at_frame_edge flag is True when the both the input
# and output streams are at a frame edge.
if not decomp._at_frame_edge:
if not decomp._at_frame_edge or not in_buf.pos:
extra_msg = "." if (len(ret) == 0) \
else (", if want to output these decompressed data, use "
"decompress_stream function or "
Expand Down
5 changes: 4 additions & 1 deletion src/cffi/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, fp, zstd_dict, option, read_size):

# File states, the last three are public attributes.
self._fp = fp
self._sof = True # start of file
self.eof = False
self.pos = 0 # Decompressed position
self.size = -1 # File size, -1 means unknown.
Expand Down Expand Up @@ -68,7 +69,7 @@ def _decompress_into(self, out_b, fill_full):
self._in_dat = self._fp.read(self._read_size)
# EOF
if not self._in_dat:
if self._at_frame_edge:
if self._at_frame_edge and not self._sof:
self.eof = True
self.pos += out_b.pos - orig_pos
self.size = self.pos
Expand All @@ -79,6 +80,7 @@ def _decompress_into(self, out_b, fill_full):
in_b.src = ffi.from_buffer(self._in_dat)
in_b.size = _nbytes(self._in_dat)
in_b.pos = 0
self._sof = False

# Decompress
zstd_ret = m.ZSTD_decompressStream(self._dctx, out_b, in_b)
Expand Down Expand Up @@ -171,6 +173,7 @@ def forward(self, offset):
def reset_session(self):
# Reset decompression states
self._needs_input = True
self._sof = True
self._at_frame_edge = True
self._in_buf.size = 0
self._in_buf.pos = 0
Expand Down
14 changes: 4 additions & 10 deletions tests/test_seekable.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,8 @@ def test_load(self):
# empty
b = BytesIO()
with SeekableZstdFile(b, 'r') as f:
self.assertEqual(f.read(10), b'')
with self.assertRaises(EOFError):
f.read(10)

# not a seekable format
b = BytesIO(COMPRESSED*10)
Expand Down Expand Up @@ -882,15 +883,8 @@ def test_read(self):

def test_read_empty(self):
with SeekableZstdFile(BytesIO(b''), 'r') as f:
self.assertEqual(f.read(), b'')
self.assertEqual(f.tell(), 0)

self.assertEqual(f.seek(2), 0)
self.assertEqual(f.read(), b'')
self.assertEqual(f.tell(), 0)

self.assertEqual(f.seek(-2), 0)
self.assertEqual(f.read(), b'')
with self.assertRaises(EOFError):
f.read()
self.assertEqual(f.tell(), 0)

def test_seek(self):
Expand Down
21 changes: 13 additions & 8 deletions tests/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,8 @@ def test_compress_empty(self):
bo.close()

def test_decompress_empty(self):
self.assertEqual(decompress(b''), b'')
with self.assertRaises(ZstdError):
decompress(b'')

d = ZstdDecompressor()
self.assertEqual(d.decompress(b''), b'')
Expand Down Expand Up @@ -1325,7 +1326,8 @@ def setUpClass(cls):
cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|'

def test_function_decompress(self):
self.assertEqual(decompress(b''), b'')
with self.assertRaises(ZstdError):
decompress(b'')

self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*1024)

Expand Down Expand Up @@ -2234,8 +2236,8 @@ def test_empty_input(self):
dat1 = b''

# decompress() function
dat2 = decompress(dat1)
self.assertEqual(len(dat2), 0)
with self.assertRaises(ZstdError):
decompress(dat1)

# ZstdDecompressor class
d = ZstdDecompressor()
Expand Down Expand Up @@ -2880,10 +2882,12 @@ def test_read_0(self):
# empty file
with ZstdFile(BytesIO(b'')) as f:
self.assertEqual(f.read(0), b"")
self.assertEqual(f.read(10), b"")
with self.assertRaises(EOFError):
f.read(10)

with ZstdFile(BytesIO(b'')) as f:
self.assertEqual(f.read(10), b"")
with self.assertRaises(EOFError):
f.read(10)

def test_read_10(self):
with ZstdFile(BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
Expand Down Expand Up @@ -3517,8 +3521,9 @@ def test_binary_modes(self):
def test_text_modes(self):
# empty input
with open(BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader:
for _ in reader:
pass
with self.assertRaises(EOFError):
for _ in reader:
pass

# read
uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
Expand Down