Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make npt.NDArray type hints more specific with dtype #4901

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

vidipsingh
Copy link
Contributor

@vidipsingh vidipsingh commented Mar 8, 2025

Description

This PR refines npt.NDArray type hints in PyBaMM by adding explicit dtype (e.g., np.float64 for time/state arrays, Any for variable cases).

Fixes: #4900

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #)

Important checks:

Please confirm the following before marking the PR as ready for review:

  • No style issues: nox -s pre-commit
  • All tests pass: nox -s tests
  • The documentation builds: nox -s doctests
  • Code is commented for hard-to-understand areas
  • Tests added that prove fix is effective or that feature works
mypy Output (Before):

src/pybamm/telemetry.py:18: error: Incompatible types in assignment (expression has type "Posthog", variable has type "MockTelemetry")  [assignment]
src/pybamm/telemetry.py:23: error: "MockTelemetry" has no attribute "log"  [attr-defined]
src/pybamm/config.py:168: error: Library stubs not installed for "yaml"  [import-untyped]
src/pybamm/config.py:168: note: Hint: "python3 -m pip install types-PyYAML"
src/pybamm/config.py:168: note: (or run "mypy --install-types" to install all missing stub packages)
src/pybamm/config.py:168: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports
src/pybamm/solvers/summary_variable.py:43: error: Need type annotation for "_variables" (hint: "_variables: dict[<type>, <type>] = ...")  [var-annotated]
src/pybamm/solvers/summary_variable.py:71: error: Incompatible types in assignment (expression has type "list[SummaryVariables]", variable has type "None")  [assignment]
src/pybamm/solvers/summary_variable.py:72: error: Argument 1 to "len" has incompatible type "None"; expected "Sized"  [arg-type]
src/pybamm/solvers/summary_variable.py:73: error: Value of type "None" is not indexable  [index]
src/pybamm/solvers/summary_variable.py:85: error: Cannot determine type of "_all_variables"  [has-type]
src/pybamm/solvers/summary_variable.py:103: error: Item "None" of "ElectrodeSOHSolver | None" has no attribute "_get_electrode_soh_sims_full"  [union-attr]
src/pybamm/solvers/summary_variable.py:105: error: Incompatible types in assignment (expression has type "list[Any]", variable has type "None")  [assignment]
src/pybamm/solvers/summary_variable.py:126: error: Incompatible return value type (got "ndarray[Any, dtype[Any]]", expected "float | list[float]")  [return-value]
src/pybamm/solvers/summary_variable.py:151: error: "None" has no attribute "__iter__" (not iterable)  [attr-defined]
src/pybamm/solvers/summary_variable.py:153: error: "None" has no attribute "__iter__" (not iterable)  [attr-defined]
src/pybamm/solvers/summary_variable.py:184: error: Item "None" of "ElectrodeSOHSolver | None" has no attribute "solve"  [union-attr]
src/pybamm/solvers/solution.py:160: error: Missing return statement  [return]
src/pybamm/solvers/processed_variable_time_integral.py:19: error: Argument "initial_condition" to "ProcessedVariableTimeIntegral" has incompatible type "float"; expected "ndarray[Any, dtype[Any]]"  [arg-type]
src/pybamm/solvers/idaklu_jax.py:262: error: Incompatible default for argument "t" (default has type "None", argument has type "ndarray[Any, dtype[Any]]")  [assignment]
src/pybamm/solvers/idaklu_jax.py:262: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
src/pybamm/solvers/idaklu_jax.py:262: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
src/pybamm/solvers/idaklu_jax.py:295: error: Incompatible default for argument "t" (default has type "None", argument has type "ndarray[Any, dtype[Any]]")  [assignment]
src/pybamm/solvers/idaklu_jax.py:295: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
src/pybamm/solvers/idaklu_jax.py:295: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
src/pybamm/expression_tree/functions.py:164: error: Expected iterable as variadic argument  [misc]
src/pybamm/expression_tree/functions.py:173: error: Argument 1 to "_function_new_copy" of "Function" has incompatible type "Symbol"; expected "list[Any]"  [arg-type]
src/pybamm/expression_tree/concatenations.py:479: error: Argument 1 to "intersect" has incompatible type "str | None"; expected "str"  [arg-type]
src/pybamm/expression_tree/concatenations.py:480: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized"  [arg-type]
src/pybamm/expression_tree/concatenations.py:483: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized"  [arg-type]
src/pybamm/expression_tree/concatenations.py:484: error: Value of type "str | None" is not indexable  [index]
src/pybamm/expression_tree/concatenations.py:545: error: Cannot determine type of "child"  [has-type]
src/pybamm/citations.py:34: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
src/pybamm/expression_tree/unary_operators.py:74: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/expression_tree/symbol.py:974: error: Incompatible return value type (got "list[Symbol]", expected "Symbol")  [return-value]
src/pybamm/expression_tree/symbol.py:996: error: Argument 2 to "Symbol" has incompatible type "Symbol"; expected "Sequence[Symbol] | None"  [arg-type]
src/pybamm/expression_tree/broadcasts.py:82: error: "Broadcast" has no attribute "broadcast_domain"  [attr-defined]
src/pybamm/expression_tree/binary_operators.py:131: error: Missing positional argument "right_child" in call to "BinaryOperator"  [call-arg]
src/pybamm/expression_tree/binary_operators.py:131: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/expression_tree/binary_operators.py:135: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/experiment/step/base_step.py:143: error: Argument 3 to "Interpolant" has incompatible type "Subtraction"; expected "Sequence[Symbol] | Time"  [arg-type]
src/pybamm/experiment/experiment.py:59: error: Incompatible types in assignment (expression has type "tuple[str | BaseStep]", variable has type "str | tuple[str] | BaseStep")  [assignment]
src/pybamm/experiment/experiment.py:62: error: Argument 1 to "len" has incompatible type "str | tuple[str] | BaseStep"; expected "Sized"  [arg-type]
src/pybamm/experiment/experiment.py:64: error: Item "BaseStep" of "str | tuple[str] | BaseStep" has no attribute "__iter__" (not iterable)  [union-attr]
src/pybamm/solvers/base_solver.py:97: error: Name "root_method" already defined on line 85  [no-redef]
src/pybamm/solvers/base_solver.py:97: error: "Callable[[BaseSolver], Any]" has no attribute "setter"  [attr-defined]
src/pybamm/solvers/base_solver.py:1124: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]
src/pybamm/solvers/base_solver.py:1125: error: Need type annotation for "initial_conditions"  [var-annotated]
src/pybamm/solvers/base_solver.py:1131: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]
src/pybamm/solvers/base_solver.py:1150: error: Incompatible return value type (got "list[Any]", expected "tuple[Any, ...]")  [return-value]
tests/unit/test_parameters/test_bpx.py:11: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
tests/unit/test_expression_tree/test_binary_operators.py:14: error: Need type annotation for "EMPTY_DOMAINS"  [var-annotated]
examples/scripts/run_ecmd.py:14: error: List item 0 has incompatible type "tuple[str, str, str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/run_ecm.py:9: error: List item 0 has incompatible type "tuple[str, str, str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/minimal_example_of_lookup_tables.py:37: error: Name "D_s_n" already defined on line 25  [no-redef]
examples/scripts/experiment_drive_cycle.py:32: error: List item 0 has incompatible type "tuple[str, str, str, Any, str, Any, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/SPMe_step.py:47: error: Value of type "Any | None" is not indexable  [index]
examples/scripts/SPMe_step.py:49: error: Value of type "Any | None" is not indexable  [index]
examples/scripts/SPM_compare_particle_grid.py:59: error: Value of type "None" is not indexable  [index]
examples/scripts/SPM_compare_particle_grid.py:60: error: Value of type "None" is not indexable  [index]
examples/scripts/MSMR.py:29: error: List item 0 has incompatible type "tuple[str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/experimental_protocols/gitt.py:8: error: List item 0 has incompatible type "tuple[str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/experimental_protocols/cccv.py:10: error: List item 0 has incompatible type "tuple[str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
docs/conf.py:110: error: Need type annotation for "suppress_warnings" (hint: "suppress_warnings: list[<type>] = ...")  [var-annotated]
docs/conf.py:202: error: Incompatible types in assignment (expression has type "bool", target has type "str")  [assignment]
docs/conf.py:221: error: Need type annotation for "latex_elements" (hint: "latex_elements: dict[<type>, <type>] = ...")  [var-annotated]
docs/conf.py:338: error: Incompatible types in assignment (expression has type "str | None", variable has type "str")  [assignment]
docs/conf.py:494: error: Dict entry 0 has incompatible type "str": "ParameterSets"; expected "str": "str"  [dict-item]
Found 59 errors in 26 files (checked 581 source files)

mypy Output (After):

src/pybamm/telemetry.py:18: error: Incompatible types in assignment (expression has type "Posthog", variable has type "MockTelemetry")  [assignment]
src/pybamm/telemetry.py:23: error: "MockTelemetry" has no attribute "log"  [attr-defined]
src/pybamm/config.py:168: error: Library stubs not installed for "yaml"  [import-untyped]
src/pybamm/config.py:168: note: Hint: "python3 -m pip install types-PyYAML"
src/pybamm/config.py:168: note: (or run "mypy --install-types" to install all missing stub packages)
src/pybamm/config.py:168: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports
src/pybamm/solvers/summary_variable.py:43: error: Need type annotation for "_variables" (hint: "_variables: dict[<type>, <type>] = ...")  [var-annotated]
src/pybamm/solvers/summary_variable.py:71: error: Incompatible types in assignment (expression has type "list[SummaryVariables]", variable has type "None")  [assignment]
src/pybamm/solvers/summary_variable.py:72: error: Argument 1 to "len" has incompatible type "None"; expected "Sized"  [arg-type]
src/pybamm/solvers/summary_variable.py:73: error: Value of type "None" is not indexable  [index]
src/pybamm/solvers/summary_variable.py:85: error: Cannot determine type of "_all_variables"  [has-type]
src/pybamm/solvers/summary_variable.py:103: error: Item "None" of "ElectrodeSOHSolver | None" has no attribute "_get_electrode_soh_sims_full"  [union-attr]
src/pybamm/solvers/summary_variable.py:105: error: Incompatible types in assignment (expression has type "list[Any]", variable has type "None")  [assignment]
src/pybamm/solvers/summary_variable.py:126: error: Incompatible return value type (got "ndarray[Any, dtype[Any]]", expected "float | list[float]")  [return-value]
src/pybamm/solvers/summary_variable.py:151: error: "None" has no attribute "__iter__" (not iterable)  [attr-defined]
src/pybamm/solvers/summary_variable.py:153: error: "None" has no attribute "__iter__" (not iterable)  [attr-defined]
src/pybamm/solvers/summary_variable.py:184: error: Item "None" of "ElectrodeSOHSolver | None" has no attribute "solve"  [union-attr]
src/pybamm/solvers/solution.py:160: error: Missing return statement  [return]
src/pybamm/solvers/processed_variable_time_integral.py:20: error: Argument "initial_condition" to "ProcessedVariableTimeIntegral" has incompatible type "float"; expected "ndarray[Any, dtype[floating[_64Bit]]]"  [arg-type]
src/pybamm/solvers/idaklu_jax.py:262: error: Incompatible default for argument "t" (default has type "None", argument has type "ndarray[Any, dtype[floating[_64Bit]]]")  [assignment]
src/pybamm/solvers/idaklu_jax.py:262: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
src/pybamm/solvers/idaklu_jax.py:262: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
src/pybamm/solvers/idaklu_jax.py:295: error: Incompatible default for argument "t" (default has type "None", argument has type "ndarray[Any, dtype[floating[_64Bit]]]")  [assignment]
src/pybamm/solvers/idaklu_jax.py:295: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
src/pybamm/solvers/idaklu_jax.py:295: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
src/pybamm/expression_tree/functions.py:164: error: Expected iterable as variadic argument  [misc]
src/pybamm/expression_tree/functions.py:173: error: Argument 1 to "_function_new_copy" of "Function" has incompatible type "Symbol"; expected "list[Any]"  [arg-type]
src/pybamm/expression_tree/concatenations.py:480: error: Argument 1 to "intersect" has incompatible type "str | None"; expected "str"  [arg-type]
src/pybamm/expression_tree/concatenations.py:481: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized"  [arg-type]
src/pybamm/expression_tree/concatenations.py:484: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized"  [arg-type]
src/pybamm/expression_tree/concatenations.py:485: error: Value of type "str | None" is not indexable  [index]
src/pybamm/expression_tree/concatenations.py:546: error: Cannot determine type of "child"  [has-type]
src/pybamm/citations.py:34: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
src/pybamm/expression_tree/unary_operators.py:74: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/expression_tree/symbol.py:974: error: Incompatible return value type (got "list[Symbol]", expected "Symbol")  [return-value]
src/pybamm/expression_tree/symbol.py:996: error: Argument 2 to "Symbol" has incompatible type "Symbol"; expected "Sequence[Symbol] | None"  [arg-type]
src/pybamm/expression_tree/broadcasts.py:82: error: "Broadcast" has no attribute "broadcast_domain"  [attr-defined]
src/pybamm/expression_tree/binary_operators.py:131: error: Missing positional argument "right_child" in call to "BinaryOperator"  [call-arg]
src/pybamm/expression_tree/binary_operators.py:131: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/expression_tree/binary_operators.py:135: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/experiment/step/base_step.py:143: error: Argument 3 to "Interpolant" has incompatible type "Subtraction"; expected "Sequence[Symbol] | Time"  [arg-type]
src/pybamm/experiment/experiment.py:59: error: Incompatible types in assignment (expression has type "tuple[str | BaseStep]", variable has type "str | tuple[str] | BaseStep")  [assignment]
src/pybamm/experiment/experiment.py:62: error: Argument 1 to "len" has incompatible type "str | tuple[str] | BaseStep"; expected "Sized"  [arg-type]
src/pybamm/experiment/experiment.py:64: error: Item "BaseStep" of "str | tuple[str] | BaseStep" has no attribute "__iter__" (not iterable)  [union-attr]
src/pybamm/solvers/base_solver.py:97: error: Name "root_method" already defined on line 85  [no-redef]
src/pybamm/solvers/base_solver.py:97: error: "Callable[[BaseSolver], Any]" has no attribute "setter"  [attr-defined]
src/pybamm/solvers/base_solver.py:1124: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]
src/pybamm/solvers/base_solver.py:1125: error: Need type annotation for "initial_conditions"  [var-annotated]
src/pybamm/solvers/base_solver.py:1131: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]
src/pybamm/solvers/base_solver.py:1150: error: Incompatible return value type (got "list[Any]", expected "tuple[Any, ...]")  [return-value]
tests/unit/test_parameters/test_bpx.py:11: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
tests/unit/test_expression_tree/test_binary_operators.py:14: error: Need type annotation for "EMPTY_DOMAINS"  [var-annotated]
examples/scripts/run_ecmd.py:14: error: List item 0 has incompatible type "tuple[str, str, str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/run_ecm.py:9: error: List item 0 has incompatible type "tuple[str, str, str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/minimal_example_of_lookup_tables.py:37: error: Name "D_s_n" already defined on line 25  [no-redef]
examples/scripts/experiment_drive_cycle.py:32: error: List item 0 has incompatible type "tuple[str, str, str, Any, str, Any, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/SPMe_step.py:47: error: Value of type "Any | None" is not indexable  [index]
examples/scripts/SPMe_step.py:49: error: Value of type "Any | None" is not indexable  [index]
examples/scripts/SPM_compare_particle_grid.py:59: error: Value of type "None" is not indexable  [index]
examples/scripts/SPM_compare_particle_grid.py:60: error: Value of type "None" is not indexable  [index]
examples/scripts/MSMR.py:29: error: List item 0 has incompatible type "tuple[str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/experimental_protocols/gitt.py:8: error: List item 0 has incompatible type "tuple[str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/experimental_protocols/cccv.py:10: error: List item 0 has incompatible type "tuple[str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
docs/conf.py:110: error: Need type annotation for "suppress_warnings" (hint: "suppress_warnings: list[<type>] = ...")  [var-annotated]
docs/conf.py:202: error: Incompatible types in assignment (expression has type "bool", target has type "str")  [assignment]
docs/conf.py:221: error: Need type annotation for "latex_elements" (hint: "latex_elements: dict[<type>, <type>] = ...")  [var-annotated]
docs/conf.py:338: error: Incompatible types in assignment (expression has type "str | None", variable has type "str")  [assignment]
docs/conf.py:494: error: Dict entry 0 has incompatible type "str": "ParameterSets"; expected "str": "str"  [dict-item]
Found 59 errors in 26 files (checked 581 source files)

@vidipsingh vidipsingh requested a review from a team as a code owner March 8, 2025 18:28
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?

@vidipsingh
Copy link
Contributor Author

Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?

@Saransh-cpp, I have attached the output of mypy run in the PR description.

@agriyakhetarpal
Copy link
Member

Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?

@Saransh-cpp, I have attached the output of mypy run in the PR description.

@vidipsingh, it would be nice if you could:

  • paste the output as text (wrapped in triple backticks, i.e., as code), rather than as an image – so that it is easier to read and copy from or quote
  • display what was fixed with a "Before" v.s. "After" comparison (you may use the GitHub / commands and choose "Details" to wrap the code blocks inside a collapsible dropdown section)

Thank you!

@vidipsingh
Copy link
Contributor Author

vidipsingh commented Mar 10, 2025

Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?

@Saransh-cpp, I have attached the output of mypy run in the PR description.

@vidipsingh, it would be nice if you could:

  • paste the output as text (wrapped in triple backticks, i.e., as code), rather than as an image – so that it is easier to read and copy from or quote
  • display what was fixed with a "Before" v.s. "After" comparison (you may use the GitHub / commands and choose "Details" to wrap the code blocks inside a collapsible dropdown section)

Thank you!

Thanks for the feedback, @agriyakhetarpal!

I’ll replace the image with the mypy output in code blocks and will also add "Before" vs. "After" comparison using collapsible sections.

Just to clarify, are you referring to the "Before" vs. "After" comparison of the mypy run output or the code changes?

@Saransh-cpp
Copy link
Member

Just to clarify, are you referring to the "Before" vs. "After" comparison of the mypy run output or the code changes?

The mypy run!

@vidipsingh
Copy link
Contributor Author

@Saransh-cpp @agriyakhetarpal

I’ve added the mypy run output for both "Before" and "After" in their respective collapsible sections in the PR description.
However, it seems that the outputs are the same for both.

Please let me know if any changes are needed!

Copy link

codecov bot commented Mar 10, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 98.71%. Comparing base (8f91615) to head (cdb7be8).
Report is 1 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #4901   +/-   ##
========================================
  Coverage    98.71%   98.71%           
========================================
  Files          304      304           
  Lines        23509    23519   +10     
========================================
+ Hits         23207    23217   +10     
  Misses         302      302           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @vidipsingh! See my comments below. Could you please also comment why you are using Any in all the places where you are using it? Thank you!

import pybamm


@dataclass
class ProcessedVariableTimeIntegral:
method: Literal["discrete", "continuous"]
initial_condition: npt.NDArray
discrete_times: Optional[npt.NDArray]
initial_condition: npt.NDArray[np.float64]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
initial_condition: npt.NDArray[np.float64]
initial_condition: float | npt.NDArray[np.float64]

@@ -259,7 +259,7 @@ def f_isolated(*args, **kwargs):

def jax_value(
self,
t: npt.NDArray = None,
t: npt.NDArray[np.float64] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t: npt.NDArray[np.float64] = None,
t: npt.NDArray[np.float64] | None = None,

@@ -292,7 +292,7 @@ def jax_value(

def jax_grad(
self,
t: npt.NDArray = None,
t: npt.NDArray[np.float64] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t: npt.NDArray[np.float64] = None,
t: npt.NDArray[np.float64] | None = None,

@@ -396,9 +396,9 @@ def _jax_solve_array_inputs(self, t, inputs_array):

def _jax_solve(
self,
t: Union[float, npt.NDArray],
t: Union[float, npt.NDArray[np.float64]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t: Union[float, npt.NDArray[np.float64]],
t: float | npt.NDArray[np.float64],

@@ -410,7 +410,7 @@ def _jax_solve(

def _jax_jvp_impl(
self,
*args: Union[npt.NDArray],
*args: Union[npt.NDArray[np.float64]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*args: Union[npt.NDArray[np.float64]],
*args: npt.NDArray[np.float64],

@@ -455,9 +455,9 @@ def _jax_jvp_impl_array_inputs(

def _jax_vjp_impl(
self,
y_bar: npt.NDArray,
y_bar: npt.NDArray[np.float64],
invar: Union[str, int], # index or name of input variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
invar: Union[str, int], # index or name of input variable
invar: str | int, # index or name of input variable

initial_condition: npt.NDArray
discrete_times: Optional[npt.NDArray]
initial_condition: npt.NDArray[np.float64]
discrete_times: Optional[npt.NDArray[np.float64]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will -

npt.NDArray[np.float64] | None

work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will -

npt.NDArray[np.float64] | None

work?

I think npt.NDArray[np.float64] | None will work, let me look into it.

@vidipsingh
Copy link
Contributor Author

Thanks, @vidipsingh! See my comments below. Could you please also comment why you are using Any in all the places where you are using it? Thank you!

Thank you for the feedback! I used Any because I wasn't entirely sure of the exact type to apply in those cases. If you have any suggestions, I'd be happy to make the changes.

I'll also work on the other suggested changes. Appreciate it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make npt.NDArray type hints more specific
3 participants