Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Use normalize_index #4001

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions stdlib/src/collections/inline_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ from collections import InlineList
```
"""

from collections._index_normalization import normalize_index
from sys.intrinsics import _type_is_eq

from memory.maybe_uninitialized import UnsafeMaybeUninitialized
Expand Down Expand Up @@ -145,15 +146,11 @@ struct InlineList[ElementType: CollectionElementNew, capacity: Int = 16](Sized):
Returns:
A reference to the item at the given index.
"""
var index = Int(idx)
debug_assert(
-self._size <= index < self._size, "Index must be within bounds."
# Using UInt to avoid extra signed normalization in self._array
var normalized_index = normalize_index["InlineList"](
idx, UInt(self._size)
)

if index < 0:
index += len(self)

return self._array[index].assume_initialized()
return self._array[normalized_index].assume_initialized()

# ===-------------------------------------------------------------------===#
# Trait implementations
Expand Down
14 changes: 7 additions & 7 deletions stdlib/src/collections/linked_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ struct LinkedList[
elem.free()
return value^

fn pop[I: Indexer](mut self, owned i: I) raises -> ElementType:
fn pop[I: Indexer](mut self, idx: I) raises -> ElementType:
"""
Remove the ith element of the list, counting from the tail if
given a negative index.
Expand All @@ -295,12 +295,12 @@ struct LinkedList[
I: The type of index to use.

Args:
i: The index of the element to get.
idx: The index of the element to get.

Returns:
Ownership of the indicated element.
"""
var current = self._get_node_ptr(Int(i))
var current = self._get_node_ptr(idx)

if current:
var node = current[]
Expand All @@ -323,7 +323,7 @@ struct LinkedList[
self._size -= 1
return data^

raise String("Invalid index for pop: {}").format(Int(i))
raise String("Invalid index for pop: ", Int(idx))

fn maybe_pop(mut self) -> Optional[ElementType]:
"""
Expand All @@ -347,7 +347,7 @@ struct LinkedList[
elem.free()
return value^

fn maybe_pop[I: Indexer](mut self, owned i: I) -> Optional[ElementType]:
fn maybe_pop[I: Indexer](mut self, idx: I) -> Optional[ElementType]:
"""
Remove the ith element of the list, counting from the tail if
given a negative index.
Expand All @@ -358,12 +358,12 @@ struct LinkedList[
I: The type of index to use.

Args:
i: The index of the element to get.
idx: The index of the element to get.

Returns:
The element, if it was found.
"""
var current = self._get_node_ptr(Int(i))
var current = self._get_node_ptr(idx)

if not current:
return Optional[ElementType]()
Expand Down
20 changes: 3 additions & 17 deletions stdlib/src/collections/list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from collections import List
"""


from collections._index_normalization import normalize_index
from os import abort
from sys import sizeof
from sys.intrinsics import _type_is_eq
Expand Down Expand Up @@ -884,23 +885,8 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False](
Returns:
A reference to the element at the given index.
"""

@parameter
if _type_is_eq[I, UInt]():
return (self.data + idx)[]
else:
var normalized_idx = Int(idx)
debug_assert(
-self.size <= normalized_idx < self.size,
"index: ",
normalized_idx,
" is out of bounds for `List` of size: ",
self.size,
)
if normalized_idx < 0:
normalized_idx += len(self)

return (self.data + normalized_idx)[]
var normalized_index = normalize_index["List"](idx, self.size)
return (self.data + normalized_index)[]

@always_inline
fn unsafe_get(ref self, idx: Int) -> ref [self] Self.T:
Expand Down
3 changes: 2 additions & 1 deletion stdlib/src/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,8 @@ struct String(
A new string containing the character at the specified position.
"""
# TODO(#933): implement this for unicode when we support llvm intrinsic evaluation at compile time
var normalized_idx = normalize_index["String"](idx, len(self))
# Using UInt to avoid extra signed normalization in self._buffer
var normalized_idx = normalize_index["String"](idx, UInt(len(self)))
var buf = Self._buffer_type(capacity=1)
buf.append(self._buffer[normalized_idx])
buf.append(0)
Expand Down
12 changes: 3 additions & 9 deletions stdlib/src/memory/span.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from memory import Span
```
"""

from collections._index_normalization import normalize_index
from collections import InlineArray
from sys.info import simdwidthof

Expand Down Expand Up @@ -178,15 +179,8 @@ struct Span[
Returns:
An element reference.
"""
# TODO: Simplify this with a UInt type.
debug_assert(
-self._len <= Int(idx) < self._len, "index must be within bounds"
)
# TODO(MSTDL-1086): optimize away SIMD/UInt normalization check
var offset = Int(idx)
if offset < 0:
offset += len(self)
return self._data[offset]
var normalized_index = normalize_index["Span"](idx, self._len)
return self._data[normalized_index]

@always_inline
fn __getitem__(self, slc: Slice) -> Self:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# CHECK-FAIL-LABEL: test_fail_list_index
fn main():
print("== test_fail_list_index")
# CHECK-FAIL: index: 4 is out of bounds for `List` of size: 3
# CHECK-FAIL: index out of bounds
nums = List[Int](1, 2, 3)
print(nums[4])

Expand Down