From defe68c33ddaa2c2a9b3d4e53818a1fbe63c3f85 Mon Sep 17 00:00:00 2001 From: Rob Brackett Date: Fri, 7 Feb 2025 13:21:54 -0800 Subject: [PATCH] Raise in partition_all when length is invalid Sometimes an iterable's `__len__` is incorrect, which can lead to bad output from `partition_all`. We now raise an exception (when we previously output bad data) in these cases so the user can more easily find the invalid iterable and fix it. Fixes #602. --- toolz/itertoolz.py | 7 ++++++- toolz/tests/test_itertoolz.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index 1ab54851..1115ac4e 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -732,7 +732,12 @@ def partition_all(n, seq): try: # If seq defines __len__, then # we can quickly calculate where no_pad starts - yield prev[:len(seq) % n] + end = len(seq) % n + if prev[end - 1] is no_pad or prev[end] is not no_pad: + raise LookupError( + 'The sequence passed to `parition_all` has invalid length' + ) + yield prev[:end] except TypeError: # Get first index of no_pad without using .index() # https://github.com/pytoolz/toolz/issues/387 diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 27907b9e..b0da58a9 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -355,6 +355,20 @@ def __eq__(self, other): assert list(partition_all(4, [obj]*7)) == result assert list(partition_all(4, iter([obj]*7))) == result + # Test invalid __len__: https://github.com/pytoolz/toolz/issues/602 + class ListWithBadLength(list): + def __init__(self, contents, off_by=1): + self.off_by = off_by + super().__init__(contents) + + def __len__(self): + return super().__len__() + self.off_by + + too_long_list = ListWithBadLength([1, 2], off_by=+1) + assert raises(LookupError, lambda: list(partition_all(5, too_long_list))) + too_short_list = ListWithBadLength([1, 2], off_by=-1) + assert raises(LookupError, lambda: list(partition_all(5, too_short_list))) + def test_count(): assert count((1, 2, 3)) == 3