-
-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PoC: Python bindings #340
base: main
Are you sure you want to change the base?
PoC: Python bindings #340
Conversation
argmin/src/core/executor.rs
Outdated
@@ -26,7 +26,7 @@ pub struct Executor<O, S, I> { | |||
/// Storage for observers | |||
observers: Observers<I>, | |||
/// Checkpoint | |||
checkpoint: Option<Box<dyn Checkpoint<S, I>>>, | |||
checkpoint: Option<Box<dyn Checkpoint<S, I> + Send>>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added these bounds because PyClass
must be Send
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds also good to me. Would it make sense to add the Send
bound to the Checkpoint
trait?
argmin-py/Cargo.toml
Outdated
argmin_testfunctions = "0.1.1" | ||
argmin = {path="../argmin", default-features=false, features=[]} | ||
argmin-math = {path="../argmin-math", features=["ndarray_latest-serde"]} | ||
ndarray-linalg = { version = "0.16", features = ["netlib"] } | ||
ndarray = { version = "0.15", features = ["serde-1"] } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section requires cleanup, I am not sure what's the best configuration of features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks mostly fine. I think in the long run at least the serde1
feature of argmin can be enabled because that would allow checkpointing. But I guess checkpointing will need more work anyways.
At some point we will have to decide which BLAS backend to use. This is probably mostly a platform issue (only Intel-MKL works on Linux, Windows and Mac) and a licensing issue since the compiled code will be packaged into a python module.
argmin-py/src/problem.rs
Outdated
let args = PyTuple::new(py, [param.to_pyarray(py)]); | ||
let pyresult = callable.call(py, args, Default::default())?; | ||
let pyarray = pyresult.extract::<&numpy::PyArray<Scalar, OutputDimension>>(py)?; | ||
// TODO: try to get ownership instead of cloning | ||
Ok(pyarray.to_owned_array()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I am unsure what's the overhead of calling
to_pyarray
andextract
for every evaluation of the gradient, hessian etc. Probably needs benchmarks. to_owned_array
makes a copy of the data, this should not be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just had a new idea: You're currently using the ndarray
backend which requires transitioning between numpy arrays and ndarray arrays. Instead we could add a new math backend based on PyArray
, which would mean that numpy would do all the heavy lifting. I'm not sure whether numpy or ndarray is faster and I haven't really thought this through either.
I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.
Regarding point 2: I agree. I assumed that there is also into_owned_array
but that does not seem to be the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.
I'll give it a try!
Requires #341. |
argmin-py/src/executor.rs
Outdated
let param = kwargs | ||
.get_item("param") | ||
.map(|x| x.extract::<&PyArray1>()) | ||
.map_or(Ok(None), |r| r.map(Some))?; | ||
let max_iters = kwargs | ||
.get_item("max_iters") | ||
.map(|x| x.extract()) | ||
.map_or(Ok(None), |r| r.map(Some))?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like how you transformed the somewhat weird way one needs to set the initial state in argmin (using a closure) into a very pythonic **kwargs
thing. This may however turn into quite a chore given that there are multiple kinds of state (IterState
and PopulationState
) with lots of methods. However, I would label this problem low priority for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I might rewrite this when I add more solvers to the python extension.
argmin-py/src/solver.rs
Outdated
} | ||
|
||
pub enum DynamicSolver { | ||
// NOTE: I tried using a Box<dyn Solver<> here, but Solver is not object safe. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too bad!! That's what I was hoping for, but I tend to forget about object safety. I don't think it'll be possible to make it object safe :(
argmin-py/src/solver.rs
Outdated
} | ||
|
||
impl core::Solver<Problem, IterState> for DynamicSolver { | ||
// TODO: make this a trait method so we can return a dynamic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me! We could have both for backwards compatibility, right? The default impl of the name
method would then just return self.NAME
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems I was able to solve two problems at once: When I remove the associated constant, Solver
becomes object-safe, so we can create trait objects for it. Let me know if you objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I thought the generics would also be a problem for object safety, but it's great if this isn't the case. Sounds good to me!
argmin-py/src/types.rs
Outdated
|
||
//! Base types for the Python extension. | ||
|
||
pub type Scalar = f64; // TODO: allow complex numbers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complex numbers would be great but I wouldn't give this a high priority.
argmin-py/src/executor.rs
Outdated
let param = kwargs | ||
.get_item("param") | ||
.map(|x| x.extract::<&PyArray1>()) | ||
.map_or(Ok(None), |r| r.map(Some))?; | ||
let max_iters = kwargs | ||
.get_item("max_iters") | ||
.map(|x| x.extract()) | ||
.map_or(Ok(None), |r| r.map(Some))?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I might rewrite this when I add more solvers to the python extension.
argmin-py/src/problem.rs
Outdated
let args = PyTuple::new(py, [param.to_pyarray(py)]); | ||
let pyresult = callable.call(py, args, Default::default())?; | ||
let pyarray = pyresult.extract::<&numpy::PyArray<Scalar, OutputDimension>>(py)?; | ||
// TODO: try to get ownership instead of cloning | ||
Ok(pyarray.to_owned_array()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.
I'll give it a try!
argmin-py/src/solver.rs
Outdated
} | ||
|
||
impl core::Solver<Problem, IterState> for DynamicSolver { | ||
// TODO: make this a trait method so we can return a dynamic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems I was able to solve two problems at once: When I remove the associated constant, Solver
becomes object-safe, so we can create trait objects for it. Let me know if you objects.
@@ -1,7 +1,7 @@ | |||
# Implementing a solver | |||
|
|||
In this section we are going to implement the Landweber solver, which essentially is a special form of gradient descent. | |||
In iteration \\( k \\), the new parameter vector \\( x_{k+1} \\) is calculated from the previous parameter vector \\( x_k \\) and the gradient at \\( x_k \\) according to the following update rule: | |||
In iteration \\( k \\), the new parameter vector \\( x\_{k+1} \\) is calculated from the previous parameter vector \\( x_k \\) and the gradient at \\( x_k \\) according to the following update rule: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My autoformatter made most changes to this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure if the change in line 4 would render correctly but surprisingly it does.
This is a PoC for adding Python bindings to argmin. See PR comments for open questions.
Currently, only the Newton solver is implemented.
TODO: