diff --git a/pyadjoint/__init__.py b/pyadjoint/__init__.py index 3e500943..48d30af1 100644 --- a/pyadjoint/__init__.py +++ b/pyadjoint/__init__.py @@ -16,7 +16,7 @@ from .verification import taylor_test, taylor_to_dict from .drivers import compute_gradient, compute_derivative, compute_tlm, compute_hessian, solve_adjoint from .checkpointing import disk_checkpointing_callback -from .reduced_functional import ReducedFunctional +from .reduced_functional import ReducedFunctional, ParametrisedReducedFunctional from .adjfloat import AdjFloat, exp, log from .tape import ( Tape, @@ -54,6 +54,7 @@ "log", "Control", "ReducedFunctional", + "ParametrisedReducedFunctional", "create_overloaded_object", "OverloadedType", "compute_gradient", diff --git a/pyadjoint/reduced_functional.py b/pyadjoint/reduced_functional.py index 694e85e4..8d1ebb78 100644 --- a/pyadjoint/reduced_functional.py +++ b/pyadjoint/reduced_functional.py @@ -379,3 +379,118 @@ def marked_controls(self): finally: for control in self.controls: control.unmark_as_control() + +class ParametrisedReducedFunctional(ReducedFunctional): + """Class representing the reduced functional with parameters. + + A reduced functional maps a control value to the provided functional. + It may also be used to compute the derivative of the functional with + respect to the control. In addition, parameters may be specified which + are updated, but not included in the derivative calculations. + + Args: + functional (:obj:`OverloadedType`): An instance of an OverloadedType, + usually :class:`AdjFloat`. This should be the return value of the + functional you want to reduce. + controls (list[Control]): A list of Control instances, which you want + to map to the functional. It is also possible to supply a single + Control instance instead of a list. + parameters (list): A list of parameters, which are updated, but not included in the derivative. + scale (float): A scaling factor applied to the functional and its + gradient with respect to the control. + tape (Tape): A tape object that the reduced functional will use to + evaluate the functional and its gradients (or derivatives). + eval_cb_pre (function): Callback function before evaluating the + functional. Input is a list of Controls. + eval_cb_pos (function): Callback function after evaluating the + functional. Inputs are the functional value and a list of Controls. + derivative_cb_pre (function): Callback function before evaluating + derivatives. Input is a list of Controls. + Should return a list of Controls (usually the same + list as the input) to be passed to compute_derivative. + derivative_cb_post (function): Callback function after evaluating + derivatives. Inputs are: functional.block_variable.checkpoint, + list of functional derivatives, list of functional values. + Should return a list of derivatives (usually the same + list as the input) to be returned from self.derivative. + hessian_cb_pre (function): Callback function before evaluating the Hessian. + Input is a list of Controls. + hessian_cb_post (function): Callback function after evaluating the Hessian. + Inputs are the functional, a list of Hessian, and controls. + tlm_cb_pre (function): Callback function before evaluating the tangent linear model. + Input is a list of Controls. + tlm_cb_post (function): Callback function after evaluating the tangent linear model. + Inputs are the functional, the tlm result, and controls. + """ + + def __init__(self, functional, controls, parameters, + scale=1.0, tape=None, + eval_cb_pre=lambda *args: None, + eval_cb_post=lambda *args: None, + derivative_cb_pre=lambda controls: controls, + derivative_cb_post=lambda checkpoint, derivative_components, + controls: derivative_components, + hessian_cb_pre=lambda *args: None, + hessian_cb_post=lambda *args: None, + tlm_cb_pre=lambda *args: None, + tlm_cb_post=lambda *args: None): + + + + + self._parameters = Enlist(parameters) + controls = Enlist(controls) + self.n_opt = len(controls) + derivative_components = tuple(range(self.n_opt)) # Tuple of indices corresponding to optimization controls which are included in derivative calculations. + + # Prepare controls + parameters list for base class. By default, parameters are appended after the optimization controls. + all_controls = controls + Enlist(Control(parameters)) + + super().__init__(functional=functional, + controls=all_controls, + derivative_components=derivative_components, + scale=scale, + tape=tape, + eval_cb_pre=eval_cb_pre, + eval_cb_post=eval_cb_post, + derivative_cb_pre=derivative_cb_pre, + derivative_cb_post=derivative_cb_post, + hessian_cb_pre=hessian_cb_pre, + hessian_cb_post=hessian_cb_post, + tlm_cb_pre=tlm_cb_pre, + tlm_cb_post=tlm_cb_post) + + + + + @property + def parameters(self) -> list[Control]: + return self._parameters + + @no_annotations + def parameter_update(self, new_parameters): + if len(Enlist(new_parameters)) != len(self._parameters): + raise ValueError( + "new_parameters should be a list of same length as parameters." + ) + self._parameters = Enlist(new_parameters) + + + @no_annotations + def derivative(self, adj_input=1.0, apply_riesz=False): + derivatives_full = super().derivative(adj_input=adj_input, apply_riesz=apply_riesz) + # Return only derivatives corresponding to optimization controls + return Enlist(derivatives_full)[:self.n_opt] + + + @no_annotations + def __call__(self, values): + values = Enlist(values) + if len(values) != self.n_opt: + raise ValueError("Length of values passed to ParametrisedReducedFunctional" \ + " must match the number of optimization controls.") + # concatenate optimization controls + parameters + full_values = values + self._parameters + return super().__call__(full_values) + +