-
-
Notifications
You must be signed in to change notification settings - Fork 600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make npt.NDArray type hints more specific with dtype #4901
base: develop
Are you sure you want to change the base?
Make npt.NDArray type hints more specific with dtype #4901
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @vidipsingh! Could you attach the output of mypy
run in the PR description?
@Saransh-cpp, I have attached the output of |
@vidipsingh, it would be nice if you could:
Thank you! |
Thanks for the feedback, @agriyakhetarpal! I’ll replace the image with the Just to clarify, are you referring to the "Before" vs. "After" comparison of the |
The |
I’ve added the Please let me know if any changes are needed! |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #4901 +/- ##
========================================
Coverage 98.71% 98.71%
========================================
Files 304 304
Lines 23509 23519 +10
========================================
+ Hits 23207 23217 +10
Misses 302 302 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @vidipsingh! See my comments below. Could you please also comment why you are using Any
in all the places where you are using it? Thank you!
import pybamm | ||
|
||
|
||
@dataclass | ||
class ProcessedVariableTimeIntegral: | ||
method: Literal["discrete", "continuous"] | ||
initial_condition: npt.NDArray | ||
discrete_times: Optional[npt.NDArray] | ||
initial_condition: npt.NDArray[np.float64] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initial_condition: npt.NDArray[np.float64] | |
initial_condition: float | npt.NDArray[np.float64] |
@@ -259,7 +259,7 @@ def f_isolated(*args, **kwargs): | |||
|
|||
def jax_value( | |||
self, | |||
t: npt.NDArray = None, | |||
t: npt.NDArray[np.float64] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t: npt.NDArray[np.float64] = None, | |
t: npt.NDArray[np.float64] | None = None, |
@@ -292,7 +292,7 @@ def jax_value( | |||
|
|||
def jax_grad( | |||
self, | |||
t: npt.NDArray = None, | |||
t: npt.NDArray[np.float64] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t: npt.NDArray[np.float64] = None, | |
t: npt.NDArray[np.float64] | None = None, |
@@ -396,9 +396,9 @@ def _jax_solve_array_inputs(self, t, inputs_array): | |||
|
|||
def _jax_solve( | |||
self, | |||
t: Union[float, npt.NDArray], | |||
t: Union[float, npt.NDArray[np.float64]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t: Union[float, npt.NDArray[np.float64]], | |
t: float | npt.NDArray[np.float64], |
@@ -410,7 +410,7 @@ def _jax_solve( | |||
|
|||
def _jax_jvp_impl( | |||
self, | |||
*args: Union[npt.NDArray], | |||
*args: Union[npt.NDArray[np.float64]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*args: Union[npt.NDArray[np.float64]], | |
*args: npt.NDArray[np.float64], |
@@ -455,9 +455,9 @@ def _jax_jvp_impl_array_inputs( | |||
|
|||
def _jax_vjp_impl( | |||
self, | |||
y_bar: npt.NDArray, | |||
y_bar: npt.NDArray[np.float64], | |||
invar: Union[str, int], # index or name of input variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
invar: Union[str, int], # index or name of input variable | |
invar: str | int, # index or name of input variable |
initial_condition: npt.NDArray | ||
discrete_times: Optional[npt.NDArray] | ||
initial_condition: npt.NDArray[np.float64] | ||
discrete_times: Optional[npt.NDArray[np.float64]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will -
npt.NDArray[np.float64] | None
work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will -
npt.NDArray[np.float64] | Nonework?
I think npt.NDArray[np.float64] | None
will work, let me look into it.
Thank you for the feedback! I used I'll also work on the other suggested changes. Appreciate it! |
Description
This PR refines
npt.NDArray
type hints in PyBaMM by adding explicitdtype
(e.g.,np.float64
for time/state arrays,Any
for variable cases).Fixes: #4900
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #)
Important checks:
Please confirm the following before marking the PR as ready for review:
nox -s pre-commit
nox -s tests
nox -s doctests
mypy Output (Before):
mypy Output (After):