Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Python bindings #340

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft

PoC: Python bindings #340

wants to merge 15 commits into from

Conversation

jjbayer
Copy link
Collaborator

@jjbayer jjbayer commented Mar 19, 2023

This is a PoC for adding Python bindings to argmin. See PR comments for open questions.

Currently, only the Newton solver is implemented.

TODO:

  • Call provided cost functions efficiently.
  • Add more solvers.
  • Add observers.
  • Documentation.
  • Python tests.

@codecov-commenter
Copy link

codecov-commenter commented Mar 19, 2023

Codecov Report

Attention: 168 lines in your changes are missing coverage. Please review.

Comparison is base (e9bebb2) 90.12% compared to head (ee48f38) 89.36%.

Files Patch % Lines
crates/argmin-py/src/executor.rs 0.00% 36 Missing ⚠️
crates/argmin-py/src/problem.rs 0.00% 27 Missing ⚠️
crates/argmin-py/src/solver.rs 0.00% 17 Missing ⚠️
crates/argmin-py/src/lib.rs 0.00% 7 Missing ⚠️
crates/argmin/src/core/test_utils.rs 0.00% 3 Missing ⚠️
crates/argmin/src/solver/brent/brentopt.rs 0.00% 3 Missing ⚠️
crates/argmin/src/solver/brent/brentroot.rs 0.00% 3 Missing ⚠️
crates/argmin/src/solver/conjugategradient/cg.rs 0.00% 3 Missing ⚠️
...rgmin/src/solver/conjugategradient/nonlinear_cg.rs 0.00% 3 Missing ⚠️
...n/src/solver/gaussnewton/gaussnewton_linesearch.rs 0.00% 3 Missing ⚠️
... and 22 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #340      +/-   ##
==========================================
- Coverage   90.12%   89.36%   -0.76%     
==========================================
  Files         162      166       +4     
  Lines       19549    19720     +171     
==========================================
+ Hits        17618    17623       +5     
- Misses       1931     2097     +166     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -26,7 +26,7 @@ pub struct Executor<O, S, I> {
/// Storage for observers
observers: Observers<I>,
/// Checkpoint
checkpoint: Option<Box<dyn Checkpoint<S, I>>>,
checkpoint: Option<Box<dyn Checkpoint<S, I> + Send>>,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these bounds because PyClass must be Send.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds also good to me. Would it make sense to add the Send bound to the Checkpoint trait?

Comment on lines 12 to 16
argmin_testfunctions = "0.1.1"
argmin = {path="../argmin", default-features=false, features=[]}
argmin-math = {path="../argmin-math", features=["ndarray_latest-serde"]}
ndarray-linalg = { version = "0.16", features = ["netlib"] }
ndarray = { version = "0.15", features = ["serde-1"] }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section requires cleanup, I am not sure what's the best configuration of features.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks mostly fine. I think in the long run at least the serde1 feature of argmin can be enabled because that would allow checkpointing. But I guess checkpointing will need more work anyways.
At some point we will have to decide which BLAS backend to use. This is probably mostly a platform issue (only Intel-MKL works on Linux, Windows and Mac) and a licensing issue since the compiled code will be packaged into a python module.

Comment on lines 62 to 66
let args = PyTuple::new(py, [param.to_pyarray(py)]);
let pyresult = callable.call(py, args, Default::default())?;
let pyarray = pyresult.extract::<&numpy::PyArray<Scalar, OutputDimension>>(py)?;
// TODO: try to get ownership instead of cloning
Ok(pyarray.to_owned_array())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I am unsure what's the overhead of calling to_pyarray and extract for every evaluation of the gradient, hessian etc. Probably needs benchmarks.
  2. to_owned_array makes a copy of the data, this should not be necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just had a new idea: You're currently using the ndarray backend which requires transitioning between numpy arrays and ndarray arrays. Instead we could add a new math backend based on PyArray, which would mean that numpy would do all the heavy lifting. I'm not sure whether numpy or ndarray is faster and I haven't really thought this through either.

I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.

Regarding point 2: I agree. I assumed that there is also into_owned_array but that does not seem to be the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.

I'll give it a try!

@jjbayer jjbayer requested a review from stefan-k March 19, 2023 14:22
@jjbayer
Copy link
Collaborator Author

jjbayer commented Mar 19, 2023

Requires #341.

argmin-py/src/executor.rs Outdated Show resolved Hide resolved
Comment on lines 44 to 51
let param = kwargs
.get_item("param")
.map(|x| x.extract::<&PyArray1>())
.map_or(Ok(None), |r| r.map(Some))?;
let max_iters = kwargs
.get_item("max_iters")
.map(|x| x.extract())
.map_or(Ok(None), |r| r.map(Some))?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like how you transformed the somewhat weird way one needs to set the initial state in argmin (using a closure) into a very pythonic **kwargs thing. This may however turn into quite a chore given that there are multiple kinds of state (IterState and PopulationState) with lots of methods. However, I would label this problem low priority for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I might rewrite this when I add more solvers to the python extension.

}

pub enum DynamicSolver {
// NOTE: I tried using a Box<dyn Solver<> here, but Solver is not object safe.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too bad!! That's what I was hoping for, but I tend to forget about object safety. I don't think it'll be possible to make it object safe :(

}

impl core::Solver<Problem, IterState> for DynamicSolver {
// TODO: make this a trait method so we can return a dynamic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! We could have both for backwards compatibility, right? The default impl of the name method would then just return self.NAME.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems I was able to solve two problems at once: When I remove the associated constant, Solver becomes object-safe, so we can create trait objects for it. Let me know if you objects.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I thought the generics would also be a problem for object safety, but it's great if this isn't the case. Sounds good to me!


//! Base types for the Python extension.

pub type Scalar = f64; // TODO: allow complex numbers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex numbers would be great but I wouldn't give this a high priority.

argmin-py/src/executor.rs Outdated Show resolved Hide resolved
Comment on lines 44 to 51
let param = kwargs
.get_item("param")
.map(|x| x.extract::<&PyArray1>())
.map_or(Ok(None), |r| r.map(Some))?;
let max_iters = kwargs
.get_item("max_iters")
.map(|x| x.extract())
.map_or(Ok(None), |r| r.map(Some))?;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I might rewrite this when I add more solvers to the python extension.

Comment on lines 62 to 66
let args = PyTuple::new(py, [param.to_pyarray(py)]);
let pyresult = callable.call(py, args, Default::default())?;
let pyarray = pyresult.extract::<&numpy::PyArray<Scalar, OutputDimension>>(py)?;
// TODO: try to get ownership instead of cloning
Ok(pyarray.to_owned_array())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.

I'll give it a try!

}

impl core::Solver<Problem, IterState> for DynamicSolver {
// TODO: make this a trait method so we can return a dynamic
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems I was able to solve two problems at once: When I remove the associated constant, Solver becomes object-safe, so we can create trait objects for it. Let me know if you objects.

@@ -1,7 +1,7 @@
# Implementing a solver

In this section we are going to implement the Landweber solver, which essentially is a special form of gradient descent.
In iteration \\( k \\), the new parameter vector \\( x_{k+1} \\) is calculated from the previous parameter vector \\( x_k \\) and the gradient at \\( x_k \\) according to the following update rule:
In iteration \\( k \\), the new parameter vector \\( x\_{k+1} \\) is calculated from the previous parameter vector \\( x_k \\) and the gradient at \\( x_k \\) according to the following update rule:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My autoformatter made most changes to this file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if the change in line 4 would render correctly but surprisingly it does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants