Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyadjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .verification import taylor_test, taylor_to_dict
from .drivers import compute_gradient, compute_derivative, compute_tlm, compute_hessian, solve_adjoint
from .checkpointing import disk_checkpointing_callback
from .reduced_functional import ReducedFunctional
from .reduced_functional import ReducedFunctional, ParametrisedReducedFunctional
from .adjfloat import AdjFloat, exp, log
from .tape import (
Tape,
Expand Down Expand Up @@ -54,6 +54,7 @@
"log",
"Control",
"ReducedFunctional",
"ParametrisedReducedFunctional",
"create_overloaded_object",
"OverloadedType",
"compute_gradient",
Expand Down
115 changes: 115 additions & 0 deletions pyadjoint/reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,118 @@ def marked_controls(self):
finally:
for control in self.controls:
control.unmark_as_control()

class ParametrisedReducedFunctional(ReducedFunctional):
"""Class representing the reduced functional with parameters.

A reduced functional maps a control value to the provided functional.
It may also be used to compute the derivative of the functional with
respect to the control. In addition, parameters may be specified which
are updated, but not included in the derivative calculations.

Args:
functional (:obj:`OverloadedType`): An instance of an OverloadedType,
usually :class:`AdjFloat`. This should be the return value of the
functional you want to reduce.
controls (list[Control]): A list of Control instances, which you want
to map to the functional. It is also possible to supply a single
Control instance instead of a list.
parameters (list): A list of parameters, which are updated, but not included in the derivative.
scale (float): A scaling factor applied to the functional and its
gradient with respect to the control.
tape (Tape): A tape object that the reduced functional will use to
evaluate the functional and its gradients (or derivatives).
eval_cb_pre (function): Callback function before evaluating the
functional. Input is a list of Controls.
eval_cb_pos (function): Callback function after evaluating the
functional. Inputs are the functional value and a list of Controls.
derivative_cb_pre (function): Callback function before evaluating
derivatives. Input is a list of Controls.
Should return a list of Controls (usually the same
list as the input) to be passed to compute_derivative.
derivative_cb_post (function): Callback function after evaluating
derivatives. Inputs are: functional.block_variable.checkpoint,
list of functional derivatives, list of functional values.
Should return a list of derivatives (usually the same
list as the input) to be returned from self.derivative.
hessian_cb_pre (function): Callback function before evaluating the Hessian.
Input is a list of Controls.
hessian_cb_post (function): Callback function after evaluating the Hessian.
Inputs are the functional, a list of Hessian, and controls.
tlm_cb_pre (function): Callback function before evaluating the tangent linear model.
Input is a list of Controls.
tlm_cb_post (function): Callback function after evaluating the tangent linear model.
Inputs are the functional, the tlm result, and controls.
"""

def __init__(self, functional, controls, parameters,
scale=1.0, tape=None,
eval_cb_pre=lambda *args: None,
eval_cb_post=lambda *args: None,
derivative_cb_pre=lambda controls: controls,
derivative_cb_post=lambda checkpoint, derivative_components,
controls: derivative_components,
hessian_cb_pre=lambda *args: None,
hessian_cb_post=lambda *args: None,
tlm_cb_pre=lambda *args: None,
tlm_cb_post=lambda *args: None):




self._parameters = Enlist(parameters)
controls = Enlist(controls)
self.n_opt = len(controls)
derivative_components = tuple(range(self.n_opt)) # Tuple of indices corresponding to optimization controls which are included in derivative calculations.

# Prepare controls + parameters list for base class. By default, parameters are appended after the optimization controls.
all_controls = controls + Enlist(Control(parameters))

super().__init__(functional=functional,
controls=all_controls,
derivative_components=derivative_components,
scale=scale,
tape=tape,
eval_cb_pre=eval_cb_pre,
eval_cb_post=eval_cb_post,
derivative_cb_pre=derivative_cb_pre,
derivative_cb_post=derivative_cb_post,
hessian_cb_pre=hessian_cb_pre,
hessian_cb_post=hessian_cb_post,
tlm_cb_pre=tlm_cb_pre,
tlm_cb_post=tlm_cb_post)




@property
def parameters(self) -> list[Control]:
return self._parameters

@no_annotations
def parameter_update(self, new_parameters):
if len(Enlist(new_parameters)) != len(self._parameters):
raise ValueError(
"new_parameters should be a list of same length as parameters."
)
self._parameters = Enlist(new_parameters)


@no_annotations
def derivative(self, adj_input=1.0, apply_riesz=False):
derivatives_full = super().derivative(adj_input=adj_input, apply_riesz=apply_riesz)
# Return only derivatives corresponding to optimization controls
return Enlist(derivatives_full)[:self.n_opt]


@no_annotations
def __call__(self, values):
values = Enlist(values)
if len(values) != self.n_opt:
raise ValueError("Length of values passed to ParametrisedReducedFunctional" \
" must match the number of optimization controls.")
# concatenate optimization controls + parameters
full_values = values + self._parameters
return super().__call__(full_values)