From ea324d234e2791c7b4a3f8e68cdc86d2777dd4bc Mon Sep 17 00:00:00 2001 From: ttrenty <154608953+ttrenty@users.noreply.github.com> Date: Thu, 26 Jun 2025 10:49:43 -0600 Subject: [PATCH 1/3] feat: improve abstraction at minimal overhead cost --- TODOs.md | 13 +- src/local_complex.mojo | 339 ++++++++++++ src/local_list.mojo | 1136 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1483 insertions(+), 5 deletions(-) create mode 100644 src/local_complex.mojo create mode 100644 src/local_list.mojo diff --git a/TODOs.md b/TODOs.md index 4b6cd02..f1591c6 100644 --- a/TODOs.md +++ b/TODOs.md @@ -4,13 +4,19 @@ A pure state is represented by a ket vector, e.g., $ |\psi\rangle = \alpha|0\ran ## Ordered by Priority (5 ↑) / Difficulty (5 ↓) +- 5 / 2 : Reimplement Fig 2 circuit using Swap efficient gates representation + - Generarly when doing an implementation of a abstract circuit also do the implementation + using the functions for people to understand more easily what's happening + - 5 / 1 : density matrix calculcation from state vectors -- 5 / 5 : Extend qubitWiseMultiply() to 2 and multiple qubits gates (Test it) +- 5 / 5 : Extend qubitWiseMultiply() to 2 qubits gates + +- 5 / 5 : Extend qubitWiseMultiply() to multiple qubits gates - 4 / 3 : Implement measurement gates -- 4 / 5 : Implement the computation of statistics: (6.5 and 6.6) + - 4 / 5 : Implement the computation of statistics: (6.5 and 6.6) - 3 / 2 : Use a separate list for things that are not real gate to not slow down the main run logic @@ -21,6 +27,3 @@ A pure state is represented by a ket vector, e.g., $ |\psi\rangle = \alpha|0\ran - 3 / 4 : Compile time circuit creation? - 3 / 2 : Reproduce table from page 10 - -- 2 / 4 : qubitWiseMultiply() but for multiple qubits gates applied to non-adjacent qubits - diff --git a/src/local_complex.mojo b/src/local_complex.mojo new file mode 100644 index 0000000..b88ad43 --- /dev/null +++ b/src/local_complex.mojo @@ -0,0 +1,339 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""Implements the Complex type. + +You can import these APIs from the `complex` package. For example: + +```mojo +from complex import ComplexSIMD +``` +""" + +import math +from math.math import _Expable +from sys import llvm_intrinsic + +alias ComplexFloat32 = ComplexSIMD[DType.float32, 1] +alias ComplexFloat64 = ComplexSIMD[DType.float64, 1] + + +# ===-----------------------------------------------------------------------===# +# ComplexSIMD +# ===-----------------------------------------------------------------------===# + + +@register_passable("trivial") +struct ComplexSIMD[type: DType, size: Int](Stringable, Writable, _Expable): + """Represents a complex SIMD value. + + The class provides basic methods for manipulating complex values. + + Parameters: + type: DType of the value. + size: SIMD width of the value. + """ + + # ===-------------------------------------------------------------------===# + # Fields + # ===-------------------------------------------------------------------===# + + alias element_type = SIMD[type, size] + var re: Self.element_type + """The real part of the complex SIMD value.""" + var im: Self.element_type + """The imaginary part of the complex SIMD value.""" + + # ===-------------------------------------------------------------------===# + # Initialization + # ===-------------------------------------------------------------------===# + + fn __init__(out self, re: Self.element_type, im: Self.element_type = 0): + """Initializes a complex SIMD value. + + Args: + re: The real part of the complex value. + im: The imaginary part of the complex value. + """ + self.re = re + self.im = im + + # ===-------------------------------------------------------------------===# + # Trait implementations + # ===-------------------------------------------------------------------===# + + @no_inline + fn __str__(self) -> String: + """Get the complex as a string. + + Returns: + A string representation. + """ + return String.write(self) + + fn write_to[W: Writer](self, mut writer: W): + """ + Formats this complex value to the provided Writer. + + Parameters: + W: A type conforming to the Writable trait. + + Args: + writer: The object to write to. + """ + + # TODO(MSTDL-700): + # Add a Writer.reserve() method, to afford writer implementations + # to request reservation of additional space from `Writer` + # implementations that support that. Then use the logic below to + # call that method here. + + # Reserve space for opening and closing brackets, plus each element and + # its trailing commas. + # var initial_buffer_size = 2 + # for i in range(size): + # initial_buffer_size += ( + # _calc_initial_buffer_size(self.re[i]) + # + _calc_initial_buffer_size(self.im[i]) + # + 4 # for the ' + i' suffix on the imaginary + # + 2 + # ) + # buf.reserve(initial_buffer_size) + + # Print an opening `[`. + @parameter + if size > 1: + writer.write("[") + + # Print each element. + for i in range(size): + var re = self.re[i] + var im = self.im[i] + # Print separators between each element. + if i != 0: + writer.write(", ") + + writer.write(re) + + if im != 0: + writer.write(" + ", im, "i") + + # Print a closing `]`. + @parameter + if size > 1: + writer.write("]") + + @always_inline + fn __abs__(self) -> SIMD[type, size]: + """Returns the magnitude of the complex value. + + Returns: + Value of `sqrt(re*re + im*im)`. + """ + return self.norm() + + # ===-------------------------------------------------------------------===# + # Operator dunders + # ===-------------------------------------------------------------------===# + + @always_inline + fn __add__(self, rhs: Self) -> Self: + """Adds two complex values. + + Args: + rhs: Complex value to add. + + Returns: + A sum of this and RHS complex values. + """ + return Self(self.re + rhs.re, self.im + rhs.im) + + @always_inline + fn __sub__(self, rhs: Self) -> Self: + """Subtracts two complex values. + + Args: + rhs: Complex value to subtract. + + Returns: + A difference of this and RHS complex values. + """ + return Self(self.re - rhs.re, self.im - rhs.im) + + @always_inline + fn __mul__(self, rhs: Self) -> Self: + """Multiplies two complex values. + + Args: + rhs: Complex value to multiply with. + + Returns: + A product of this and RHS complex values. + """ + return Self( + self.re.fma(rhs.re, -self.im * rhs.im), + self.re.fma(rhs.im, self.im * rhs.re), + ) + + @always_inline + fn __truediv__(self, rhs: Self) -> Self: + """Divides two complex values. + + Args: + rhs: Complex value to divide by. + + Returns: + A quotient of this and RHS complex values. + """ + var denom = rhs.squared_norm() + return Self( + self.re.fma(rhs.re, self.im * rhs.im) / denom, + self.re.fma(rhs.im, -self.im * rhs.re) / denom, + ) + + @always_inline + fn __neg__(self) -> Self: + """Negates the complex value. + + Returns: + The negative of the complex value. + """ + return ComplexSIMD(-self.re, -self.im) + + # ===------------------------------------------------------------------=== # + # In place operations. + # ===------------------------------------------------------------------=== # + + @always_inline("nodebug") + fn __iadd__(mut self, rhs: Self): + """Performs in-place addition. + + Args: + rhs: The rhs of the addition operation. + """ + self = self + rhs + + @always_inline("nodebug") + fn __isub__(mut self, rhs: Self): + """Performs in-place subtraction. + + Args: + rhs: The rhs of the operation. + """ + self = self - rhs + + @always_inline("nodebug") + fn __imul__(mut self, rhs: Self): + """Performs in-place multiplication. + + Args: + rhs: The rhs of the operation. + """ + self = self * rhs + + @always_inline("nodebug") + fn __itruediv__(mut self, rhs: Self): + """In-place true divide operator. + + Args: + rhs: The rhs of the operation. + """ + var denom = rhs.squared_norm() + self.re = self.re.fma(rhs.re, self.im * rhs.im) / denom + self.im = self.re.fma(rhs.im, -self.im * rhs.re) / denom + + # ===-------------------------------------------------------------------===# + # Methods + # ===-------------------------------------------------------------------===# + + @always_inline + fn norm(self) -> SIMD[type, size]: + """Returns the magnitude of the complex value. + + Returns: + Value of `sqrt(re*re + im*im)`. + """ + return llvm_intrinsic["llvm.sqrt", SIMD[type, size]]( + self.squared_norm() + ) + + @always_inline + fn squared_norm(self) -> SIMD[type, size]: + """Returns the squared magnitude of the complex value. + + Returns: + Value of `re*re + im*im`. + """ + return self.re.fma(self.re, self.im * self.im) + + # fma(self, b, c) + @always_inline + fn fma(self, b: Self, c: Self) -> Self: + """Computes FMA operation. + + Compute fused multiple-add with two other complex values: + `result = self * b + c` + + Args: + b: Multiplier complex value. + c: Complex value to add. + + Returns: + Computed `Self * B + C` complex value. + """ + return Self( + self.re.fma(b.re, -(self.im.fma(b.im, -c.re))), + self.re.fma(b.im, self.im.fma(b.re, c.im)), + ) + + # fma(self, self, c) + @always_inline + fn squared_add(self, c: Self) -> Self: + """Computes Square-Add operation. + + Compute `Self * Self + C`. + + Args: + c: Complex value to add. + + Returns: + Computed `Self * Self + C` complex value. + """ + return Self( + self.re.fma(self.re, self.im.fma(-self.im, c.re)), + self.re.fma(self.im + self.im, c.im), + ) + + @always_inline + fn __exp__(self) -> Self: + """Computes the exponential of the complex value. + + Returns: + The exponential of the complex value. + """ + var exp_re = math.exp(self.re) + return Self(exp_re * math.cos(self.im), exp_re * math.sin(self.im)) + + +# TODO: we need this overload, because the Absable trait requires returning Self +# type. We could maybe get rid of this if we had associated types? +@always_inline +fn abs(x: ComplexSIMD[*_]) -> SIMD[x.type, x.size]: + """Performs elementwise abs (norm) on each element of the complex value. + + Args: + x: The complex vector to perform absolute value on. + + Returns: + The elementwise abs of x. + """ + return x.__abs__() diff --git a/src/local_list.mojo b/src/local_list.mojo new file mode 100644 index 0000000..fb0dcb4 --- /dev/null +++ b/src/local_list.mojo @@ -0,0 +1,1136 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""Defines the List type. + +These APIs are imported automatically, just like builtins. +""" + + +from os import abort +from sys import sizeof +from sys.intrinsics import _type_is_eq + +from memory import Pointer, memcpy, memset_zero + +# from .optional import Optional + +# ===-----------------------------------------------------------------------===# +# List +# ===-----------------------------------------------------------------------===# + + +@fieldwise_init +struct _ListIter[ + list_mutability: Bool, //, + T: Copyable & Movable, + hint_trivial_type: Bool, + list_origin: Origin[list_mutability], + forward: Bool = True, +](Copyable, IteratorTrait, Movable): + """Iterator for List. + + Parameters: + list_mutability: Whether the reference to the list is mutable. + T: The type of the elements in the list. + hint_trivial_type: Set to `True` if the type `T` is trivial, this is not + mandatory, but it helps performance. Will go away in the future. + list_origin: The origin of the List + forward: The iteration direction. `False` is backwards. + """ + + alias Element = T # FIXME(MOCO-2068): shouldn't be needed. + alias list_type = CustomList[T, hint_trivial_type] + + var index: Int + var src: Pointer[Self.list_type, list_origin] + + fn __next_ref__(mut self) -> ref [list_origin] T: + @parameter + if forward: + self.index += 1 + return self.src[][self.index - 1] + else: + self.index -= 1 + return self.src[][self.index] + + @always_inline + fn __next__(mut self) -> T: + return self.__next_ref__() + + @always_inline + fn __has_next__(self) -> Bool: + return self.__len__() > 0 + + @always_inline + fn __iter__(self) -> Self: + return self + + fn __len__(self) -> Int: + @parameter + if forward: + return len(self.src[]) - self.index + else: + return self.index + + +struct CustomList[T: Copyable & Movable, hint_trivial_type: Bool = False]( + Boolable, Copyable, Defaultable, ExplicitlyCopyable, Movable, Sized +): + """The `List` type is a dynamically-allocated list. + + Parameters: + T: The type of the elements. + hint_trivial_type: A hint to the compiler that the type T is trivial. + It's not mandatory, but if set, it allows some optimizations. + + Notes: + It supports pushing and popping from the back resizing the underlying + storage as needed. When it is deallocated, it frees its memory. + """ + + # Fields + var data: UnsafePointer[T] + """The underlying storage for the list.""" + var _len: Int + """The number of elements in the list.""" + var capacity: Int + """The amount of elements that can fit in the list without resizing it.""" + + # ===-------------------------------------------------------------------===# + # Life cycle methods + # ===-------------------------------------------------------------------===# + + fn __init__(out self): + """Constructs an empty list.""" + self.data = UnsafePointer[T]() + self._len = 0 + self.capacity = 0 + + fn copy(self) -> Self: + """Creates a deep copy of the given list. + + Returns: + A copy of the value. + """ + var copy = Self(capacity=self.capacity) + for e in self: + copy.append(e) + return copy^ + + fn __init__(out self, *, capacity: Int): + """Constructs a list with the given capacity. + + Args: + capacity: The requested capacity of the list. + """ + if capacity: + self.data = UnsafePointer[T].alloc(capacity) + else: + self.data = UnsafePointer[T]() + self._len = 0 + self.capacity = capacity + + fn __init__(out self, *, length: UInt, fill: T): + """Constructs a list with the given capacity. + + Args: + length: The requested length of the list. + fill: The element to fill each element of the list. + """ + self = Self() + self.resize(length, fill) + + @always_inline + fn __init__(out self, owned *values: T, __list_literal__: () = ()): + """Constructs a list from the given values. + + Args: + values: The values to populate the list with. + __list_literal__: Tell Mojo to use this method for list literals. + """ + self = Self(elements=values^) + + fn __init__(out self, *, owned elements: VariadicListMem[T, _]): + """Constructs a list from the given values. + + Args: + elements: The values to populate the list with. + """ + var length = len(elements) + + self = Self(capacity=length) + + for i in range(length): + var src = UnsafePointer(to=elements[i]) + var dest = self.data + i + + src.move_pointee_into(dest) + + # Do not destroy the elements when their backing storage goes away. + __disable_del elements + + self._len = length + + fn __init__(out self, span: Span[T]): + """Constructs a list from the a Span of values. + + Args: + span: The span of values to populate the list with. + """ + self = Self(capacity=len(span)) + for value in span: + self.append(value) + + @always_inline + fn __init__(out self, *, unsafe_uninit_length: Int): + """Construct a list with the specified length, with uninitialized + memory. This is unsafe, as it relies on the caller initializing the + elements with unsafe operations, not assigning over the uninitialized + data. + + Args: + unsafe_uninit_length: The number of elements to allocate. + """ + self = Self(capacity=unsafe_uninit_length) + self._len = unsafe_uninit_length + + fn __copyinit__(out self, existing: Self): + """Creates a deepcopy of the given list. + + Args: + existing: The list to copy. + """ + self = Self(capacity=existing.capacity) + for i in range(len(existing)): + self.append(existing[i]) + + fn __del__(owned self): + """Destroy all elements in the list and free its memory.""" + + @parameter + if not hint_trivial_type: + for i in range(len(self)): + (self.data + i).destroy_pointee() + self.data.free() + + # ===-------------------------------------------------------------------===# + # Operator dunders + # ===-------------------------------------------------------------------===# + + @always_inline + fn __eq__[ + U: EqualityComparable & Copyable & Movable, // + ](self: List[U, *_], other: List[U, *_]) -> Bool: + """Checks if two lists are equal. + + Parameters: + U: The type of the elements in the list. Must implement the + trait `EqualityComparable`. + + Args: + other: The list to compare with. + + Returns: + True if the lists are equal, False otherwise. + + Examples: + + ```mojo + var x = [1, 2, 3] + var y = [1, 2, 3] + print("x and y are equal" if x == y else "x and y are not equal") + ``` + """ + if len(self) != len(other): + return False + var index = 0 + for element in self: + if element != other[index]: + return False + index += 1 + return True + + @always_inline + fn __ne__[ + U: EqualityComparable & Copyable & Movable, // + ](self: List[U, *_], other: List[U, *_]) -> Bool: + """Checks if two lists are not equal. + + Parameters: + U: The type of the elements in the list. Must implement the + trait `EqualityComparable`. + + Args: + other: The list to compare with. + + Returns: + True if the lists are not equal, False otherwise. + + Examples: + + ```mojo + var x = [1, 2, 3] + var y = [1, 2, 4] + print("x and y are not equal" if x != y else "x and y are equal") + ``` + """ + return not (self == other) + + fn __contains__[ + U: EqualityComparable & Copyable & Movable, // + ](self: List[U, *_], value: U) -> Bool: + """Verify if a given value is present in the list. + + Parameters: + U: The type of the elements in the list. Must implement the + trait `EqualityComparable`. + + Args: + value: The value to find. + + Returns: + True if the value is contained in the list, False otherwise. + + Examples: + + ```mojo + var x = [1, 2, 3] + print("x contains 3" if 3 in x else "x does not contain 3") + ``` + """ + for i in self: + if i == value: + return True + return False + + fn __mul__(self, x: Int) -> Self: + """Multiplies the list by x and returns a new list. + + Args: + x: The multiplier number. + + Returns: + The new list. + """ + # avoid the copy since it would be cleared immediately anyways + if x == 0: + return Self() + var result = self.copy() + result *= x + return result^ + + fn __imul__(mut self, x: Int): + """Appends the original elements of this list x-1 times or clears it if + x is <= 0. + + ```mojo + var a = [1, 2] + a *= 2 # a = [1, 2, 1, 2] + ``` + + Args: + x: The multiplier number. + """ + if x <= 0 or len(self) == 0: + self.clear() + return + var orig = self.copy() + self.reserve(len(self) * x) + for _ in range(x - 1): + self.extend(orig) + + fn __add__(self, owned other: Self) -> Self: + """Concatenates self with other and returns the result as a new list. + + Args: + other: List whose elements will be combined with the elements of + self. + + Returns: + The newly created list. + """ + var result = self.copy() + result.extend(other^) + return result^ + + fn __iadd__(mut self, owned other: Self): + """Appends the elements of other into self. + + Args: + other: List whose elements will be appended to self. + """ + self.extend(other^) + + fn __iter__(ref self) -> _ListIter[T, hint_trivial_type, __origin_of(self)]: + """Iterate over elements of the list, returning immutable references. + + Returns: + An iterator of immutable references to the list elements. + """ + return _ListIter(0, Pointer(to=self)) + + fn __reversed__( + ref self, + ) -> _ListIter[T, hint_trivial_type, __origin_of(self), False]: + """Iterate backwards over the list, returning immutable references. + + Returns: + A reversed iterator of immutable references to the list elements. + """ + return _ListIter[forward=False](len(self), Pointer(to=self)) + + # ===-------------------------------------------------------------------===# + # Trait implementations + # ===-------------------------------------------------------------------===# + + @always_inline("nodebug") + fn __len__(self) -> Int: + """Gets the number of elements in the list. + + Returns: + The number of elements in the list. + """ + return self._len + + fn __bool__(self) -> Bool: + """Checks whether the list has any elements or not. + + Returns: + `False` if the list is empty, `True` if there is at least one + element. + """ + return len(self) > 0 + + @no_inline + fn __str__[ + U: Representable & Copyable & Movable, // + ](self: List[U, *_]) -> String: + """Returns a string representation of a `List`. + + Parameters: + U: The type of the elements in the list. Must implement the + trait `Representable`. + + Returns: + A string representation of the list. + + Notes: + Note that since we can't condition methods on a trait yet, + the way to call this method is a bit special. Here is an example + below: + + ```mojo + var my_list = [1, 2, 3] + print(my_list.__str__()) + ``` + + When the compiler supports conditional methods, then a simple + `String(my_list)` will be enough. + """ + # at least 1 byte per item e.g.: [a, b, c, d] = 4 + 2 * 3 + [] + null + var l = len(self) + var output = String(capacity=l + 2 * (l - 1) * Int(l > 1) + 3) + self.write_to(output) + return output^ + + @no_inline + fn write_to[ + W: Writer, U: Representable & Copyable & Movable, // + ](self: List[U, *_], mut writer: W): + """Write `my_list.__str__()` to a `Writer`. + + Parameters: + W: A type conforming to the Writable trait. + U: The type of the List elements. Must have the trait + `Representable`. + + Args: + writer: The object to write to. + """ + writer.write("[") + for i in range(len(self)): + writer.write(repr(self[i])) + if i < len(self) - 1: + writer.write(", ") + writer.write("]") + + @no_inline + fn __repr__[ + U: Representable & Copyable & Movable, // + ](self: List[U, *_]) -> String: + """Returns a string representation of a `List`. + + Parameters: + U: The type of the elements in the list. Must implement the + trait `Representable`. + + Returns: + A string representation of the list. + + Notes: + Note that since we can't condition methods on a trait yet, the way + to call this method is a bit special. Here is an example below: + + ```mojo + var my_list = [1, 2, 3] + print(my_list.__repr__()) + ``` + + When the compiler supports conditional methods, then a simple + `repr(my_list)` will be enough. + """ + return self.__str__() + + # ===-------------------------------------------------------------------===# + # Methods + # ===-------------------------------------------------------------------===# + + fn byte_length(self) -> Int: + """Gets the byte length of the List (`len(self) * sizeof[T]()`). + + Returns: + The byte length of the List (`len(self) * sizeof[T]()`). + """ + return len(self) * sizeof[T]() + + @no_inline + fn _realloc(mut self, new_capacity: Int): + var new_data = UnsafePointer[T].alloc(new_capacity) + + @parameter + if hint_trivial_type: + memcpy(new_data, self.data, len(self)) + else: + for i in range(len(self)): + (self.data + i).move_pointee_into(new_data + i) + + if self.data: + self.data.free() + self.data = new_data + self.capacity = new_capacity + + fn memset_zero(mut self): + """Sets all elements in the list to zero.""" + + @parameter + if hint_trivial_type: + memset_zero(self.data, len(self)) + else: + return # TODO: how to reset to 0 unknown type? + + fn append(mut self, owned value: T): + """Appends a value to this list. + + Args: + value: The value to append. + + Notes: + If there is no capacity left, resizes to twice the current capacity. + Except for 0 capacity where it sets 1. + """ + if self._len >= self.capacity: + self._realloc(self.capacity * 2 | Int(self.capacity == 0)) + self._unsafe_next_uninit_ptr().init_pointee_move(value^) + self._len += 1 + + fn append(mut self, elements: Span[T, _]): + """Appends elements to this list. + + Args: + elements: The elements to append. + """ + var elements_len = len(elements) + var new_num_elts = self._len + elements_len + if new_num_elts > self.capacity: + # Make sure our capacity at least doubles to avoid O(n^2) behavior. + self._realloc(max(self.capacity * 2, new_num_elts)) + + var i = self._len + self._len = new_num_elts + + @parameter + if hint_trivial_type: + memcpy(self.data + i, elements.unsafe_ptr(), elements_len) + else: + for elt in elements: + UnsafePointer(to=self[i]).init_pointee_copy(elt) + i += 1 + + fn insert(mut self, i: Int, owned value: T): + """Inserts a value to the list at the given index. + `a.insert(len(a), value)` is equivalent to `a.append(value)`. + + Args: + i: The index for the value. + value: The value to insert. + """ + debug_assert(i <= len(self), "insert index out of range") + + var normalized_idx = i + if i < 0: + normalized_idx = max(0, len(self) + i) + + var earlier_idx = len(self) + var later_idx = len(self) - 1 + self.append(value^) + + for _ in range(normalized_idx, len(self) - 1): + var earlier_ptr = self.data + earlier_idx + var later_ptr = self.data + later_idx + + var tmp = earlier_ptr.take_pointee() + later_ptr.move_pointee_into(earlier_ptr) + later_ptr.init_pointee_move(tmp^) + + earlier_idx -= 1 + later_idx -= 1 + + fn extend(mut self, owned other: List[T, *_]): + """Extends this list by consuming the elements of `other`. + + Args: + other: List whose elements will be added in order at the end of this + list. + """ + + var other_len = len(other) + var final_size = len(self) + other_len + self.reserve(final_size) + + var dest_ptr = self.data + self._len + var src_ptr = other.unsafe_ptr() + + @parameter + if hint_trivial_type: + memcpy(dest_ptr, src_ptr, other_len) + else: + for _ in range(other_len): + # This (TODO: optimistically) moves an element directly from the + # `other` list into this list using a single `T.__moveinit()__` + # call, without moving into an intermediate temporary value + # (avoiding an extra redundant move constructor call). + src_ptr.move_pointee_into(dest_ptr) + src_ptr += 1 + dest_ptr += 1 + + # Update the size now since all elements have been moved into this list. + self._len = final_size + # The elements of `other` are now consumed, so we mark it as empty so + # they don't get destroyed when it goes out of scope. + other._len = 0 + + fn extend[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _]): + """Extends this list with the elements of a vector. + + Parameters: + D: The DType. + + Args: + value: The value to append. + + Notes: + If there is no capacity left, resizes to `len(self) + value.size`. + """ + self.reserve(self._len + value.size) + self._unsafe_next_uninit_ptr().store(value) + self._len += value.size + + fn extend[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _], *, count: Int): + """Extends this list with `count` number of elements from a vector. + + Parameters: + D: The DType. + + Args: + value: The value to append. + count: The amount of items to append. Must be less than or equal to + `value.size`. + + Notes: + If there is no capacity left, resizes to `len(self) + count`. + """ + debug_assert(count <= value.size, "count must be <= value.size") + self.reserve(self._len + count) + var v_ptr = UnsafePointer(to=value).bitcast[Scalar[D]]() + memcpy(self._unsafe_next_uninit_ptr(), v_ptr, count) + self._len += count + + fn extend[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: Span[Scalar[D]]): + """Extends this list with the elements of a `Span`. + + Parameters: + D: The DType. + + Args: + value: The value to append. + + Notes: + If there is no capacity left, resizes to `len(self) + len(value)`. + """ + self.reserve(self._len + len(value)) + memcpy(self._unsafe_next_uninit_ptr(), value.unsafe_ptr(), len(value)) + self._len += len(value) + + fn pop(mut self, i: Int = -1) -> T: + """Pops a value from the list at the given index. + + Args: + i: The index of the value to pop. + + Returns: + The popped value. + """ + debug_assert(-self._len <= i < self._len, "pop index out of range") + + var normalized_idx = i + if i < 0: + normalized_idx += self._len + + var ret_val = (self.data + normalized_idx).take_pointee() + for j in range(normalized_idx + 1, self._len): + (self.data + j).move_pointee_into(self.data + j - 1) + self._len -= 1 + + return ret_val^ + + fn reserve(mut self, new_capacity: Int): + """Reserves the requested capacity. + + Args: + new_capacity: The new capacity. + + Notes: + If the current capacity is greater or equal, this is a no-op. + Otherwise, the storage is reallocated and the date is moved. + """ + if self.capacity >= new_capacity: + return + self._realloc(new_capacity) + + fn resize(mut self, new_size: Int, value: T): + """Resizes the list to the given new size. + + Args: + new_size: The new size. + value: The value to use to populate new elements. + + Notes: + If the new size is smaller than the current one, elements at the end + are discarded. If the new size is larger than the current one, the + list is appended with new values elements up to the requested size. + """ + if new_size <= self._len: + self.shrink(new_size) + else: + self.reserve(new_size) + for i in range(self._len, new_size): + (self.data + i).init_pointee_copy(value) + self._len = new_size + + fn resize(mut self, *, unsafe_uninit_length: Int): + """Resizes the list to the given new size leaving any new elements + uninitialized. + + If the new size is smaller than the current one, elements at the end + are discarded. If the new size is larger than the current one, the + list is extended and the new elements are left uninitialized. + + Args: + unsafe_uninit_length: The new size. + """ + if unsafe_uninit_length <= self._len: + self.shrink(unsafe_uninit_length) + else: + self.reserve(unsafe_uninit_length) + self._len = unsafe_uninit_length + + fn shrink(mut self, new_size: Int): + """Resizes to the given new size which must be <= the current size. + + Args: + new_size: The new size. + + Notes: + With no new value provided, the new size must be smaller than or + equal to the current one. Elements at the end are discarded. + """ + if len(self) < new_size: + abort( + "You are calling List.resize with a new_size bigger than the" + " current size. If you want to make the List bigger, provide a" + " value to fill the new slots with. If not, make sure the new" + " size is smaller than the current size." + ) + + @parameter + if not hint_trivial_type: + for i in range(new_size, len(self)): + (self.data + i).destroy_pointee() + self._len = new_size + self.reserve(new_size) + + fn reverse(mut self): + """Reverses the elements of the list.""" + + var earlier_idx = 0 + var later_idx = len(self) - 1 + + var effective_len = len(self) + var half_len = effective_len // 2 + + for _ in range(half_len): + var earlier_ptr = self.data + earlier_idx + var later_ptr = self.data + later_idx + + var tmp = earlier_ptr.take_pointee() + later_ptr.move_pointee_into(earlier_ptr) + later_ptr.init_pointee_move(tmp^) + + earlier_idx += 1 + later_idx -= 1 + + # TODO: Remove explicit self type when issue 1876 is resolved. + fn index[ + C: EqualityComparable & Copyable & Movable, // + ]( + ref self: List[C, *_], + value: C, + start: Int = 0, + stop: Optional[Int] = None, + ) raises -> Int: + """Returns the index of the first occurrence of a value in a list + restricted by the range given the start and stop bounds. + + Args: + value: The value to search for. + start: The starting index of the search, treated as a slice index + (defaults to 0). + stop: The ending index of the search, treated as a slice index + (defaults to None, which means the end of the list). + + Parameters: + C: The type of the elements in the list. Must implement the + `EqualityComparable` trait. + + Returns: + The index of the first occurrence of the value in the list. + + Raises: + ValueError: If the value is not found in the list. + + Examples: + + ```mojo + var my_list = [1, 2, 3] + print(my_list.index(2)) # prints `1` + ``` + """ + var start_normalized = start + + var stop_normalized: Int + if stop is None: + # Default end + stop_normalized = len(self) + else: + stop_normalized = stop.value() + + if start_normalized < 0: + start_normalized += len(self) + if stop_normalized < 0: + stop_normalized += len(self) + + start_normalized = _clip(start_normalized, 0, len(self)) + stop_normalized = _clip(stop_normalized, 0, len(self)) + + for i in range(start_normalized, stop_normalized): + if self[i] == value: + return i + raise "ValueError: Given element is not in list" + + fn _binary_search_index[ + dtype: DType, //, + ](self: List[Scalar[dtype], **_], needle: Scalar[dtype]) -> Optional[UInt]: + """Finds the index of `needle` with binary search. + + Args: + needle: The value to binary search for. + + Returns: + Returns None if `needle` is not present, or if `self` was not + sorted. + + Notes: + This function will return an unspecified index if `self` is not + sorted in ascending order. + """ + var cursor = UInt(0) + var b = self.data + var length = len(self) + while length > 1: + var half = length >> 1 + length -= half + cursor += Int(b[cursor + half - 1] < needle) * half + + return Optional(cursor) if b[cursor] == needle else None + + fn clear(mut self): + """Clears the elements in the list.""" + for i in range(self._len): + (self.data + i).destroy_pointee() + self._len = 0 + + fn steal_data(mut self) -> UnsafePointer[T]: + """Take ownership of the underlying pointer from the list. + + Returns: + The underlying data. + """ + var ptr = self.data + self.data = UnsafePointer[T]() + self._len = 0 + self.capacity = 0 + return ptr + + fn __getitem__(self, slice: Slice) -> Self: + """Gets the sequence of elements at the specified positions. + + Args: + slice: A slice that specifies positions of the new list. + + Returns: + A new list containing the list at the specified slice. + """ + var start, end, step = slice.indices(len(self)) + var r = range(start, end, step) + + if not len(r): + return Self() + + var res = Self(capacity=len(r)) + for i in r: + res.append(self[i]) + + return res^ + + fn __getitem__[I: Indexer](ref self, idx: I) -> ref [self] T: + """Gets the list element at the given index. + + Args: + idx: The index of the element. + + Parameters: + I: A type that can be used as an index. + + Returns: + A reference to the element at the given index. + """ + + @parameter + if _type_is_eq[I, UInt](): + return (self.data + idx)[] + else: + var normalized_idx = Int(idx) + debug_assert( + -self._len <= normalized_idx < self._len, + "index: ", + normalized_idx, + " is out of bounds for `List` of length: ", + self._len, + ) + if normalized_idx < 0: + normalized_idx += len(self) + + return (self.data + normalized_idx)[] + + @always_inline + fn unsafe_get(ref self, idx: Int) -> ref [self] Self.T: + """Get a reference to an element of self without checking index bounds. + + Args: + idx: The index of the element to get. + + Returns: + A reference to the element at the given index. + + Notes: + Users should consider using `__getitem__` instead of this method as + it is unsafe. If an index is out of bounds, this method will not + abort, it will be considered undefined behavior. + + Note that there is no wraparound for negative indices, caution is + advised. Using negative indices is considered undefined behavior. + Never use `my_list.unsafe_get(-1)` to get the last element of the + list. Instead, do `my_list.unsafe_get(len(my_list) - 1)`. + """ + debug_assert( + 0 <= idx < len(self), + ( + "The index provided must be within the range [0, len(List) -1]" + " when using List.unsafe_get()" + ), + ) + return (self.data + idx)[] + + @always_inline + fn unsafe_set(mut self, idx: Int, owned value: T): + """Write a value to a given location without checking index bounds. + + Args: + idx: The index of the element to set. + value: The value to set. + + Notes: + Users should consider using `my_list[idx] = value` instead of this + method as it is unsafe. If an index is out of bounds, this method + will not abort, it will be considered undefined behavior. + + Note that there is no wraparound for negative indices, caution is + advised. Using negative indices is considered undefined behavior. + Never use `my_list.unsafe_set(-1, value)` to set the last element of + the list. Instead, do `my_list.unsafe_set(len(my_list) - 1, value)`. + """ + debug_assert( + 0 <= idx < len(self), + ( + "The index provided must be within the range [0, len(List) -1]" + " when using List.unsafe_set()" + ), + ) + (self.data + idx).destroy_pointee() + (self.data + idx).init_pointee_move(value^) + + fn count[ + T: EqualityComparable & Copyable & Movable, // + ](self: List[T, *_], value: T) -> Int: + """Counts the number of occurrences of a value in the list. + + Parameters: + T: The type of the elements in the list. Must implement the + trait `EqualityComparable`. + + Args: + value: The value to count. + + Returns: + The number of occurrences of the value in the list. + """ + var count = 0 + for elem in self: + if elem == value: + count += 1 + return count + + fn swap_elements(mut self, elt_idx_1: Int, elt_idx_2: Int): + """Swaps elements at the specified indexes if they are different. + + Args: + elt_idx_1: The index of one element. + elt_idx_2: The index of the other element. + + Examples: + + ```mojo + var my_list = [1, 2, 3] + my_list.swap_elements(0, 2) + print(my_list.__str__()) # 3, 2, 1 + ``` + + Notes: + This is useful because `swap(my_list[i], my_list[j])` cannot be + supported by Mojo, because a mutable alias may be formed. + """ + debug_assert( + 0 <= elt_idx_1 < len(self) and 0 <= elt_idx_2 < len(self), + ( + "The indices provided to swap_elements must be within the range" + " [0, len(List)-1]" + ), + ) + if elt_idx_1 != elt_idx_2: + swap((self.data + elt_idx_1)[], (self.data + elt_idx_2)[]) + + fn unsafe_ptr( + ref self, + ) -> UnsafePointer[ + T, + mut = Origin(__origin_of(self)).mut, + origin = __origin_of(self), + ]: + """Retrieves a pointer to the underlying memory. + + Returns: + The pointer to the underlying memory. + """ + return self.data.origin_cast[ + mut = Origin(__origin_of(self)).mut, origin = __origin_of(self) + ]() + + @always_inline + fn _unsafe_next_uninit_ptr( + ref self, + ) -> UnsafePointer[ + T, + mut = Origin(__origin_of(self)).mut, + origin = __origin_of(self), + ]: + """Retrieves a pointer to the next uninitialized element position. + + Safety: + + - This pointer MUST not be used to read or write memory beyond the + allocated capacity of this list. + - This pointer may not be used to initialize non-contiguous elements. + - Ensure that `List._len` is updated to reflect the new number of + initialized elements, otherwise elements may be unexpectedly + overwritten or not destroyed correctly. + + Notes: + This returns a pointer that points to the element position immediately + after the last initialized element. This is equivalent to + `list.unsafe_ptr() + len(list)`. + """ + debug_assert( + self.capacity > 0 and self.capacity > self._len, + ( + "safety violation: Insufficient capacity to retrieve pointer to" + " next uninitialized element" + ), + ) + + # self.unsafe_ptr() + self._len won't work because .unsafe_ptr() + # takes a ref that might mutate self + var length = self._len + return self.unsafe_ptr() + length + + fn _cast_hint_trivial_type[ + hint_trivial_type: Bool + ](owned self) -> List[T, hint_trivial_type]: + var result = List[T, hint_trivial_type]() + result.data = self.data + result._len = self._len + result.capacity = self.capacity + + # We stole the elements, don't destroy them. + __disable_del self + + return result^ + + +fn _clip(value: Int, start: Int, end: Int) -> Int: + return max(start, min(value, end)) From 43c73c06901b8b77fed57d20f42c6a105c62c3d5 Mon Sep 17 00:00:00 2001 From: ttrenty <154608953+ttrenty@users.noreply.github.com> Date: Thu, 26 Jun 2025 18:09:54 -0600 Subject: [PATCH 2/3] feat: re-organise code structure + add partial trace + pixi setup --- TODOs.md | 10 +- src/local_complex.mojo | 339 ------------ src/local_list.mojo | 1136 ---------------------------------------- 3 files changed, 3 insertions(+), 1482 deletions(-) delete mode 100644 src/local_complex.mojo delete mode 100644 src/local_list.mojo diff --git a/TODOs.md b/TODOs.md index f1591c6..803d02a 100644 --- a/TODOs.md +++ b/TODOs.md @@ -4,15 +4,9 @@ A pure state is represented by a ket vector, e.g., $ |\psi\rangle = \alpha|0\ran ## Ordered by Priority (5 ↑) / Difficulty (5 ↓) -- 5 / 2 : Reimplement Fig 2 circuit using Swap efficient gates representation - - Generarly when doing an implementation of a abstract circuit also do the implementation - using the functions for people to understand more easily what's happening - - 5 / 1 : density matrix calculcation from state vectors -- 5 / 5 : Extend qubitWiseMultiply() to 2 qubits gates - -- 5 / 5 : Extend qubitWiseMultiply() to multiple qubits gates +- 5 / 5 : Extend qubitWiseMultiply() to 2 and multiple qubits gates (Test it) - 4 / 3 : Implement measurement gates @@ -20,6 +14,8 @@ A pure state is represented by a ket vector, e.g., $ |\psi\rangle = \alpha|0\ran - 3 / 2 : Use a separate list for things that are not real gate to not slow down the main run logic +- 3 / 2 : Use a separate list for things that are not real gate to not slow down the main run logic + - 3 / 3 : Implement naive implementation of the functions - matrix multiplication (but starting from right or smart) - partial trace diff --git a/src/local_complex.mojo b/src/local_complex.mojo deleted file mode 100644 index b88ad43..0000000 --- a/src/local_complex.mojo +++ /dev/null @@ -1,339 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2025, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # -"""Implements the Complex type. - -You can import these APIs from the `complex` package. For example: - -```mojo -from complex import ComplexSIMD -``` -""" - -import math -from math.math import _Expable -from sys import llvm_intrinsic - -alias ComplexFloat32 = ComplexSIMD[DType.float32, 1] -alias ComplexFloat64 = ComplexSIMD[DType.float64, 1] - - -# ===-----------------------------------------------------------------------===# -# ComplexSIMD -# ===-----------------------------------------------------------------------===# - - -@register_passable("trivial") -struct ComplexSIMD[type: DType, size: Int](Stringable, Writable, _Expable): - """Represents a complex SIMD value. - - The class provides basic methods for manipulating complex values. - - Parameters: - type: DType of the value. - size: SIMD width of the value. - """ - - # ===-------------------------------------------------------------------===# - # Fields - # ===-------------------------------------------------------------------===# - - alias element_type = SIMD[type, size] - var re: Self.element_type - """The real part of the complex SIMD value.""" - var im: Self.element_type - """The imaginary part of the complex SIMD value.""" - - # ===-------------------------------------------------------------------===# - # Initialization - # ===-------------------------------------------------------------------===# - - fn __init__(out self, re: Self.element_type, im: Self.element_type = 0): - """Initializes a complex SIMD value. - - Args: - re: The real part of the complex value. - im: The imaginary part of the complex value. - """ - self.re = re - self.im = im - - # ===-------------------------------------------------------------------===# - # Trait implementations - # ===-------------------------------------------------------------------===# - - @no_inline - fn __str__(self) -> String: - """Get the complex as a string. - - Returns: - A string representation. - """ - return String.write(self) - - fn write_to[W: Writer](self, mut writer: W): - """ - Formats this complex value to the provided Writer. - - Parameters: - W: A type conforming to the Writable trait. - - Args: - writer: The object to write to. - """ - - # TODO(MSTDL-700): - # Add a Writer.reserve() method, to afford writer implementations - # to request reservation of additional space from `Writer` - # implementations that support that. Then use the logic below to - # call that method here. - - # Reserve space for opening and closing brackets, plus each element and - # its trailing commas. - # var initial_buffer_size = 2 - # for i in range(size): - # initial_buffer_size += ( - # _calc_initial_buffer_size(self.re[i]) - # + _calc_initial_buffer_size(self.im[i]) - # + 4 # for the ' + i' suffix on the imaginary - # + 2 - # ) - # buf.reserve(initial_buffer_size) - - # Print an opening `[`. - @parameter - if size > 1: - writer.write("[") - - # Print each element. - for i in range(size): - var re = self.re[i] - var im = self.im[i] - # Print separators between each element. - if i != 0: - writer.write(", ") - - writer.write(re) - - if im != 0: - writer.write(" + ", im, "i") - - # Print a closing `]`. - @parameter - if size > 1: - writer.write("]") - - @always_inline - fn __abs__(self) -> SIMD[type, size]: - """Returns the magnitude of the complex value. - - Returns: - Value of `sqrt(re*re + im*im)`. - """ - return self.norm() - - # ===-------------------------------------------------------------------===# - # Operator dunders - # ===-------------------------------------------------------------------===# - - @always_inline - fn __add__(self, rhs: Self) -> Self: - """Adds two complex values. - - Args: - rhs: Complex value to add. - - Returns: - A sum of this and RHS complex values. - """ - return Self(self.re + rhs.re, self.im + rhs.im) - - @always_inline - fn __sub__(self, rhs: Self) -> Self: - """Subtracts two complex values. - - Args: - rhs: Complex value to subtract. - - Returns: - A difference of this and RHS complex values. - """ - return Self(self.re - rhs.re, self.im - rhs.im) - - @always_inline - fn __mul__(self, rhs: Self) -> Self: - """Multiplies two complex values. - - Args: - rhs: Complex value to multiply with. - - Returns: - A product of this and RHS complex values. - """ - return Self( - self.re.fma(rhs.re, -self.im * rhs.im), - self.re.fma(rhs.im, self.im * rhs.re), - ) - - @always_inline - fn __truediv__(self, rhs: Self) -> Self: - """Divides two complex values. - - Args: - rhs: Complex value to divide by. - - Returns: - A quotient of this and RHS complex values. - """ - var denom = rhs.squared_norm() - return Self( - self.re.fma(rhs.re, self.im * rhs.im) / denom, - self.re.fma(rhs.im, -self.im * rhs.re) / denom, - ) - - @always_inline - fn __neg__(self) -> Self: - """Negates the complex value. - - Returns: - The negative of the complex value. - """ - return ComplexSIMD(-self.re, -self.im) - - # ===------------------------------------------------------------------=== # - # In place operations. - # ===------------------------------------------------------------------=== # - - @always_inline("nodebug") - fn __iadd__(mut self, rhs: Self): - """Performs in-place addition. - - Args: - rhs: The rhs of the addition operation. - """ - self = self + rhs - - @always_inline("nodebug") - fn __isub__(mut self, rhs: Self): - """Performs in-place subtraction. - - Args: - rhs: The rhs of the operation. - """ - self = self - rhs - - @always_inline("nodebug") - fn __imul__(mut self, rhs: Self): - """Performs in-place multiplication. - - Args: - rhs: The rhs of the operation. - """ - self = self * rhs - - @always_inline("nodebug") - fn __itruediv__(mut self, rhs: Self): - """In-place true divide operator. - - Args: - rhs: The rhs of the operation. - """ - var denom = rhs.squared_norm() - self.re = self.re.fma(rhs.re, self.im * rhs.im) / denom - self.im = self.re.fma(rhs.im, -self.im * rhs.re) / denom - - # ===-------------------------------------------------------------------===# - # Methods - # ===-------------------------------------------------------------------===# - - @always_inline - fn norm(self) -> SIMD[type, size]: - """Returns the magnitude of the complex value. - - Returns: - Value of `sqrt(re*re + im*im)`. - """ - return llvm_intrinsic["llvm.sqrt", SIMD[type, size]]( - self.squared_norm() - ) - - @always_inline - fn squared_norm(self) -> SIMD[type, size]: - """Returns the squared magnitude of the complex value. - - Returns: - Value of `re*re + im*im`. - """ - return self.re.fma(self.re, self.im * self.im) - - # fma(self, b, c) - @always_inline - fn fma(self, b: Self, c: Self) -> Self: - """Computes FMA operation. - - Compute fused multiple-add with two other complex values: - `result = self * b + c` - - Args: - b: Multiplier complex value. - c: Complex value to add. - - Returns: - Computed `Self * B + C` complex value. - """ - return Self( - self.re.fma(b.re, -(self.im.fma(b.im, -c.re))), - self.re.fma(b.im, self.im.fma(b.re, c.im)), - ) - - # fma(self, self, c) - @always_inline - fn squared_add(self, c: Self) -> Self: - """Computes Square-Add operation. - - Compute `Self * Self + C`. - - Args: - c: Complex value to add. - - Returns: - Computed `Self * Self + C` complex value. - """ - return Self( - self.re.fma(self.re, self.im.fma(-self.im, c.re)), - self.re.fma(self.im + self.im, c.im), - ) - - @always_inline - fn __exp__(self) -> Self: - """Computes the exponential of the complex value. - - Returns: - The exponential of the complex value. - """ - var exp_re = math.exp(self.re) - return Self(exp_re * math.cos(self.im), exp_re * math.sin(self.im)) - - -# TODO: we need this overload, because the Absable trait requires returning Self -# type. We could maybe get rid of this if we had associated types? -@always_inline -fn abs(x: ComplexSIMD[*_]) -> SIMD[x.type, x.size]: - """Performs elementwise abs (norm) on each element of the complex value. - - Args: - x: The complex vector to perform absolute value on. - - Returns: - The elementwise abs of x. - """ - return x.__abs__() diff --git a/src/local_list.mojo b/src/local_list.mojo deleted file mode 100644 index fb0dcb4..0000000 --- a/src/local_list.mojo +++ /dev/null @@ -1,1136 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2025, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # -"""Defines the List type. - -These APIs are imported automatically, just like builtins. -""" - - -from os import abort -from sys import sizeof -from sys.intrinsics import _type_is_eq - -from memory import Pointer, memcpy, memset_zero - -# from .optional import Optional - -# ===-----------------------------------------------------------------------===# -# List -# ===-----------------------------------------------------------------------===# - - -@fieldwise_init -struct _ListIter[ - list_mutability: Bool, //, - T: Copyable & Movable, - hint_trivial_type: Bool, - list_origin: Origin[list_mutability], - forward: Bool = True, -](Copyable, IteratorTrait, Movable): - """Iterator for List. - - Parameters: - list_mutability: Whether the reference to the list is mutable. - T: The type of the elements in the list. - hint_trivial_type: Set to `True` if the type `T` is trivial, this is not - mandatory, but it helps performance. Will go away in the future. - list_origin: The origin of the List - forward: The iteration direction. `False` is backwards. - """ - - alias Element = T # FIXME(MOCO-2068): shouldn't be needed. - alias list_type = CustomList[T, hint_trivial_type] - - var index: Int - var src: Pointer[Self.list_type, list_origin] - - fn __next_ref__(mut self) -> ref [list_origin] T: - @parameter - if forward: - self.index += 1 - return self.src[][self.index - 1] - else: - self.index -= 1 - return self.src[][self.index] - - @always_inline - fn __next__(mut self) -> T: - return self.__next_ref__() - - @always_inline - fn __has_next__(self) -> Bool: - return self.__len__() > 0 - - @always_inline - fn __iter__(self) -> Self: - return self - - fn __len__(self) -> Int: - @parameter - if forward: - return len(self.src[]) - self.index - else: - return self.index - - -struct CustomList[T: Copyable & Movable, hint_trivial_type: Bool = False]( - Boolable, Copyable, Defaultable, ExplicitlyCopyable, Movable, Sized -): - """The `List` type is a dynamically-allocated list. - - Parameters: - T: The type of the elements. - hint_trivial_type: A hint to the compiler that the type T is trivial. - It's not mandatory, but if set, it allows some optimizations. - - Notes: - It supports pushing and popping from the back resizing the underlying - storage as needed. When it is deallocated, it frees its memory. - """ - - # Fields - var data: UnsafePointer[T] - """The underlying storage for the list.""" - var _len: Int - """The number of elements in the list.""" - var capacity: Int - """The amount of elements that can fit in the list without resizing it.""" - - # ===-------------------------------------------------------------------===# - # Life cycle methods - # ===-------------------------------------------------------------------===# - - fn __init__(out self): - """Constructs an empty list.""" - self.data = UnsafePointer[T]() - self._len = 0 - self.capacity = 0 - - fn copy(self) -> Self: - """Creates a deep copy of the given list. - - Returns: - A copy of the value. - """ - var copy = Self(capacity=self.capacity) - for e in self: - copy.append(e) - return copy^ - - fn __init__(out self, *, capacity: Int): - """Constructs a list with the given capacity. - - Args: - capacity: The requested capacity of the list. - """ - if capacity: - self.data = UnsafePointer[T].alloc(capacity) - else: - self.data = UnsafePointer[T]() - self._len = 0 - self.capacity = capacity - - fn __init__(out self, *, length: UInt, fill: T): - """Constructs a list with the given capacity. - - Args: - length: The requested length of the list. - fill: The element to fill each element of the list. - """ - self = Self() - self.resize(length, fill) - - @always_inline - fn __init__(out self, owned *values: T, __list_literal__: () = ()): - """Constructs a list from the given values. - - Args: - values: The values to populate the list with. - __list_literal__: Tell Mojo to use this method for list literals. - """ - self = Self(elements=values^) - - fn __init__(out self, *, owned elements: VariadicListMem[T, _]): - """Constructs a list from the given values. - - Args: - elements: The values to populate the list with. - """ - var length = len(elements) - - self = Self(capacity=length) - - for i in range(length): - var src = UnsafePointer(to=elements[i]) - var dest = self.data + i - - src.move_pointee_into(dest) - - # Do not destroy the elements when their backing storage goes away. - __disable_del elements - - self._len = length - - fn __init__(out self, span: Span[T]): - """Constructs a list from the a Span of values. - - Args: - span: The span of values to populate the list with. - """ - self = Self(capacity=len(span)) - for value in span: - self.append(value) - - @always_inline - fn __init__(out self, *, unsafe_uninit_length: Int): - """Construct a list with the specified length, with uninitialized - memory. This is unsafe, as it relies on the caller initializing the - elements with unsafe operations, not assigning over the uninitialized - data. - - Args: - unsafe_uninit_length: The number of elements to allocate. - """ - self = Self(capacity=unsafe_uninit_length) - self._len = unsafe_uninit_length - - fn __copyinit__(out self, existing: Self): - """Creates a deepcopy of the given list. - - Args: - existing: The list to copy. - """ - self = Self(capacity=existing.capacity) - for i in range(len(existing)): - self.append(existing[i]) - - fn __del__(owned self): - """Destroy all elements in the list and free its memory.""" - - @parameter - if not hint_trivial_type: - for i in range(len(self)): - (self.data + i).destroy_pointee() - self.data.free() - - # ===-------------------------------------------------------------------===# - # Operator dunders - # ===-------------------------------------------------------------------===# - - @always_inline - fn __eq__[ - U: EqualityComparable & Copyable & Movable, // - ](self: List[U, *_], other: List[U, *_]) -> Bool: - """Checks if two lists are equal. - - Parameters: - U: The type of the elements in the list. Must implement the - trait `EqualityComparable`. - - Args: - other: The list to compare with. - - Returns: - True if the lists are equal, False otherwise. - - Examples: - - ```mojo - var x = [1, 2, 3] - var y = [1, 2, 3] - print("x and y are equal" if x == y else "x and y are not equal") - ``` - """ - if len(self) != len(other): - return False - var index = 0 - for element in self: - if element != other[index]: - return False - index += 1 - return True - - @always_inline - fn __ne__[ - U: EqualityComparable & Copyable & Movable, // - ](self: List[U, *_], other: List[U, *_]) -> Bool: - """Checks if two lists are not equal. - - Parameters: - U: The type of the elements in the list. Must implement the - trait `EqualityComparable`. - - Args: - other: The list to compare with. - - Returns: - True if the lists are not equal, False otherwise. - - Examples: - - ```mojo - var x = [1, 2, 3] - var y = [1, 2, 4] - print("x and y are not equal" if x != y else "x and y are equal") - ``` - """ - return not (self == other) - - fn __contains__[ - U: EqualityComparable & Copyable & Movable, // - ](self: List[U, *_], value: U) -> Bool: - """Verify if a given value is present in the list. - - Parameters: - U: The type of the elements in the list. Must implement the - trait `EqualityComparable`. - - Args: - value: The value to find. - - Returns: - True if the value is contained in the list, False otherwise. - - Examples: - - ```mojo - var x = [1, 2, 3] - print("x contains 3" if 3 in x else "x does not contain 3") - ``` - """ - for i in self: - if i == value: - return True - return False - - fn __mul__(self, x: Int) -> Self: - """Multiplies the list by x and returns a new list. - - Args: - x: The multiplier number. - - Returns: - The new list. - """ - # avoid the copy since it would be cleared immediately anyways - if x == 0: - return Self() - var result = self.copy() - result *= x - return result^ - - fn __imul__(mut self, x: Int): - """Appends the original elements of this list x-1 times or clears it if - x is <= 0. - - ```mojo - var a = [1, 2] - a *= 2 # a = [1, 2, 1, 2] - ``` - - Args: - x: The multiplier number. - """ - if x <= 0 or len(self) == 0: - self.clear() - return - var orig = self.copy() - self.reserve(len(self) * x) - for _ in range(x - 1): - self.extend(orig) - - fn __add__(self, owned other: Self) -> Self: - """Concatenates self with other and returns the result as a new list. - - Args: - other: List whose elements will be combined with the elements of - self. - - Returns: - The newly created list. - """ - var result = self.copy() - result.extend(other^) - return result^ - - fn __iadd__(mut self, owned other: Self): - """Appends the elements of other into self. - - Args: - other: List whose elements will be appended to self. - """ - self.extend(other^) - - fn __iter__(ref self) -> _ListIter[T, hint_trivial_type, __origin_of(self)]: - """Iterate over elements of the list, returning immutable references. - - Returns: - An iterator of immutable references to the list elements. - """ - return _ListIter(0, Pointer(to=self)) - - fn __reversed__( - ref self, - ) -> _ListIter[T, hint_trivial_type, __origin_of(self), False]: - """Iterate backwards over the list, returning immutable references. - - Returns: - A reversed iterator of immutable references to the list elements. - """ - return _ListIter[forward=False](len(self), Pointer(to=self)) - - # ===-------------------------------------------------------------------===# - # Trait implementations - # ===-------------------------------------------------------------------===# - - @always_inline("nodebug") - fn __len__(self) -> Int: - """Gets the number of elements in the list. - - Returns: - The number of elements in the list. - """ - return self._len - - fn __bool__(self) -> Bool: - """Checks whether the list has any elements or not. - - Returns: - `False` if the list is empty, `True` if there is at least one - element. - """ - return len(self) > 0 - - @no_inline - fn __str__[ - U: Representable & Copyable & Movable, // - ](self: List[U, *_]) -> String: - """Returns a string representation of a `List`. - - Parameters: - U: The type of the elements in the list. Must implement the - trait `Representable`. - - Returns: - A string representation of the list. - - Notes: - Note that since we can't condition methods on a trait yet, - the way to call this method is a bit special. Here is an example - below: - - ```mojo - var my_list = [1, 2, 3] - print(my_list.__str__()) - ``` - - When the compiler supports conditional methods, then a simple - `String(my_list)` will be enough. - """ - # at least 1 byte per item e.g.: [a, b, c, d] = 4 + 2 * 3 + [] + null - var l = len(self) - var output = String(capacity=l + 2 * (l - 1) * Int(l > 1) + 3) - self.write_to(output) - return output^ - - @no_inline - fn write_to[ - W: Writer, U: Representable & Copyable & Movable, // - ](self: List[U, *_], mut writer: W): - """Write `my_list.__str__()` to a `Writer`. - - Parameters: - W: A type conforming to the Writable trait. - U: The type of the List elements. Must have the trait - `Representable`. - - Args: - writer: The object to write to. - """ - writer.write("[") - for i in range(len(self)): - writer.write(repr(self[i])) - if i < len(self) - 1: - writer.write(", ") - writer.write("]") - - @no_inline - fn __repr__[ - U: Representable & Copyable & Movable, // - ](self: List[U, *_]) -> String: - """Returns a string representation of a `List`. - - Parameters: - U: The type of the elements in the list. Must implement the - trait `Representable`. - - Returns: - A string representation of the list. - - Notes: - Note that since we can't condition methods on a trait yet, the way - to call this method is a bit special. Here is an example below: - - ```mojo - var my_list = [1, 2, 3] - print(my_list.__repr__()) - ``` - - When the compiler supports conditional methods, then a simple - `repr(my_list)` will be enough. - """ - return self.__str__() - - # ===-------------------------------------------------------------------===# - # Methods - # ===-------------------------------------------------------------------===# - - fn byte_length(self) -> Int: - """Gets the byte length of the List (`len(self) * sizeof[T]()`). - - Returns: - The byte length of the List (`len(self) * sizeof[T]()`). - """ - return len(self) * sizeof[T]() - - @no_inline - fn _realloc(mut self, new_capacity: Int): - var new_data = UnsafePointer[T].alloc(new_capacity) - - @parameter - if hint_trivial_type: - memcpy(new_data, self.data, len(self)) - else: - for i in range(len(self)): - (self.data + i).move_pointee_into(new_data + i) - - if self.data: - self.data.free() - self.data = new_data - self.capacity = new_capacity - - fn memset_zero(mut self): - """Sets all elements in the list to zero.""" - - @parameter - if hint_trivial_type: - memset_zero(self.data, len(self)) - else: - return # TODO: how to reset to 0 unknown type? - - fn append(mut self, owned value: T): - """Appends a value to this list. - - Args: - value: The value to append. - - Notes: - If there is no capacity left, resizes to twice the current capacity. - Except for 0 capacity where it sets 1. - """ - if self._len >= self.capacity: - self._realloc(self.capacity * 2 | Int(self.capacity == 0)) - self._unsafe_next_uninit_ptr().init_pointee_move(value^) - self._len += 1 - - fn append(mut self, elements: Span[T, _]): - """Appends elements to this list. - - Args: - elements: The elements to append. - """ - var elements_len = len(elements) - var new_num_elts = self._len + elements_len - if new_num_elts > self.capacity: - # Make sure our capacity at least doubles to avoid O(n^2) behavior. - self._realloc(max(self.capacity * 2, new_num_elts)) - - var i = self._len - self._len = new_num_elts - - @parameter - if hint_trivial_type: - memcpy(self.data + i, elements.unsafe_ptr(), elements_len) - else: - for elt in elements: - UnsafePointer(to=self[i]).init_pointee_copy(elt) - i += 1 - - fn insert(mut self, i: Int, owned value: T): - """Inserts a value to the list at the given index. - `a.insert(len(a), value)` is equivalent to `a.append(value)`. - - Args: - i: The index for the value. - value: The value to insert. - """ - debug_assert(i <= len(self), "insert index out of range") - - var normalized_idx = i - if i < 0: - normalized_idx = max(0, len(self) + i) - - var earlier_idx = len(self) - var later_idx = len(self) - 1 - self.append(value^) - - for _ in range(normalized_idx, len(self) - 1): - var earlier_ptr = self.data + earlier_idx - var later_ptr = self.data + later_idx - - var tmp = earlier_ptr.take_pointee() - later_ptr.move_pointee_into(earlier_ptr) - later_ptr.init_pointee_move(tmp^) - - earlier_idx -= 1 - later_idx -= 1 - - fn extend(mut self, owned other: List[T, *_]): - """Extends this list by consuming the elements of `other`. - - Args: - other: List whose elements will be added in order at the end of this - list. - """ - - var other_len = len(other) - var final_size = len(self) + other_len - self.reserve(final_size) - - var dest_ptr = self.data + self._len - var src_ptr = other.unsafe_ptr() - - @parameter - if hint_trivial_type: - memcpy(dest_ptr, src_ptr, other_len) - else: - for _ in range(other_len): - # This (TODO: optimistically) moves an element directly from the - # `other` list into this list using a single `T.__moveinit()__` - # call, without moving into an intermediate temporary value - # (avoiding an extra redundant move constructor call). - src_ptr.move_pointee_into(dest_ptr) - src_ptr += 1 - dest_ptr += 1 - - # Update the size now since all elements have been moved into this list. - self._len = final_size - # The elements of `other` are now consumed, so we mark it as empty so - # they don't get destroyed when it goes out of scope. - other._len = 0 - - fn extend[ - D: DType, // - ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _]): - """Extends this list with the elements of a vector. - - Parameters: - D: The DType. - - Args: - value: The value to append. - - Notes: - If there is no capacity left, resizes to `len(self) + value.size`. - """ - self.reserve(self._len + value.size) - self._unsafe_next_uninit_ptr().store(value) - self._len += value.size - - fn extend[ - D: DType, // - ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _], *, count: Int): - """Extends this list with `count` number of elements from a vector. - - Parameters: - D: The DType. - - Args: - value: The value to append. - count: The amount of items to append. Must be less than or equal to - `value.size`. - - Notes: - If there is no capacity left, resizes to `len(self) + count`. - """ - debug_assert(count <= value.size, "count must be <= value.size") - self.reserve(self._len + count) - var v_ptr = UnsafePointer(to=value).bitcast[Scalar[D]]() - memcpy(self._unsafe_next_uninit_ptr(), v_ptr, count) - self._len += count - - fn extend[ - D: DType, // - ](mut self: List[Scalar[D], *_, **_], value: Span[Scalar[D]]): - """Extends this list with the elements of a `Span`. - - Parameters: - D: The DType. - - Args: - value: The value to append. - - Notes: - If there is no capacity left, resizes to `len(self) + len(value)`. - """ - self.reserve(self._len + len(value)) - memcpy(self._unsafe_next_uninit_ptr(), value.unsafe_ptr(), len(value)) - self._len += len(value) - - fn pop(mut self, i: Int = -1) -> T: - """Pops a value from the list at the given index. - - Args: - i: The index of the value to pop. - - Returns: - The popped value. - """ - debug_assert(-self._len <= i < self._len, "pop index out of range") - - var normalized_idx = i - if i < 0: - normalized_idx += self._len - - var ret_val = (self.data + normalized_idx).take_pointee() - for j in range(normalized_idx + 1, self._len): - (self.data + j).move_pointee_into(self.data + j - 1) - self._len -= 1 - - return ret_val^ - - fn reserve(mut self, new_capacity: Int): - """Reserves the requested capacity. - - Args: - new_capacity: The new capacity. - - Notes: - If the current capacity is greater or equal, this is a no-op. - Otherwise, the storage is reallocated and the date is moved. - """ - if self.capacity >= new_capacity: - return - self._realloc(new_capacity) - - fn resize(mut self, new_size: Int, value: T): - """Resizes the list to the given new size. - - Args: - new_size: The new size. - value: The value to use to populate new elements. - - Notes: - If the new size is smaller than the current one, elements at the end - are discarded. If the new size is larger than the current one, the - list is appended with new values elements up to the requested size. - """ - if new_size <= self._len: - self.shrink(new_size) - else: - self.reserve(new_size) - for i in range(self._len, new_size): - (self.data + i).init_pointee_copy(value) - self._len = new_size - - fn resize(mut self, *, unsafe_uninit_length: Int): - """Resizes the list to the given new size leaving any new elements - uninitialized. - - If the new size is smaller than the current one, elements at the end - are discarded. If the new size is larger than the current one, the - list is extended and the new elements are left uninitialized. - - Args: - unsafe_uninit_length: The new size. - """ - if unsafe_uninit_length <= self._len: - self.shrink(unsafe_uninit_length) - else: - self.reserve(unsafe_uninit_length) - self._len = unsafe_uninit_length - - fn shrink(mut self, new_size: Int): - """Resizes to the given new size which must be <= the current size. - - Args: - new_size: The new size. - - Notes: - With no new value provided, the new size must be smaller than or - equal to the current one. Elements at the end are discarded. - """ - if len(self) < new_size: - abort( - "You are calling List.resize with a new_size bigger than the" - " current size. If you want to make the List bigger, provide a" - " value to fill the new slots with. If not, make sure the new" - " size is smaller than the current size." - ) - - @parameter - if not hint_trivial_type: - for i in range(new_size, len(self)): - (self.data + i).destroy_pointee() - self._len = new_size - self.reserve(new_size) - - fn reverse(mut self): - """Reverses the elements of the list.""" - - var earlier_idx = 0 - var later_idx = len(self) - 1 - - var effective_len = len(self) - var half_len = effective_len // 2 - - for _ in range(half_len): - var earlier_ptr = self.data + earlier_idx - var later_ptr = self.data + later_idx - - var tmp = earlier_ptr.take_pointee() - later_ptr.move_pointee_into(earlier_ptr) - later_ptr.init_pointee_move(tmp^) - - earlier_idx += 1 - later_idx -= 1 - - # TODO: Remove explicit self type when issue 1876 is resolved. - fn index[ - C: EqualityComparable & Copyable & Movable, // - ]( - ref self: List[C, *_], - value: C, - start: Int = 0, - stop: Optional[Int] = None, - ) raises -> Int: - """Returns the index of the first occurrence of a value in a list - restricted by the range given the start and stop bounds. - - Args: - value: The value to search for. - start: The starting index of the search, treated as a slice index - (defaults to 0). - stop: The ending index of the search, treated as a slice index - (defaults to None, which means the end of the list). - - Parameters: - C: The type of the elements in the list. Must implement the - `EqualityComparable` trait. - - Returns: - The index of the first occurrence of the value in the list. - - Raises: - ValueError: If the value is not found in the list. - - Examples: - - ```mojo - var my_list = [1, 2, 3] - print(my_list.index(2)) # prints `1` - ``` - """ - var start_normalized = start - - var stop_normalized: Int - if stop is None: - # Default end - stop_normalized = len(self) - else: - stop_normalized = stop.value() - - if start_normalized < 0: - start_normalized += len(self) - if stop_normalized < 0: - stop_normalized += len(self) - - start_normalized = _clip(start_normalized, 0, len(self)) - stop_normalized = _clip(stop_normalized, 0, len(self)) - - for i in range(start_normalized, stop_normalized): - if self[i] == value: - return i - raise "ValueError: Given element is not in list" - - fn _binary_search_index[ - dtype: DType, //, - ](self: List[Scalar[dtype], **_], needle: Scalar[dtype]) -> Optional[UInt]: - """Finds the index of `needle` with binary search. - - Args: - needle: The value to binary search for. - - Returns: - Returns None if `needle` is not present, or if `self` was not - sorted. - - Notes: - This function will return an unspecified index if `self` is not - sorted in ascending order. - """ - var cursor = UInt(0) - var b = self.data - var length = len(self) - while length > 1: - var half = length >> 1 - length -= half - cursor += Int(b[cursor + half - 1] < needle) * half - - return Optional(cursor) if b[cursor] == needle else None - - fn clear(mut self): - """Clears the elements in the list.""" - for i in range(self._len): - (self.data + i).destroy_pointee() - self._len = 0 - - fn steal_data(mut self) -> UnsafePointer[T]: - """Take ownership of the underlying pointer from the list. - - Returns: - The underlying data. - """ - var ptr = self.data - self.data = UnsafePointer[T]() - self._len = 0 - self.capacity = 0 - return ptr - - fn __getitem__(self, slice: Slice) -> Self: - """Gets the sequence of elements at the specified positions. - - Args: - slice: A slice that specifies positions of the new list. - - Returns: - A new list containing the list at the specified slice. - """ - var start, end, step = slice.indices(len(self)) - var r = range(start, end, step) - - if not len(r): - return Self() - - var res = Self(capacity=len(r)) - for i in r: - res.append(self[i]) - - return res^ - - fn __getitem__[I: Indexer](ref self, idx: I) -> ref [self] T: - """Gets the list element at the given index. - - Args: - idx: The index of the element. - - Parameters: - I: A type that can be used as an index. - - Returns: - A reference to the element at the given index. - """ - - @parameter - if _type_is_eq[I, UInt](): - return (self.data + idx)[] - else: - var normalized_idx = Int(idx) - debug_assert( - -self._len <= normalized_idx < self._len, - "index: ", - normalized_idx, - " is out of bounds for `List` of length: ", - self._len, - ) - if normalized_idx < 0: - normalized_idx += len(self) - - return (self.data + normalized_idx)[] - - @always_inline - fn unsafe_get(ref self, idx: Int) -> ref [self] Self.T: - """Get a reference to an element of self without checking index bounds. - - Args: - idx: The index of the element to get. - - Returns: - A reference to the element at the given index. - - Notes: - Users should consider using `__getitem__` instead of this method as - it is unsafe. If an index is out of bounds, this method will not - abort, it will be considered undefined behavior. - - Note that there is no wraparound for negative indices, caution is - advised. Using negative indices is considered undefined behavior. - Never use `my_list.unsafe_get(-1)` to get the last element of the - list. Instead, do `my_list.unsafe_get(len(my_list) - 1)`. - """ - debug_assert( - 0 <= idx < len(self), - ( - "The index provided must be within the range [0, len(List) -1]" - " when using List.unsafe_get()" - ), - ) - return (self.data + idx)[] - - @always_inline - fn unsafe_set(mut self, idx: Int, owned value: T): - """Write a value to a given location without checking index bounds. - - Args: - idx: The index of the element to set. - value: The value to set. - - Notes: - Users should consider using `my_list[idx] = value` instead of this - method as it is unsafe. If an index is out of bounds, this method - will not abort, it will be considered undefined behavior. - - Note that there is no wraparound for negative indices, caution is - advised. Using negative indices is considered undefined behavior. - Never use `my_list.unsafe_set(-1, value)` to set the last element of - the list. Instead, do `my_list.unsafe_set(len(my_list) - 1, value)`. - """ - debug_assert( - 0 <= idx < len(self), - ( - "The index provided must be within the range [0, len(List) -1]" - " when using List.unsafe_set()" - ), - ) - (self.data + idx).destroy_pointee() - (self.data + idx).init_pointee_move(value^) - - fn count[ - T: EqualityComparable & Copyable & Movable, // - ](self: List[T, *_], value: T) -> Int: - """Counts the number of occurrences of a value in the list. - - Parameters: - T: The type of the elements in the list. Must implement the - trait `EqualityComparable`. - - Args: - value: The value to count. - - Returns: - The number of occurrences of the value in the list. - """ - var count = 0 - for elem in self: - if elem == value: - count += 1 - return count - - fn swap_elements(mut self, elt_idx_1: Int, elt_idx_2: Int): - """Swaps elements at the specified indexes if they are different. - - Args: - elt_idx_1: The index of one element. - elt_idx_2: The index of the other element. - - Examples: - - ```mojo - var my_list = [1, 2, 3] - my_list.swap_elements(0, 2) - print(my_list.__str__()) # 3, 2, 1 - ``` - - Notes: - This is useful because `swap(my_list[i], my_list[j])` cannot be - supported by Mojo, because a mutable alias may be formed. - """ - debug_assert( - 0 <= elt_idx_1 < len(self) and 0 <= elt_idx_2 < len(self), - ( - "The indices provided to swap_elements must be within the range" - " [0, len(List)-1]" - ), - ) - if elt_idx_1 != elt_idx_2: - swap((self.data + elt_idx_1)[], (self.data + elt_idx_2)[]) - - fn unsafe_ptr( - ref self, - ) -> UnsafePointer[ - T, - mut = Origin(__origin_of(self)).mut, - origin = __origin_of(self), - ]: - """Retrieves a pointer to the underlying memory. - - Returns: - The pointer to the underlying memory. - """ - return self.data.origin_cast[ - mut = Origin(__origin_of(self)).mut, origin = __origin_of(self) - ]() - - @always_inline - fn _unsafe_next_uninit_ptr( - ref self, - ) -> UnsafePointer[ - T, - mut = Origin(__origin_of(self)).mut, - origin = __origin_of(self), - ]: - """Retrieves a pointer to the next uninitialized element position. - - Safety: - - - This pointer MUST not be used to read or write memory beyond the - allocated capacity of this list. - - This pointer may not be used to initialize non-contiguous elements. - - Ensure that `List._len` is updated to reflect the new number of - initialized elements, otherwise elements may be unexpectedly - overwritten or not destroyed correctly. - - Notes: - This returns a pointer that points to the element position immediately - after the last initialized element. This is equivalent to - `list.unsafe_ptr() + len(list)`. - """ - debug_assert( - self.capacity > 0 and self.capacity > self._len, - ( - "safety violation: Insufficient capacity to retrieve pointer to" - " next uninitialized element" - ), - ) - - # self.unsafe_ptr() + self._len won't work because .unsafe_ptr() - # takes a ref that might mutate self - var length = self._len - return self.unsafe_ptr() + length - - fn _cast_hint_trivial_type[ - hint_trivial_type: Bool - ](owned self) -> List[T, hint_trivial_type]: - var result = List[T, hint_trivial_type]() - result.data = self.data - result._len = self._len - result.capacity = self.capacity - - # We stole the elements, don't destroy them. - __disable_del self - - return result^ - - -fn _clip(value: Int, start: Int, end: Int) -> Int: - return max(start, min(value, end)) From 8cc6b08395f96c844d39cc6105a30a2b8ede304d Mon Sep 17 00:00:00 2001 From: ttrenty <154608953+ttrenty@users.noreply.github.com> Date: Thu, 26 Jun 2025 23:20:54 -0600 Subject: [PATCH 3/3] fix: update README + test partial_trace() + Update TODOs --- README.md | 4 +- TODOs.md | 53 +++++++--- examples/main.mojo | 53 +++++++++- pixi.toml | 7 +- src/abstractions/gate_circuit.mojo | 3 +- src/abstractions/simulator.mojo | 23 ++--- src/base/qubits_operations.mojo | 58 ++++++----- src/base/state_and_matrix.mojo | 10 +- tests/base/test_qubit_operations.mojo | 139 ++++++++++++++++++++++++++ tests/base/testing_matrix.mojo | 51 ++++++++++ 10 files changed, 336 insertions(+), 65 deletions(-) create mode 100644 tests/base/test_qubit_operations.mojo create mode 100644 tests/base/testing_matrix.mojo diff --git a/README.md b/README.md index fe00738..a796c23 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,8 @@ curl -sSf https://pixi.sh/install.sh | bash # Install all project dependencies: pixi install -# Build and run the simulator: -pixi run mojo build src/main.mojo && ./main +# Build and run examples of the simulator: +pixi run main ``` diff --git a/TODOs.md b/TODOs.md index 803d02a..af6d789 100644 --- a/TODOs.md +++ b/TODOs.md @@ -1,25 +1,54 @@ # TODOs -A pure state is represented by a ket vector, e.g., $ |\psi\rangle = \alpha|0\rangle + \beta|1\rangle $, where $ \alpha $ and $ \beta $ are complex numbers satisfying the normalization condition $ |\alpha|^2 + |\beta|^2 = 1 $. - ## Ordered by Priority (5 ↑) / Difficulty (5 ↓) -- 5 / 1 : density matrix calculcation from state vectors - -- 5 / 5 : Extend qubitWiseMultiply() to 2 and multiple qubits gates (Test it) +### Implementations - 4 / 3 : Implement measurement gates - - 4 / 5 : Implement the computation of statistics: (6.5 and 6.6) - -- 3 / 2 : Use a separate list for things that are not real gate to not slow down the main run logic - -- 3 / 2 : Use a separate list for things that are not real gate to not slow down the main run logic +- 4 / 5 : Implement the computation of statistics (6.5 and 6.6) -- 3 / 3 : Implement naive implementation of the functions +- 3 / 3 : Implement naive implementation of the functions to compare performances - matrix multiplication (but starting from right or smart) - partial trace -- 3 / 4 : Compile time circuit creation? +- 5 / 5 : Start adding support for GPU in the base classes if needed (not possible to use SIMD(complexfloat64) anymore, or keep them but seperate them when moving data to GPU) + - struct PureBasisState + - struct ComplexMatrix + - struct Gate + +- 5 / ? : GPU implementation of: + - qubit_wise_multiply() + - apply_swap() + - partial_trace() + + +### Tests + +- 5 / 2 : Test qubit_wise_multiply_extended() that can take multiple qubits gates (2 and more, iSWAP for example) + +- 5 / 2 : Test for everything that will be implement in GPU + - qubit_wise_multiply() + - apply_swap() + - partial_trace() + - struct PureBasisState's methods + - struct ComplexMatrix's methods + - struct Gate's Gate + +### Benchmarks - 3 / 2 : Reproduce table from page 10 + +## Droped for now + +- 4 / 4 : Gradient computation with Finite Difference + +- 3 / 2 : Use a separate list for things that are not real gate to not slow down the main run logic + +- 3 / 4 : Compile time circuit creation? + +- 3 / 4 : Gradient computation with Parameter-Shift + +- 3 / 100 : Gradient computation with Adjoint Method + +- 2 / 4 : qubit_wise_multiply_extended() but for gates applied to non-adjacent qubits diff --git a/examples/main.mojo b/examples/main.mojo index f817026..139afbd 100644 --- a/examples/main.mojo +++ b/examples/main.mojo @@ -380,6 +380,47 @@ fn presentation() -> None: print("Final quantum state:\n", final_state) +fn test_density_matrix() -> None: + """ + Returns the density matrix of the given quantum state. + If qubits is empty, returns the full density matrix. + """ + num_qubits: Int = 2 + qc: GateCircuit = GateCircuit(num_qubits) + + qc.apply_gates( + Hadamard(0), + Hadamard(1, controls=[0]), + Z(0), + X(1), + ) + + print("Quantum circuit created:\n", qc) + + qsimu = StateVectorSimulator( + qc, + initial_state=PureBasisState.from_bitstring("00"), + optimisation_level=0, # No optimisations for now + verbose=True, + verbose_step_size=ShowAfterEachGate, # ShowAfterEachGate, ShowOnlyEnd + ) + final_state = qsimu.run() + print("Final quantum state:\n", final_state) + + matrix = final_state.to_density_matrix() + print("Density matrix:\n", matrix) + other_matrix = partial_trace(final_state, []) # Empty list means full trace + print("Partial trace matrix:\n", other_matrix) + other_matrix_0 = partial_trace( + final_state, [0] + ) # Empty list means full trace + print("Partial trace matrix qubit 0:\n", other_matrix_0) + other_matrix_1 = partial_trace( + final_state, [1] + ) # Empty list means full trace + print("Partial trace matrix qubit 1:\n", other_matrix_1) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # MARK: Tests # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # @@ -427,8 +468,8 @@ fn presentation() -> None: def main(): args = argv() - number_qubits: Int = 15 - number_layers: Int = 10 + number_qubits: Int = 10 + number_layers: Int = 20 if len(args) == 3: try: number_qubits = Int(args[1]) @@ -441,14 +482,16 @@ def main(): else: print("Usage: ./main [number_of_qubits] [number_of_layers]") - simulate_figure1_circuit() + # simulate_figure1_circuit() - # simulate_figure1_circuit_abstract() + simulate_figure1_circuit_abstract() - # simulate_random_circuit(number_qubits, number_layers) + simulate_random_circuit(number_qubits, number_layers) # simulate_figure4_circuit() # simulate_figure4_circuit_abstract() # presentation() + + # test_density_matrix() diff --git a/pixi.toml b/pixi.toml index f9384e7..8f16304 100644 --- a/pixi.toml +++ b/pixi.toml @@ -13,17 +13,17 @@ cmd = ".github/scripts/check-format.sh" [tasks.format] # Format the code cmd = "pixi run mojo format ./src ./tests" -inputs = ["./src/**/*.mojo", "./tests/**/*.mojo"] +inputs = ["./examples/**/*.mojo", "./src/**/*.mojo", "./tests/**/*.mojo"] [tasks.create_build_dir] cmd = "mkdir -p build/" [tasks.build] # Compile any mojo file args = [ - { "arg" = "full_file_path", "default" = "main" }, + { "arg" = "full_file_path", "default" = "examples/main.mojo" }, { "arg" = "executable_name", "default" = "main" }, ] -inputs = ["{{ full_file_path }}"] +inputs = ["{{ full_file_path }}", "./src/**/*.mojo"] outputs = ["build/{{ executable_name }}"] cmd = "pixi run mojo build {{ full_file_path }} -o {{ executable_name }} && cp {{ executable_name }} build/{{ executable_name }} && rm {{ executable_name }}" depends-on = ["create_build_dir"] @@ -70,6 +70,7 @@ depends-on = ["install"] [tasks] +tests = [{ task = "test" }] p = [{ task = "clear" }, { task = "package" }] m = [{ task = "clear" }, { task = "main" }] # t = "clear && pixi run package && pixi run mojo test tests --filter" diff --git a/src/abstractions/gate_circuit.mojo b/src/abstractions/gate_circuit.mojo index fcd7e98..c3da343 100644 --- a/src/abstractions/gate_circuit.mojo +++ b/src/abstractions/gate_circuit.mojo @@ -112,7 +112,8 @@ struct GateCircuit(Movable, Stringable, Writable): controls_flags: List[List[Int]] = gate.control_qubits_with_flags for control in controls_flags: - control_index, flag = control[0], control[1] + # control_index, flag = control[0], control[1] + control_index = control[0] while ( wires_current_gate[qubit_index] < wires_current_gate[control_index] diff --git a/src/abstractions/simulator.mojo b/src/abstractions/simulator.mojo index 7d2424b..e59aed2 100644 --- a/src/abstractions/simulator.mojo +++ b/src/abstractions/simulator.mojo @@ -211,15 +211,8 @@ struct StateVectorSimulator(Copyable, Movable): i: Int = 0 layer_index: Int = 0 for gate in self.circuit.gates: # Iterate over the gates in the circuit - if gate.symbol not in [_SEPARATOR.symbol, SWAP.symbol]: - # Apply the next gate - quantum_state = qubit_wise_multiply_extended( - len(gate.target_qubits), # Number of target qubits - gate.matrix, - gate.target_qubits, # Assuming single target qubit - quantum_state, - gate.control_qubits_with_flags, - ) + if gate.symbol == _SEPARATOR.symbol: + continue elif gate.symbol == SWAP.symbol: if len(gate.target_qubits) != 2: print("Error: SWAP gate must have exactly 2 target qubits.") @@ -231,11 +224,15 @@ struct StateVectorSimulator(Copyable, Movable): quantum_state, gate.control_qubits_with_flags, ) - elif gate.symbol == _SEPARATOR.symbol: - continue else: - print("Error: Unexpected gate symbol:", gate.symbol) - continue # Skip unexpected symbols + # Apply the next gate + quantum_state = qubit_wise_multiply_extended( + len(gate.target_qubits), # Number of target qubits + gate.matrix, + gate.target_qubits, # Assuming single target qubit + quantum_state, + gate.control_qubits_with_flags, + ) i += 1 if self.verbose: diff --git a/src/base/qubits_operations.mojo b/src/base/qubits_operations.mojo index 5c4507c..e530b36 100644 --- a/src/base/qubits_operations.mojo +++ b/src/base/qubits_operations.mojo @@ -587,40 +587,48 @@ fn partial_trace[ n = quantum_state.number_qubits() conj_quantum_state = quantum_state.conjugate() - is_traced_out: List[Bool] = [False] * n - for i in range(len(qubits_to_trace_out)): - is_traced_out[qubits_to_trace_out[i]] = True - - qubits_to_keep: List[Int] = [] - for i in range(n): - if not is_traced_out[i]: - qubits_to_keep.append(i) - # is_traced_out: List[Bool] = [False] * n - # qubits_to_keep: List[Int] = [] + # for i in range(len(qubits_to_trace_out)): + # is_traced_out[qubits_to_trace_out[i]] = True - # current_index: Int = 0 - # current_traced_index: Int = 0 - # for _ in range(n): - # if qubits_to_trace_out[current_traced_index] == current_index: - # is_traced_out[current_index] = True - # current_traced_index += 1 - # if current_traced_index >= len(qubits_to_trace_out): - # break - # else: + # qubits_to_keep: List[Int] = [] + # for i in range(n): + # if not is_traced_out[i]: # qubits_to_keep.append(i) - # current_index += 1 - # # Ensure all qubits have been added to qubits_to_keep - # for i in range(current_index, n): - # qubits_to_keep.append(i) + is_traced_out: List[Bool] = [False] * n + qubits_to_keep: List[Int] = [] + + current_index: Int = 0 + current_traced_index: Int = 0 + for _ in range(n): + if current_traced_index >= len(qubits_to_trace_out): + break + if qubits_to_trace_out[current_traced_index] == current_index: + is_traced_out[current_index] = True + current_traced_index += 1 + current_index += 1 + if current_traced_index >= len(qubits_to_trace_out): + break + else: + qubits_to_keep.append(current_index) + current_index += 1 + + # Ensure all qubits have been added to qubits_to_keep + for i in range(current_index, n): + qubits_to_keep.append(i) num_qubits_to_trace_out: Int = len(qubits_to_trace_out) num_qubits_to_keep: Int = len(qubits_to_keep) if num_qubits_to_trace_out + num_qubits_to_keep != n: print( - "Error: The total number of qubits to trace out and keep does not" - " match the total number of qubits." + "Error: The total number of qubits to trace out (", + num_qubits_to_trace_out, + ") and keep (", + num_qubits_to_keep, + ") does not match the total number of qubits (", + n, + ").", ) return ComplexMatrix(0, 0) # Return an empty matrix diff --git a/src/base/state_and_matrix.mojo b/src/base/state_and_matrix.mojo index 3a5dcd6..ff4defc 100644 --- a/src/base/state_and_matrix.mojo +++ b/src/base/state_and_matrix.mojo @@ -36,7 +36,7 @@ struct PureBasisState(Copyable, Movable, Stringable, Writable): self.state_vector = CustomList[ComplexFloat64, hint_trivial_type=True]( length=size, fill=ComplexFloat64(0.0, 0.0) ) - self.state_vector.memset_zero() # Initialize the state vector with zeros + # self.state_vector.memset_zero() @always_inline fn __getitem__(self, index: Int) -> ComplexFloat64: @@ -64,9 +64,11 @@ struct PureBasisState(Copyable, Movable, Stringable, Writable): if amplitude_im == 0.0 and amplitude_re == 0.0: amplitude_str: String = String(Int(amplitude_re)) elif amplitude_im == 0.0: - amplitude_str = String(round(amplitude_re, 2)) + # amplitude_str = String(round(amplitude_re, 2)) + amplitude_str = String(amplitude_re) elif amplitude_re == 0.0: - amplitude_str = String(round(amplitude_im, 2)) + "i" + # amplitude_str = String(round(amplitude_im, 2)) + "i" + amplitude_str = String(amplitude_im) + "i" else: amplitude_str = ( String(round(amplitude_re, 2)) @@ -179,7 +181,7 @@ struct PureBasisState(Copyable, Movable, Stringable, Writable): conjugated_state.state_vector[i] = self.state_vector[i].conjugate() return conjugated_state - fn get_density_matrix(self) -> ComplexMatrix: + fn to_density_matrix(self) -> ComplexMatrix: """Returns the density matrix of the pure state. The density matrix is computed as the outer product of the state vector with itself. diff --git a/tests/base/test_qubit_operations.mojo b/tests/base/test_qubit_operations.mojo new file mode 100644 index 0000000..d6e6ef2 --- /dev/null +++ b/tests/base/test_qubit_operations.mojo @@ -0,0 +1,139 @@ +from testing import ( + assert_true, + assert_false, + assert_equal, + assert_not_equal, + assert_almost_equal, +) + +from testing_matrix import assert_matrix_almost_equal + +from qlabs.base import ( + PureBasisState, + ComplexMatrix, + Gate, + Hadamard, + PauliX, + PauliY, + PauliZ, + NOT, + H, + X, + Y, + Z, + SWAP, + iSWAP, + qubit_wise_multiply, + qubit_wise_multiply_extended, + apply_swap, + partial_trace, +) + +from qlabs.local_stdlib import CustomList +from qlabs.local_stdlib.complex import ComplexFloat64 + + +def test_partial_trace_all(): + """Test the partial trace operation on a 2-qubit state. Keep all qubits.""" + state: PureBasisState = PureBasisState( + 2, + CustomList[ComplexFloat64, hint_trivial_type=True]( + ComplexFloat64(0, 0), + ComplexFloat64(-0.5, 0), + ComplexFloat64(0.7071067811863477, 0), + ComplexFloat64(-0.5, 0), + ), + ) + matrix: ComplexMatrix = partial_trace(state, []) + assert_matrix_almost_equal( + matrix, + ComplexMatrix( + List[List[ComplexFloat64]]( + [ + ComplexFloat64(0, 0), + ComplexFloat64(0, 0), + ComplexFloat64(0, 0), + ComplexFloat64(0, 0), + ], + [ + ComplexFloat64(0, 0), + ComplexFloat64(0.25, 0), + ComplexFloat64(-0.3535533905929741, 0), + ComplexFloat64(0.25, 0), + ], + [ + ComplexFloat64(0, 0), + ComplexFloat64(-0.3535533905929741, 0), + ComplexFloat64(0.5, 0), + ComplexFloat64(-0.3535533905929741, 0), + ], + [ + ComplexFloat64(0, 0), + ComplexFloat64(0.25, 0), + ComplexFloat64(-0.3535533905929741, 0), + ComplexFloat64(0.25, 0), + ], + ) + ), + "full trace", + ) + + +def test_partial_trace_qubit0(): + """Test the partial trace operation on a 2-qubit state. Trace out qubit 0, keep qubit 1. + """ + state: PureBasisState = PureBasisState( + 2, + CustomList[ComplexFloat64, hint_trivial_type=True]( + ComplexFloat64(0, 0), + ComplexFloat64(-0.5, 0), + ComplexFloat64(0.7071067811863477, 0), + ComplexFloat64(-0.5, 0), + ), + ) + matrix = partial_trace(state, [0]) + assert_matrix_almost_equal( + matrix, + ComplexMatrix( + List[List[ComplexFloat64]]( + [ + ComplexFloat64(0, 0), + ComplexFloat64(0, 0), + ], + [ + ComplexFloat64(0, 0), + ComplexFloat64(1, 0), + ], + ) + ), + ) + + +def test_partial_trace_qubit1(): + """Test the partial trace operation on a 2-qubit state. Trace out qubit 1, keep qubit 0. + """ + state: PureBasisState = PureBasisState( + 2, + CustomList[ComplexFloat64, hint_trivial_type=True]( + ComplexFloat64(0, 0), + ComplexFloat64(-0.5, 0), + ComplexFloat64(0.7071067811863477, 0), + ComplexFloat64(-0.5, 0), + ), + ) + matrix = partial_trace(state, [1]) + assert_matrix_almost_equal( + matrix, + ComplexMatrix( + List[List[ComplexFloat64]]( + [ + ComplexFloat64(0, 0), + ComplexFloat64(0, 0), + ], + [ + ComplexFloat64(0, 0), + ComplexFloat64(0.5, 0), + ], + ) + ), + ) diff --git a/tests/base/testing_matrix.mojo b/tests/base/testing_matrix.mojo new file mode 100644 index 0000000..01a77a6 --- /dev/null +++ b/tests/base/testing_matrix.mojo @@ -0,0 +1,51 @@ +from testing import ( + assert_true, + assert_false, + assert_equal, + assert_not_equal, + assert_almost_equal, +) + +from qlabs.base import ComplexMatrix + + +def assert_matrix_almost_equal( + reference_matrix: ComplexMatrix, matrix: ComplexMatrix, message: String = "" +) -> None: + """Asserts that two complex matrices are almost equal. + + Args: + reference_matrix: The reference matrix to compare against. + matrix: The matrix to check for equality. + """ + assert_equal( + reference_matrix.size(), + matrix.size(), + String("Matrices must have the same size.") + message, + ) + for i in range(reference_matrix.size()): + for j in range(reference_matrix.size()): + assert_almost_equal( + reference_matrix[i, j].re, + matrix[i, j].re, + String( + "Real parts of matrices are not equal at indices (", + i, + ", ", + j, + "). ", + ) + + message, + ) + assert_almost_equal( + reference_matrix[i, j].im, + matrix[i, j].im, + String( + "Imaginary parts of matrices are not equal at indices (", + i, + ", ", + j, + "). ", + ) + + message, + )