Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace np.ndarray with npt.NDArray for improved type hinting #4899

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmarks/different_model_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pybamm
from benchmarks.benchmark_utils import set_random_seed
import numpy as np
import numpy.typing as npt


def compute_discretisation(model, param):
Expand Down Expand Up @@ -33,8 +34,8 @@ def build_model(parameter, model_, option, value):
class SolveModel:
solver: pybamm.BaseSolver
model: pybamm.BaseModel
t_eval: np.ndarray
t_interp: np.ndarray | None
t_eval: npt.NDArray
t_interp: npt.NDArray | None

def solve_setup(self, parameter, model_, option, value, solver_class):
self.solver = solver_class()
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/time_solve_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pybamm
from benchmarks.benchmark_utils import set_random_seed
import numpy as np
import numpy.typing as npt


def solve_model_once(model, solver, t_eval, t_interp):
Expand All @@ -30,8 +31,8 @@ class TimeSolveSPM:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: np.ndarray
t_interp: np.ndarray | None
t_eval: npt.NDArray
t_interp: npt.NDArray | None

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down Expand Up @@ -96,7 +97,7 @@ class TimeSolveSPMe:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: np.ndarray
t_eval: npt.NDArray

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down Expand Up @@ -160,7 +161,7 @@ class TimeSolveDFN:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: np.ndarray
t_eval: npt.NDArray

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down
7 changes: 4 additions & 3 deletions src/pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations
import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix, issparse

import pybamm
Expand Down Expand Up @@ -38,7 +39,7 @@ class Array(pybamm.Symbol):

def __init__(
self,
entries: np.ndarray | list[float] | csr_matrix,
entries: npt.NDArray | list[float] | csr_matrix,
name: str | None = None,
domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
Expand Down Expand Up @@ -144,8 +145,8 @@ def create_copy(
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numbers

import numpy as np
import numpy.typing as npt
import sympy
from scipy.sparse import csr_matrix, issparse
import functools
Expand Down Expand Up @@ -152,8 +153,8 @@ def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol):
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down
9 changes: 5 additions & 4 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional

import numpy as np
import numpy.typing as npt
import sympy
from scipy.sparse import issparse, vstack
from collections.abc import Sequence
Expand Down Expand Up @@ -112,7 +113,7 @@ def get_children_domains(self, children: Sequence[pybamm.Symbol]):

return domains

def _concatenation_evaluate(self, children_eval: list[np.ndarray]):
def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
"""See :meth:`Concatenation._concatenation_evaluate()`."""
if len(children_eval) == 0:
return np.array([])
Expand All @@ -122,8 +123,8 @@ def _concatenation_evaluate(self, children_eval: list[np.ndarray]):
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down Expand Up @@ -367,7 +368,7 @@ def create_slices(self, node: pybamm.Symbol) -> defaultdict:
start = end
return slices

def _concatenation_evaluate(self, children_eval: list[np.ndarray]):
def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
"""See :meth:`Concatenation._concatenation_evaluate()`."""
# preallocate vector
vector = np.empty((self._size, 1))
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/discrete_time_sum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pybamm
import numpy as np
import numpy.typing as npt


class DiscreteTimeData(pybamm.Interpolant):
Expand All @@ -19,7 +19,7 @@ class DiscreteTimeData(pybamm.Interpolant):

"""

def __init__(self, time_points: np.ndarray, data: np.ndarray, name: str):
def __init__(self, time_points: npt.NDArray, data: npt.NDArray, name: str):
super().__init__(time_points, data, pybamm.t, name)

def create_copy(self, new_children=None, perform_simplifications=True):
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import numpy as np
import numpy.typing as npt
from scipy import special
import sympy
from typing import Callable
Expand Down Expand Up @@ -122,8 +123,8 @@ def _function_jac(self, children_jacs):
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down
7 changes: 3 additions & 4 deletions src/pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
#
from __future__ import annotations
import sympy
import numpy as np

import numpy.typing as npt
import pybamm
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType

Expand Down Expand Up @@ -94,8 +93,8 @@ def create_copy(
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations
import numbers
import numpy as np
import numpy.typing as npt
import scipy.sparse
import pybamm

Expand Down Expand Up @@ -88,8 +89,8 @@ def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix:
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
# inputs should be a dictionary
Expand Down
7 changes: 4 additions & 3 deletions src/pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations
import numpy as np
import numpy.typing as npt
from scipy import interpolate
from collections.abc import Sequence
import numbers
Expand Down Expand Up @@ -43,8 +44,8 @@ class Interpolant(pybamm.Function):

def __init__(
self,
x: np.ndarray | Sequence[np.ndarray],
y: np.ndarray,
x: npt.NDArray | Sequence[npt.NDArray],
y: npt.NDArray,
children: Sequence[pybamm.Symbol] | pybamm.Time,
name: str | None = None,
interpolator: str | None = "linear",
Expand Down Expand Up @@ -96,7 +97,7 @@ def __init__(
x1 = x[0]
else:
x1 = x
x: list[np.ndarray] = [x] # type: ignore[no-redef]
x: list[npt.NDArray] = [x] # type: ignore[no-redef]
x2 = None
if x1.shape[0] != y.shape[0]:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion src/pybamm/expression_tree/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations
import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix, issparse

import pybamm
Expand All @@ -16,7 +17,7 @@ class Matrix(pybamm.Array):

def __init__(
self,
entries: np.ndarray | list[float] | csr_matrix,
entries: npt.NDArray | list[float] | csr_matrix,
name: str | None = None,
domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations
import numpy as np
import numpy.typing as npt
import sympy
from typing import Literal

Expand Down Expand Up @@ -66,8 +67,8 @@ def set_id(self):
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down
9 changes: 5 additions & 4 deletions src/pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations
import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix, vstack

import pybamm
Expand Down Expand Up @@ -281,8 +282,8 @@ def __init__(
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down Expand Up @@ -365,8 +366,8 @@ def __init__(
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down
9 changes: 5 additions & 4 deletions src/pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings

import numpy as np
import numpy.typing as npt
import sympy
from scipy.sparse import csr_matrix, issparse
from functools import cached_property
Expand Down Expand Up @@ -769,8 +770,8 @@ def _jac(self, variable):
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""
Expand Down Expand Up @@ -801,8 +802,8 @@ def _base_evaluate(
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
) -> ChildValue:
"""Evaluate expression tree (wrapper to allow using dict of known values).
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix, issparse
import sympy
import pybamm
Expand Down Expand Up @@ -93,8 +94,8 @@ def _unary_evaluate(self, child):
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
from __future__ import annotations
import numpy as np

import numpy.typing as npt
import pybamm
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType

Expand All @@ -15,7 +15,7 @@ class Vector(pybamm.Array):

def __init__(
self,
entries: np.ndarray | list[float] | np.matrix,
entries: npt.NDArray | list[float] | np.matrix,
name: str | None = None,
domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
Expand Down
7 changes: 3 additions & 4 deletions src/pybamm/models/event.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from enum import Enum
import numpy as np

import numpy.typing as npt
from typing import TypeVar


Expand Down Expand Up @@ -75,8 +74,8 @@ def _from_json(cls: type[E], snippet: dict) -> E:
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | None = None,
):
"""
Expand Down
Loading
Loading