Skip to content

Commit

Permalink
make_key/value_iterator: support iterators that return a non-referenc…
Browse files Browse the repository at this point in the history
…e from operator*
  • Loading branch information
oremanj authored and wjakob committed Mar 6, 2024
1 parent e4770f8 commit d39617e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 4 deletions.
18 changes: 16 additions & 2 deletions include/nanobind/make_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ struct iterator_state {
bool first_or_done;
};

template <typename T>
struct remove_rvalue_ref { using type = T; };
template <typename T>
struct remove_rvalue_ref<T&&> { using type = T; };

// Note: these helpers take the iterator by non-const reference because some
// iterators in the wild can't be dereferenced when const.
template <typename Iterator> struct iterator_access {
Expand All @@ -33,12 +38,21 @@ template <typename Iterator> struct iterator_access {
};

template <typename Iterator> struct iterator_key_access {
using result_type = const decltype((*std::declval<Iterator &>()).first) &;
// Note double parens in decltype((...)) to capture the value category
// as well. This will be lvalue if the iterator's operator* returned an
// lvalue reference, and xvalue if the iterator's operator* returned an
// object (or rvalue reference but that's unlikely). decltype of an xvalue
// produces T&&, but we want to return a value T from operator() in that
// case, in order to avoid creating a Python object that references a
// C++ temporary. Thus, pass the result through remove_rvalue_ref.
using result_type = typename remove_rvalue_ref<
decltype(((*std::declval<Iterator &>()).first))>::type;
result_type operator()(Iterator &it) const { return (*it).first; }
};

template <typename Iterator> struct iterator_value_access {
using result_type = const decltype((*std::declval<Iterator &>()).second) &;
using result_type = typename remove_rvalue_ref<
decltype(((*std::declval<Iterator &>()).second))>::type;
result_type operator()(Iterator &it) const { return (*it).second; }
};

Expand Down
39 changes: 39 additions & 0 deletions tests/test_make_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,47 @@ NB_MODULE(test_make_iterator_ext, m) {
return nb::make_iterator(mod, "pt_iterator", std::begin(s), std::end(s));
});

// test of map where dereferencing the iterator returns a value,
// not a reference
struct IdentityMap {
struct iterator {
int val;
std::pair<int, int> operator*() const { return {val, val}; }
iterator& operator++() { ++val; return *this; }
bool operator==(const iterator& other) const {
return val == other.val;
}
};

iterator begin() const { return iterator{0}; }
iterator end() const { return iterator{10}; }
};
nb::class_<IdentityMap>(m, "IdentityMap")
.def(nb::init<>())
.def("__iter__",
[](const IdentityMap &map) {
return nb::make_key_iterator(nb::type<IdentityMap>(),
"key_iterator",
map.begin(),
map.end());
}, nb::keep_alive<0, 1>())
.def("items",
[](const IdentityMap &map) {
return nb::make_iterator(nb::type<IdentityMap>(),
"item_iterator",
map.begin(),
map.end());
}, nb::keep_alive<0, 1>())
.def("values", [](const IdentityMap &map) {
return nb::make_value_iterator(nb::type<IdentityMap>(),
"value_iterator",
map.begin(),
map.end());
}, nb::keep_alive<0, 1>());

nb::list all;
all.append("iterator_passthrough");
all.append("StringMap");
all.append("IdentityMap");
m.attr("__all__") = all;
}
7 changes: 7 additions & 0 deletions tests/test_make_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ def test04_passthrough_iterator():
for d in data:
m = t.StringMap(d)
assert list(t.iterator_passthrough(m.values())) == list(m.values())


def test05_iterator_returning_temporary():
im = t.IdentityMap()
assert list(im) == list(range(10))
assert list(im.values()) == list(range(10))
assert list(im.items()) == list(zip(range(10), range(10)))
13 changes: 11 additions & 2 deletions tests/test_make_iterator_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from collections.abc import Mapping, Iterator
from collections.abc import Iterator, Mapping
from typing import overload

class IdentityMap:
def __init__(self) -> None: ...

def __iter__(self) -> Iterator[int]: ...

def items(self) -> Iterator[tuple[int, int]]: ...

def values(self) -> Iterator[int]: ...

class StringMap:
@overload
def __init__(self) -> None: ...
Expand All @@ -14,6 +23,6 @@ class StringMap:

def values(self) -> Iterator[str]: ...

__all__: list = ['iterator_passthrough', 'StringMap']
__all__: list = ['iterator_passthrough', 'StringMap', 'IdentityMap']

def iterator_passthrough(arg: Iterator, /) -> Iterator: ...

0 comments on commit d39617e

Please sign in to comment.