diff --git a/src/pad_tail.rs b/src/pad_tail.rs index 5595b42ba..0b66aa83b 100644 --- a/src/pad_tail.rs +++ b/src/pad_tail.rs @@ -1,4 +1,3 @@ -use crate::size_hint; use std::iter::{Fuse, FusedIterator}; /// An iterator adaptor that pads a sequence to a minimum length by filling @@ -11,8 +10,9 @@ use std::iter::{Fuse, FusedIterator}; #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] pub struct PadUsing { iter: Fuse, - min: usize, - pos: usize, + elements_from_next: usize, + elements_from_next_back: usize, + elements_required: usize, filler: F, } @@ -20,19 +20,26 @@ impl std::fmt::Debug for PadUsing where I: std::fmt::Debug, { - debug_fmt_fields!(PadUsing, iter, min, pos); + debug_fmt_fields!( + PadUsing, + iter, + elements_from_next, + elements_from_next_back, + elements_required + ); } /// Create a new `PadUsing` iterator. -pub fn pad_using(iter: I, min: usize, filler: F) -> PadUsing +pub fn pad_using(iter: I, elements_required: usize, filler: F) -> PadUsing where I: Iterator, F: FnMut(usize) -> I::Item, { PadUsing { iter: iter.fuse(), - min, - pos: 0, + elements_from_next: 0, + elements_from_next_back: 0, + elements_required, filler, } } @@ -44,40 +51,35 @@ where { type Item = I::Item; - #[inline] fn next(&mut self) -> Option { - match self.iter.next() { - None => { - if self.pos < self.min { - let e = Some((self.filler)(self.pos)); - self.pos += 1; - e - } else { - None - } - } - e => { - self.pos += 1; - e - } + let total_consumed = self.elements_from_next + self.elements_from_next_back; + + if total_consumed >= self.elements_required { + self.iter.next() + } else if let Some(e) = self.iter.next() { + self.elements_from_next += 1; + Some(e) + } else { + let e = (self.filler)(self.elements_from_next); + self.elements_from_next += 1; + Some(e) } } fn size_hint(&self) -> (usize, Option) { - let tail = self.min.saturating_sub(self.pos); - size_hint::max(self.iter.size_hint(), (tail, Some(tail))) - } + let total_consumed = self.elements_from_next + self.elements_from_next_back; + + if total_consumed >= self.elements_required { + return self.iter.size_hint(); + } + + let elements_remaining = self.elements_required - total_consumed; + let (low, high) = self.iter.size_hint(); - fn fold(self, mut init: B, mut f: G) -> B - where - G: FnMut(B, Self::Item) -> B, - { - let mut pos = self.pos; - init = self.iter.fold(init, |acc, item| { - pos += 1; - f(acc, item) - }); - (pos..self.min).map(self.filler).fold(init, f) + let lower_bound = low.max(elements_remaining); + let upper_bound = high.map(|h| h.max(elements_remaining)); + + (lower_bound, upper_bound) } } @@ -87,25 +89,20 @@ where F: FnMut(usize) -> I::Item, { fn next_back(&mut self) -> Option { - if self.min == 0 { - self.iter.next_back() - } else if self.iter.len() >= self.min { - self.min -= 1; - self.iter.next_back() - } else { - self.min -= 1; - Some((self.filler)(self.min)) + let total_consumed = self.elements_from_next + self.elements_from_next_back; + + if total_consumed >= self.elements_required { + return self.iter.next_back(); } - } - fn rfold(self, mut init: B, mut f: G) -> B - where - G: FnMut(B, Self::Item) -> B, - { - init = (self.iter.len()..self.min) - .map(self.filler) - .rfold(init, &mut f); - self.iter.rfold(init, f) + let index_from_back = self.elements_required - self.elements_from_next_back - 1; + self.elements_from_next_back += 1; + + if index_from_back >= self.iter.len() { + Some((self.filler)(index_from_back)) + } else { + self.iter.next_back() + } } } diff --git a/tests/specializations.rs b/tests/specializations.rs index 26d1f5367..9228829b5 100644 --- a/tests/specializations.rs +++ b/tests/specializations.rs @@ -118,6 +118,15 @@ where } } } + check_specialized!(it, |i| { + let mut parameters_from_fold = vec![]; + let fold_result = i.fold(vec![], |mut acc, v: I::Item| { + parameters_from_fold.push((acc.clone(), v.clone())); + acc.push(v); + acc + }); + (parameters_from_fold, fold_result) + }); check_specialized!(it, |i| { let mut parameters_from_rfold = vec![]; let rfold_result = i.rfold(vec![], |mut acc, v: I::Item| { @@ -131,6 +140,25 @@ where for n in 0..size + 2 { check_specialized!(it, |mut i| i.nth_back(n)); } + + let mut fwd = it.clone(); + let mut bwd = it.clone(); + + while fwd.next().is_some() {} + + assert_eq!( + fwd.next_back(), + None, + "iterator leaks elements after consuming forwards" + ); + + while bwd.next_back().is_some() {} + + assert_eq!( + bwd.next(), + None, + "iterator leaks elements after consuming backwards" + ); } quickcheck! {