diff --git a/pyproject.toml b/pyproject.toml index dc28c323..4daa00ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ - + [build-system] requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" @@ -44,6 +44,7 @@ classifiers = [ dynamic = ["version"] [project.optional-dependencies] +qs = ["bluesky-queueserver-api"] dev = [ "pytest", "pytest-cov", @@ -54,6 +55,7 @@ dev = [ "pandas-stubs", "coverage", "pyright", + "blop[qs]", ] cpu = [ # Empty extra - the source configuration below routes to CPU-only index @@ -83,23 +85,23 @@ local_scheme = "no-local-version" src = ["src", "examples", "docs/source/tutorials"] line-length = 125 lint.select = [ - "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b - "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 - "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e - "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f - "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w - "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i - "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up - "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self - "PLC2701", # private import - https://docs.astral.sh/ruff/rules/import-private-name/ - "LOG015", # root logger call - https://docs.astral.sh/ruff/rules/root-logger-call/ - "S101", # assert - https://docs.astral.sh/ruff/rules/assert/ - "D", # docstring - https://docs.astral.sh/ruff/rules/#pydocstyle-d + "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b + "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 + "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e + "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f + "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w + "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i + "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self + "PLC2701", # private import - https://docs.astral.sh/ruff/rules/import-private-name/ + "LOG015", # root logger call - https://docs.astral.sh/ruff/rules/root-logger-call/ + "S101", # assert - https://docs.astral.sh/ruff/rules/assert/ + "D", # docstring - https://docs.astral.sh/ruff/rules/#pydocstyle-d ] lint.ignore = [ - "D", # TODO: Add docstrings, then enforce these errors - "SLF001", # TODO: Fix private member access, https://github.com/NSLS-II/blop/issues/94 - "B901", # return-in-generator - https://docs.astral.sh/ruff/rules/return-in-generator/ + "D", # TODO: Add docstrings, then enforce these errors + "SLF001", # TODO: Fix private member access, https://github.com/NSLS-II/blop/issues/94 + "B901", # return-in-generator - https://docs.astral.sh/ruff/rules/return-in-generator/ ] lint.preview = true # so that preview mode PLC2701, and LOG015 is enabled @@ -114,20 +116,18 @@ convention = "numpy" [tool.pyright] ignore = [ - "sim/", - "src/blop/tests/", - "src/blop/bayesian/", # TODO: Remove this and fix type errors - "src/blop/ax/qserver_agent.py", # TODO: Remove this and fix type errors + "sim/", + "src/blop/tests/", + "src/blop/bayesian/", # TODO: Remove this and fix type errors + "src/blop/ax/qserver_agent.py", # TODO: Remove this and fix type errors ] # Configure PyTorch CPU-only installation when [cpu] extra is requested # This requires uv (https://docs.astral.sh/uv/) [tool.uv.sources] -torch = [ - { index = "pytorch-cpu", marker = "extra == 'cpu'" }, -] +torch = [{ index = "pytorch-cpu", marker = "extra == 'cpu'" }] [[tool.uv.index]] name = "pytorch-cpu" url = "https://download.pytorch.org/whl/cpu" -explicit = true # Only use this index for torch-related packages +explicit = true # Only use this index for torch-related packages diff --git a/src/blop/ax/__init__.py b/src/blop/ax/__init__.py index e5a192c4..cdc4dc9b 100644 --- a/src/blop/ax/__init__.py +++ b/src/blop/ax/__init__.py @@ -2,11 +2,9 @@ from .dof import DOF, ChoiceDOF, DOFConstraint, RangeDOF from .objective import Objective, OutcomeConstraint, ScalarizedObjective, to_ax_objective_str from .optimizer import AxOptimizer -from .qserver_agent import BlopQserverAgent as QserverAgent __all__ = [ "Agent", - "QserverAgent", "DOF", "RangeDOF", "ChoiceDOF", diff --git a/src/blop/ax/agent.py b/src/blop/ax/agent.py index 76a7057d..ba6ce71c 100644 --- a/src/blop/ax/agent.py +++ b/src/blop/ax/agent.py @@ -1,7 +1,7 @@ import importlib.util import logging from collections.abc import Sequence -from typing import Any, TypeGuard +from typing import Any, TypeGuard, cast from ax import Client from ax.analysis import ContourPlot @@ -14,9 +14,18 @@ from ax.analysis.analysis_card import AnalysisCardBase # type: ignore[import-untyped] # =============================== from bluesky.utils import MsgGenerator +from bluesky_queueserver_api.zmq import REManagerAPI from ..plans import acquire_baseline, optimize, sample_suggestions -from ..protocols import AcquisitionPlan, Actuator, EvaluationFunction, OptimizationProblem, Sensor +from ..protocols import ( + AcquisitionPlan, + Actuator, + EvaluationFunction, + OptimizationProblem, + QueueserverOptimizationProblem, + Sensor, +) +from ..queueserver import QueueserverClient, QueueserverOptimizationRunner from ..utils import InferredReadable from .dof import DOF, DOFConstraint from .objective import Objective, OutcomeConstraint, to_ax_objective_str @@ -33,7 +42,147 @@ def _has_dof_keys(d: dict[DOF, Any] | dict[str, Any]) -> TypeGuard[dict[DOF, Any return all(isinstance(key, DOF) for key in d.keys()) -class Agent: +class _AxAgentMixin: + """ + Mixin providing Ax-related functionality shared by agents. + Expects subclasses to define `self._optimizer` as an `AxOptimizer`. + """ + + _optimizer: AxOptimizer + + @property + def ax_client(self) -> Client: + return self._optimizer.ax_client + + @property + def checkpoint_path(self) -> str | None: + return self._optimizer.checkpoint_path + + @property + def fixed_dofs(self) -> dict[str, Any] | None: + return self._optimizer.fixed_parameters + + @fixed_dofs.setter + def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | dict[str, Any] | None) -> None: + """ + Fix degrees of freedom to a certain value for future optimizations. + + Parameters + ---------- + fixed_dofs : dict[DOF, Any] | dict[str, Any] | None + A mapping of DOFs or DOF names to the values they should be fixed to. + + """ + if not fixed_dofs: + self._optimizer.fixed_parameters = None + return + + if _has_str_keys(fixed_dofs): + self._optimizer.fixed_parameters = fixed_dofs + elif _has_dof_keys(fixed_dofs): + self._optimizer.fixed_parameters = {dof.parameter_name: value for dof, value in fixed_dofs.items()} + else: + raise ValueError( + f"Keys must all be either {type(DOF)} or {type(str)}, but got {type(list(fixed_dofs.keys())[0])}" + ) + + def suggest(self, num_points: int = 1) -> list[dict]: + """ + Get the next point(s) to evaluate in the search space. + + Uses the Bayesian optimization algorithm to suggest promising points based + on all previously acquired data. Each suggestion includes an "_id" key for + tracking. + + Parameters + ---------- + num_points : int, optional + The number of points to suggest. Default is 1. Higher values enable + batch optimization but may reduce optimization efficiency per iteration. + + Returns + ------- + list[dict] + A list of dictionaries, each containing a parameterization of a point to + evaluate next. Each dictionary includes an "_id" key for identification. + """ + return self._optimizer.suggest(num_points) + + def ingest(self, points: list[dict]) -> None: + """ + Ingest evaluation results into the optimizer. + + Updates the optimizer's model with new data. Can ingest both suggested points + (with "_id" key) and external data (without "_id" key). + + Parameters + ---------- + points : list[dict] + A list of dictionaries, each containing outcomes for a trial. For suggested + points, include the "_id" key. For external data, include DOF names and + objective values, and omit "_id". + + Notes + ----- + This method is typically called automatically by :meth:`optimize`. Manual usage + is only needed for custom workflows or when ingesting external data. + + For complete examples, see :doc:`/how-to-guides/attach-data-to-experiments`. + """ + self._optimizer.ingest(points) + + def plot_objective( + self, x_dof_name: str, y_dof_name: str, objective_name: str, *args: Any, **kwargs: Any + ) -> list[AnalysisCardBase]: + """ + Plot the predicted objective as a function of two DOFs. + + Creates a contour plot showing the model's prediction of an objective across + the space defined by two DOFs. Useful for visualizing the optimization landscape. + + Parameters + ---------- + x_dof_name : str + The name of the DOF to plot on the x-axis. + y_dof_name : str + The name of the DOF to plot on the y-axis. + objective_name : str + The name of the objective to plot. + *args : Any + Additional positional arguments passed to Ax's compute_analyses. + **kwargs : Any + Additional keyword arguments passed to Ax's compute_analyses. + + Returns + ------- + list[AnalysisCard] + The computed analysis cards containing the plot data. + + See Also + -------- + ax.analysis.ContourPlot : Pre-built analysis for plotting objectives. + ax.analysis.AnalysisCard : Contains the raw and computed data. + """ + return self.ax_client.compute_analyses( + [ + ContourPlot( + x_parameter_name=x_dof_name, + y_parameter_name=y_dof_name, + metric_name=objective_name, + ), + ], + *args, + **kwargs, + ) + + def checkpoint(self) -> None: + """ + Save the agent's state to a JSON file. + """ + self._optimizer.checkpoint() + + +class Agent(_AxAgentMixin): """ An interface that uses Ax as the backend for optimization and experiment tracking. @@ -93,8 +242,13 @@ def __init__( checkpoint_path: str | None = None, **kwargs: Any, ): + if any(isinstance(dof.actuator, str) for dof in dofs): + dof_actuator_strs = [dof.actuator for dof in dofs if isinstance(dof.actuator, str)] + raise ValueError( + f"DOFs with actuators must be `Actuator` instances, not strings. Got strings for: {dof_actuator_strs}" + ) self._sensors = sensors - self._actuators = [dof.actuator for dof in dofs if dof.actuator is not None] + self._actuators: Sequence[Actuator] = [cast(Actuator, dof.actuator) for dof in dofs if dof.actuator is not None] self._evaluation_function = evaluation_function self._acquisition_plan = acquisition_plan self._optimizer = AxOptimizer( @@ -165,42 +319,6 @@ def evaluation_function(self) -> EvaluationFunction: def acquisition_plan(self) -> AcquisitionPlan | None: return self._acquisition_plan - @property - def ax_client(self) -> Client: - return self._optimizer.ax_client - - @property - def checkpoint_path(self) -> str | None: - return self._optimizer.checkpoint_path - - @property - def fixed_dofs(self) -> dict[str, Any] | None: - return self._optimizer.fixed_parameters - - @fixed_dofs.setter - def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | dict[str, Any] | None) -> None: - """ - Fix degrees of freedom to a certain value for future optimizations. - - Parameters - ---------- - fixed_dofs : dict[DOF, Any] | dict[str, Any] | None - A mapping of DOFs or DOF names to the values they should be fixed to. - - """ - if not fixed_dofs: - self._optimizer.fixed_parameters = None - return - - if _has_str_keys(fixed_dofs): - self._optimizer.fixed_parameters = fixed_dofs - elif _has_dof_keys(fixed_dofs): - self._optimizer.fixed_parameters = {dof.parameter_name: value for dof, value in fixed_dofs.items()} - else: - raise ValueError( - f"Keys must all be either {type(DOF)} or {type(str)}, but got {type(list(fixed_dofs.keys())[0])}" - ) - def to_optimization_problem(self) -> OptimizationProblem: """ Construct an optimization problem from the agent. @@ -227,51 +345,6 @@ def to_optimization_problem(self) -> OptimizationProblem: acquisition_plan=self.acquisition_plan, ) - def suggest(self, num_points: int = 1) -> list[dict]: - """ - Get the next point(s) to evaluate in the search space. - - Uses the Bayesian optimization algorithm to suggest promising points based - on all previously acquired data. Each suggestion includes an "_id" key for - tracking. - - Parameters - ---------- - num_points : int, optional - The number of points to suggest. Default is 1. Higher values enable - batch optimization but may reduce optimization efficiency per iteration. - - Returns - ------- - list[dict] - A list of dictionaries, each containing a parameterization of a point to - evaluate next. Each dictionary includes an "_id" key for identification. - """ - return self._optimizer.suggest(num_points) - - def ingest(self, points: list[dict]) -> None: - """ - Ingest evaluation results into the optimizer. - - Updates the optimizer's model with new data. Can ingest both suggested points - (with "_id" key) and external data (without "_id" key). - - Parameters - ---------- - points : list[dict] - A list of dictionaries, each containing outcomes for a trial. For suggested - points, include the "_id" key. For external data, include DOF names and - objective values, and omit "_id". - - Notes - ----- - This method is typically called automatically by :meth:`optimize`. Manual usage - is only needed for custom workflows or when ingesting external data. - - For complete examples, see :doc:`/how-to-guides/attach-data-to-experiments`. - """ - self._optimizer.ingest(points) - def acquire_baseline(self, parameterization: dict[str, Any] | None = None) -> MsgGenerator[None]: """ Acquire a baseline reading for reference. @@ -363,52 +436,153 @@ def sample_suggestions(self, suggestions: list[dict]) -> MsgGenerator[tuple[str, ) ) - def plot_objective( - self, x_dof_name: str, y_dof_name: str, objective_name: str, *args: Any, **kwargs: Any - ) -> list[AnalysisCardBase]: + +class QueueserverAgent(_AxAgentMixin): + """ + An asynchronous interface that uses Ax as the backend for optimization and experiment tracking + and the bluesky-queueserver-api for scheduling plan execution. + + Parameters + ---------- + re_manager_api : REManagerAPI + The manager API for interaction with Bluesky queueserver. + zmq_consumer_addr : str + A ZMQ address to consume Bluesky messages from, to react to plan execution on the + remote server. + sensors : Sequence[str] + The sensors to use for acquisition. These should be the minimal set + of sensors that are needed to compute the objectives. + dofs : Sequence[DOF] + The degrees of freedom that the agent can control, which determine the search space. + objectives : Sequence[Objective] + The objectives which the agent will try to optimize. + evaluation_function : EvaluationFunction + The function to evaluate acquired data and produce outcomes. + acquisition_plan : str | None, optional + The acquisition plan to use for acquiring data from the beamline. If not provided, + :func:`blop.plans.default_acquire` will be assumed. + dof_constraints : Sequence[DOFConstraint] | None, optional + Constraints on DOFs to refine the search space. + outcome_constraints : Sequence[OutcomeConstraint] | None, optional + Constraints on outcomes to be satisfied during optimization. + checkpoint_path : str | None, optional + The path to the checkpoint file to save the optimizer's state to. + **kwargs : Any + Additional keyword arguments to configure the Ax experiment. + + See Also + -------- + blop.protocols.Sensor : The protocol for sensors. + blop.ax.dof.RangeDOF : For continuous parameters. + blop.ax.dof.ChoiceDOF : For discrete parameters. + blop.ax.objective.Objective : For defining objectives. + blop.ax.optimizer.AxOptimizer : The optimizer used internally. + blop.queueserver.QueueservverOptimizatonRunner : Runner that handles interaction with bluesky-queueserver. + """ + + def __init__( + self, + re_manager_api: REManagerAPI, + zmq_consumer_addr: str, + sensors: Sequence[str], + dofs: Sequence[DOF], + objectives: Sequence[Objective], + evaluation_function: EvaluationFunction, + acquisition_plan: str | None = None, + dof_constraints: Sequence[DOFConstraint] | None = None, + outcome_constraints: Sequence[OutcomeConstraint] | None = None, + checkpoint_path: str | None = None, + **kwargs: Any, + ): + self._sensors = sensors + self._actuators: Sequence[str] = [] + for dof in dofs: + if dof.actuator is not None: + if isinstance(dof.actuator, Actuator): + self._actuators.append(dof.actuator.name) + else: + self._actuators.append(dof.actuator) + self._evaluation_function = evaluation_function + self._acquisition_plan = acquisition_plan + self._optimizer = AxOptimizer( + parameters=[dof.to_ax_parameter_config() for dof in dofs], + objective=to_ax_objective_str(objectives), + parameter_constraints=[constraint.ax_constraint for constraint in dof_constraints] if dof_constraints else None, + outcome_constraints=[constraint.ax_constraint for constraint in outcome_constraints] + if outcome_constraints + else None, + checkpoint_path=checkpoint_path, + **kwargs, + ) + self._runner = QueueserverOptimizationRunner( + self.to_optimization_problem(), + QueueserverClient(re_manager_api, zmq_consumer_addr), + ) + + @property + def evaluation_function(self) -> EvaluationFunction: + return self._evaluation_function + + @property + def actuators(self) -> Sequence[str]: + return self._actuators + + @property + def sensors(self) -> Sequence[str]: + return self._sensors + + @property + def acquisition_plan(self) -> str | None: + return self._acquisition_plan + + def to_optimization_problem(self) -> QueueserverOptimizationProblem: + return QueueserverOptimizationProblem( + optimizer=self._optimizer, + actuators=self._actuators, + sensors=self._sensors, + evaluation_function=self._evaluation_function, + acquisition_plan=self._acquisition_plan, + ) + + def run(self, iterations=1, n_points=1) -> None: """ - Plot the predicted objective as a function of two DOFs. + Start the optimization loop. - Creates a contour plot showing the model's prediction of an objective across - the space defined by two DOFs. Useful for visualizing the optimization landscape. + Validates the queueserver state, then begins the suggest -> acquire -> ingest + cycle. This method returns immediately; the optimization runs asynchronously + via callbacks. Parameters ---------- - x_dof_name : str - The name of the DOF to plot on the x-axis. - y_dof_name : str - The name of the DOF to plot on the y-axis. - objective_name : str - The name of the objective to plot. - *args : Any - Additional positional arguments passed to Ax's compute_analyses. - **kwargs : Any - Additional keyword arguments passed to Ax's compute_analyses. + iterations : int + Number of optimization iterations to run. + num_points : int + Number of points to suggest per iteration. - Returns - ------- - list[AnalysisCard] - The computed analysis cards containing the plot data. - - See Also - -------- - ax.analysis.ContourPlot : Pre-built analysis for plotting objectives. - ax.analysis.AnalysisCard : Contains the raw and computed data. + Raises + ------ + RuntimeError + If the queueserver environment is not ready. + ValueError + If required devices or plans are not available. """ - return self.ax_client.compute_analyses( - [ - ContourPlot( - x_parameter_name=x_dof_name, - y_parameter_name=y_dof_name, - metric_name=objective_name, - ), - ], - *args, - **kwargs, - ) - def checkpoint(self) -> None: + self._runner.run(iterations, n_points) + + def submit_suggestions(self, suggestions: list[dict]) -> None: """ - Save the agent's state to a JSON file. + Evaluate specific parameter combinations. + + Acquires data for given suggestions and ingests results. Supports both + optimizer suggestions and manual points. + + Parameters + ---------- + suggestions : list[dict] + Either optimizer suggestions (with "_id") or manual points (without "_id"). + + See Also + -------- + suggest : Get optimizer suggestions. """ - self._optimizer.checkpoint() + self._runner.submit_suggestions(suggestions) diff --git a/src/blop/ax/dof.py b/src/blop/ax/dof.py index abcd3e3f..a469f0a6 100644 --- a/src/blop/ax/dof.py +++ b/src/blop/ax/dof.py @@ -24,8 +24,8 @@ class DOF(ABC): ---------- name : str | None The name of the DOF. Provide a name if the DOF is not an actuator. - actuator : Actuator | None - The actuator to use for the DOF. Provide an actuator if the DOF is controllable by Bluesky. + actuator : Actuator | str | None + The actuator or its name to use for the DOF. Provide an actuator if the DOF is controllable by Bluesky. Notes ----- @@ -42,7 +42,7 @@ class DOF(ABC): """ name: str | None = None - actuator: Actuator | None = None + actuator: Actuator | str | None = None def __post_init__(self) -> None: if not (bool(self.name) ^ bool(self.actuator)): @@ -50,7 +50,13 @@ def __post_init__(self) -> None: @property def parameter_name(self) -> str: - return self.name or cast(Actuator, self.actuator).name + if isinstance(self.actuator, Actuator): + param_name = self.actuator.name + elif isinstance(self.actuator, str): + param_name = self.actuator + else: + param_name = cast(str, self.name) + return param_name @abstractmethod def to_ax_parameter_config(self) -> RangeParameterConfig | ChoiceParameterConfig: ... diff --git a/src/blop/ax/qserver_agent.py b/src/blop/ax/qserver_agent.py deleted file mode 100644 index 787ce4eb..00000000 --- a/src/blop/ax/qserver_agent.py +++ /dev/null @@ -1,334 +0,0 @@ -import logging - -### -import threading -import uuid -from collections.abc import Sequence -from typing import Any - -from ax.api.types import TParameterization -from bluesky.callbacks import CallbackBase -from bluesky.callbacks.zmq import RemoteDispatcher -from bluesky_queueserver_api import BPlan -from bluesky_queueserver_api.zmq import REManagerAPI - -from ..protocols import EvaluationFunction, Sensor -from .agent import Agent as BlopAxAgent # type: ignore[import-untyped] -from .dof import DOF, DOFConstraint -from .objective import Objective - -logger = logging.getLogger("blop") - - -class ConsumerCallback(CallbackBase): - """ - A child of Callback base which caches the start document and calls a callback function on the stop document - - """ - - def __init__(self, callback: callable = None, enable=True, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.start_doc_cache = None - self.callback = callback # A function that is called when stop is called - self.enable = enable - - def start(self, doc): - if self.enable: - self.start_doc_cache = doc - - def stop(self, doc): - if self.enable: - self.callback(self.start_doc_cache, doc) - self._clear_cache() - - def _clear_cache(self): - self.start_doc_cache = None - - -class ZMQConsumer: - """ - Allows us to start a thread which will listen to docs and call the callback in CallbackBase - """ - - def __init__(self, zmq_consumer_ip_address, zmq_consumer_port, callback: callable = None): - self.zmq_consumer_ip_address = zmq_consumer_ip_address - self.zmq_consumer_port = zmq_consumer_port - - self.zmq_consumer = RemoteDispatcher(f"{self.zmq_consumer_ip_address}:{self.zmq_consumer_port}") - self.zmq_consumer_callback = ConsumerCallback(callback=callback, enable=True) - self.zmq_consumer.subscribe(self.zmq_consumer_callback) - self._zmq_thread = None - - def start_zmq_listener_thread(self): - logger.info("Starting ZMQ Callback Thread") - - self._zmq_thread = threading.Thread(target=self.zmq_consumer.start, name="zmq-consumer", daemon=True) - self._zmq_thread.start() - - -class BlopQserverAgent(BlopAxAgent): - """ - An interface that uses Ax as the backend for optimization and experiment tracking. - - The Agent is the main entry point for setting up and running Bayesian optimization - using Blop. It coordinates the DOFs, objectives, evaluation function, and optimizer - to perform intelligent exploration of the parameter space. - - This class sends JSON strings to a queueserver, rather than emmitting messages to be - consumed directly by a RE. - - Parameters - ---------- - sensors : Sequence[Sensor] - The sensors to use for acquisition. These should be the minimal set - of sensors that are needed to compute the objectives. - dofs : Sequence[DOF] - The degrees of freedom that the agent can control, which determine the search space. - objectives : Sequence[Objective] - The objectives which the agent will try to optimize. - evaluation_function : EvaluationFunction - The function to evaluate acquired data and produce outcomes. - acquisition_plan : str, optional - The name of the plan on the queueserver - dof_constraints : Sequence[DOFConstraint] | None, optional - Constraints on DOFs to refine the search space. - outcome_constraints : Sequence[OutcomeConstraint] | None, optional - Constraints on outcomes to be satisfied during optimization. - qserver_control_addr : str, default="tcp://localhost:60615" - Queueserver Control Address - qserver_info_addr : str, default="tcp://localhost:60625" - Queueserver Info Address - zmq_consumer_ip : str, default= "localhost" - The IP address of the ZMQ proxy to listen for stop document - zmq_consumer_port : str, default= "5578" - The PORT of the ZMQ proxy to listen for stop document - **kwargs : Any - Additional keyword arguments to configure the Ax experiment. - - Notes - ----- - For more complex setups, you can configure the Ax client directly via ``self.ax_client``. - - For complete working examples of creating and using an Agent, see the tutorial - documentation, particularly :doc:`/tutorials/qserver-experiment`. - - - """ - - def __init__( - self, - sensors: Sequence[Sensor], - dofs: Sequence[DOF], - objectives: Sequence[Objective], - evaluation_function: EvaluationFunction = None, - acquisition_plan: str = "acquire", - dof_constraints: Sequence[DOFConstraint] = None, - qserver_control_addr: str = "tcp://localhost:60615", - qserver_info_addr: str = "tcp://localhost:60625", - zmq_consumer_ip: str = "localhost", - zmq_consumer_port: str = "5578", - **kwargs: Any, - ): - super().__init__( - sensors=sensors, - dofs=dofs, - objectives=objectives, - evaluation_function=evaluation_function, - acquisition_plan=acquisition_plan, - dof_constraints=dof_constraints, - **kwargs, - ) - - # Store dofs for qserver-specific operations (parent class only stores actuators) - self._dofs = list(dofs) - - # Instantiate an object that can communicate with the queueserver - self.RM = REManagerAPI( - zmq_control_addr=qserver_control_addr, zmq_info_addr=qserver_info_addr - ) # To Do, Add arguements to class init - - # Should plans be submitted and automatically started, or not? - self._queue_autostart = True - - # Instantiate an object that will listen for start and stop documents and call a method on the stop document - self.zmq_consumer = ZMQConsumer(zmq_consumer_ip, zmq_consumer_port, callback=self._stop_doc_callback) - self.zmq_consumer.start_zmq_listener_thread() - - # Should we do something when there is a new event? - self._listen_to_events = True - - # Should new suggestions be made automatically until all of the trials are complete? - self.continuous_suggestion = True - - # Learning parameters - self.num_itterations = 30 - self.n_points = 1 - - # Variables used to keep track of the current optimization - self.current_itteration = 0 - self.agent_suggestion_uid = None - self.trials = None - self.acquisition_finished = False - self.optimization_problem = None - - @property - def dofs(self) -> Sequence[DOF]: - """The degrees of freedom for this agent.""" - return self._dofs - - def _stop_doc_callback(self, start_doc, stop_doc): - """ - In here we can decide whether our experiment requested has completed - - If it has completed, we can digest the data from it and move on to the next point. - """ - - if self._listen_to_events: - # Mark the current acquisition as finished - - logger.info("A stop document has been received, evaluating") - - # Evaluate it with the evaluation function - outcomes = self.optimization_problem.evaluation_function(uid=self.agent_suggestion_uid, suggestions=self.trials) - - logger.info(f"successfully evaluated id: {self.agent_suggestion_uid}") - - self.acquisition_finished = True - # ingest the data, updating the model of the optimizer - self.optimization_problem.optimizer.ingest(outcomes) - - # After this is complete, call gen_next_trials again if required - if self.continuous_suggestion: - if self.current_itteration < self.num_itterations: - logger.info("making another suggestion") - self.suggest() - else: - self.current_itteration = 0 - logger.info("made all required suggestions") - - def optimize(self, iterations=1, n_points=1): - """ - This method will create the optimization problem, suggest points and execute them in the QS - """ - - # Before we do anything check the connection to the Queueserver - status = self.RM.status() - if not status["worker_environment_exists"]: - raise ValueError("The queueserver environment is not open") - - # Check that the devices we want to interact with are in the queueserver environment - res = self.RM.devices_allowed() - for dof in self.dofs: - if dof.parameter_name not in res["devices_allowed"]: - raise ValueError(f"The device {dof.parameter_name} is not in the Queueserver Environment") - - for sensor in self.sensors: - # Handle both sensor objects (with .name) and string sensor names - sensor_name = sensor.name if hasattr(sensor, "name") else sensor - if sensor_name not in res["devices_allowed"]: - raise ValueError(f"The device {sensor_name} is not in the Queueserver Environment") - - # Check that the plan we want to call is in the queueserver environment - res = self.RM.plans_allowed() - if self._acquisition_plan not in res["plans_allowed"]: - raise ValueError(f"The plan {self._acquisition_plan} is not in the Queueserver Environment") - - # Form the problem and start suggesting points to measure at - self.optimization_problem = self.to_optimization_problem() - self.num_itterations = iterations - self.n_points = n_points - - self.suggest() - - def suggest(self): - """ - get suggestions from the optimizer, then send them to the plan on the queueserver - - """ - - # record this itteration - self.current_itteration = self.current_itteration + 1 - - # Get the trials to perform - self.trials = self.optimization_problem.optimizer._client.get_next_trials(self.n_points) - - # acquire the values from those trials - self.agent_suggestion_uid = self.acquire(self.trials) - logger.info( - f"sending suggestion {self.current_itteration} to queueserver with suggestion id: {self.agent_suggestion_uid}" - ) - - def acquire_baseline(self, parameterization: dict[str, Any] | None = None): - logger.info("This is not implemented for the qserver agent") - - def acquire(self, trials: dict[int, TParameterization] | None = None): - """ - Acquire the new data from the system by submitting the suggested - points to the queueserver. This method does not block while the - queueserver is running. - - Parameters - ---------- - trials : dict[int, TParameterization] - A dictionary mapping trial indices to their suggested parameterizations. - """ - - try: - self.acquisition_finished = False - - # Create a unique identifier which will connect the children to the parent batch - # This batch ID will be used by all runs from this request by the agent. - # It will be used by the EvaluationFunction later to work out what happened. - - agent_suggestion_uid = str(uuid.uuid4()) - kwargs = {} - kwargs.setdefault("md", {}) - - # Add the unique suggestion ID so we can find this run later - kwargs["md"]["agent_suggestion_uid"] = agent_suggestion_uid - - # Add the suggestion _id key so we can work out which number we are on later - suggestions = [ - { - "_id": trial_index, - **parameterization, - } - for trial_index, parameterization in trials.items() - ] - kwargs["md"]["blop_suggestions"] = suggestions - - # Create the BPlan object to send to the queue. Convert dofs to strings - item = BPlan( - self.acquisition_plan, - readables=self.sensors, - dofs=[dof.parameter_name for dof in self.dofs], - trials=trials, - md=kwargs["md"], - ) - - # Send the plan to the Run Engine Manager - r = self.RM.item_add(item) - logger.debug( - f"Sent http-server request for trials {trials}" - f"with agent_suggestion_uid= {agent_suggestion_uid}\n.Received reponse: {r}" - ) - - # If the queue should start automatically, then start the queue. - if self._queue_autostart: - logger.debug("Waiting for Queue to be idle or paused") - self.RM.wait_for_idle_or_paused(timeout=600) - r = self.RM.queue_start() - logger.debug(f"Sent http-server request to start the queue\n.Received reponse: {r}") - - except KeyboardInterrupt as interrupt: - raise interrupt - - return agent_suggestion_uid - - def stop(self): - """ - Stop the agent if it is running - """ - self._queue_autostart = False - self._listen_to_events = False diff --git a/src/blop/protocols.py b/src/blop/protocols.py index dc31f0c6..9373ecab 100644 --- a/src/blop/protocols.py +++ b/src/blop/protocols.py @@ -1,14 +1,36 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, Protocol, runtime_checkable +from typing import Any, Generic, Literal, Protocol, TypeVar, runtime_checkable -from bluesky.protocols import EventCollectable, EventPageCollectable, Flyable, NamedMovable, Readable +from bluesky.protocols import EventCollectable, EventPageCollectable, Flyable, HasName, Movable, Readable from bluesky.utils import MsgGenerator, plan + +@runtime_checkable +class MovableHasName(Movable, HasName, Protocol): + """ + A movable that has a name. + + We use this instead of `bluesky.protocols.NamedMovable` since + we do not want to require `HasHints` on the movable. + + A `Movable` and `HasName` is sufficient. `HasHints` should be optional. + """ + + ... + + ID_KEY: Literal["_id"] = "_id" -Actuator = NamedMovable | Flyable +Actuator = MovableHasName | Flyable Sensor = Readable | EventCollectable | EventPageCollectable +TActuator = TypeVar("TActuator") +"""Actuator generic type""" +TSensor = TypeVar("TSensor") +"""Sensor generic type""" +TPlan = TypeVar("TPlan") +"""Plan generic type""" + @runtime_checkable class CanRegisterSuggestions(Protocol): @@ -175,6 +197,7 @@ def __call__( suggestions: list[dict], actuators: Sequence[Actuator], sensors: Sequence[Sensor] | None = None, + md: dict[str, Any] | None = None, ) -> MsgGenerator[str]: """ Acquire data for optimization. @@ -192,6 +215,8 @@ def __call__( The actuators to move to their suggested positions. sensors: Sequence[Sensor], optional The sensors that produce data to evaluate. + md : dict[str, Any] | None, optional + Metadata to attach to the start document Returns ------- @@ -202,7 +227,15 @@ def __call__( @dataclass(frozen=True) -class OptimizationProblem: +class BaseOptimizationProblem(Generic[TActuator, TSensor, TPlan]): + optimizer: Optimizer + actuators: Sequence[TActuator] + sensors: Sequence[TSensor] + evaluation_function: EvaluationFunction + acquisition_plan: TPlan | None = None + + +class OptimizationProblem(BaseOptimizationProblem[Actuator, Sensor, AcquisitionPlan]): """ An optimization problem to solve. Immutable once initialized. @@ -230,8 +263,37 @@ class OptimizationProblem: blop.plans.optimize : Bluesky plan that uses an OptimizationProblem. """ + ... + + +class QueueserverOptimizationProblem(BaseOptimizationProblem[str, str, str]): + """ + An optimization problem to solve. Immutable once initialized. + + This dataclass encapsulates all components needed for optimization into a single + immutable structure. It is typically created via :meth:`blop.ax.QueueserverAgent.to_optimization_problem` + and used with bluesky-queueserver-api. Actuators, sensors, and the acquisition plan are referenced + by their names, since their instances live on a remote server. + + Attributes + ---------- optimizer: Optimizer - actuators: Sequence[Actuator] - sensors: Sequence[Sensor] + Suggests points to evaluate and ingests outcomes to inform the optimization. + actuators: Sequence[str] + Names of objects that can be moved to control the beamline using the Bluesky RunEngine. + A subset of the actuators' names must match the names of suggested parameterizations. + sensors: Sequence[str] + Names of objects that can produce data to acquire data from the beamline using the Bluesky RunEngine. evaluation_function: EvaluationFunction - acquisition_plan: AcquisitionPlan | None = None + A callable to evaluate data from a Bluesky run and produce outcomes. + acquisition_plan: str, optional + The name of a Bluesky plan to acquire data from the beamline. If not provided, a default plan name will be used. + The plan must match the arguments of :ref:`AcquisitionPlan`. + + See Also + -------- + blop.ax.QueueserverAgent.to_optimization_problem : Creates a QueueserverOptimizationProblem from an agent. + blop.queueserver.QueueserverOptimizationRunner : Runs the optimization loop using the bluesky-queueserver-api. + """ + + ... diff --git a/src/blop/queueserver.py b/src/blop/queueserver.py new file mode 100644 index 00000000..7015e7f3 --- /dev/null +++ b/src/blop/queueserver.py @@ -0,0 +1,418 @@ +""" +Queueserver integration for running optimization through a Bluesky queueserver. + +This module provides components for running optimization loops remotely through +a queueserver, rather than directly through a RunEngine. +""" + +import logging +import threading +import uuid +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Any, Literal + +from bluesky.callbacks import CallbackBase +from bluesky.callbacks.zmq import RemoteDispatcher +from bluesky_queueserver_api import BPlan +from bluesky_queueserver_api.zmq import REManagerAPI +from event_model import RunStart, RunStop + +from .plans import default_acquire +from .protocols import ID_KEY, CanRegisterSuggestions, QueueserverOptimizationProblem + +logger = logging.getLogger("blop") + + +DEFAULT_ACQUIRE_PLAN_NAME: str = default_acquire.__name__ +CORRELATION_UID_KEY: Literal["blop_correlation_uid"] = "blop_correlation_uid" + + +class ConsumerCallback(CallbackBase): + """ + A callback that caches the start document and invokes a callback on stop. + + Parameters + ---------- + callback : callable + Function to call when a stop document is received. + Signature: callback(start_doc, stop_doc) + """ + + def __init__(self, callback: Callable[[RunStart, RunStop], None] | None = None): + super().__init__() + self._start_doc_cache: RunStart | None = None + self._callback = callback + + def start(self, doc: RunStart) -> None: + """Caches the start document if it came from Blop""" + if doc.get(CORRELATION_UID_KEY, None): + self._start_doc_cache = doc + else: + self._start_doc_cache = None + + def stop(self, doc: RunStop) -> None: + """Executes the callback if the cached start and stop document match""" + if self._callback is not None and self._start_doc_cache is not None and self._start_doc_cache["uid"] == doc["uid"]: + self._callback(self._start_doc_cache, doc) + self._start_doc_cache = None + + +class QueueserverClient: + """ + Handles communication with a Bluesky queueserver. + + This class encapsulates all ZMQ and HTTP communication with the queueserver, + including plan submission and event listening. + + Parameters + ---------- + re_manager_api : bluesky_queueserver_api.zmq.REManagerAPI + Manager instance for communication with Bluesky Queueserver + zmq_consumer_addr : str + Address for ZMQ document consumer (e.g., "localhost:5578"). + """ + + def __init__( + self, + re_manager_api: REManagerAPI, + zmq_consumer_addr: str, + ): + self._zmq_consumer_addr = zmq_consumer_addr + + self._rm = re_manager_api + self._dispatcher: RemoteDispatcher | None = None + self._consumer_callback: ConsumerCallback | None = None + self._listener_thread: threading.Thread | None = None + + def check_environment(self) -> None: + """ + Verify that the queueserver environment is ready. + + Raises + ------ + RuntimeError + If the queueserver environment is not open. + """ + status = self._rm.status() + if status is None or not status.get("worker_environment_exists", False): + raise RuntimeError("The queueserver environment is not open") + + def check_devices_available(self, device_names: Sequence[str]) -> None: + """ + Verify that all specified devices are available in the queueserver. + + Parameters + ---------- + device_names : Sequence[str] + Names of devices to check. + + Raises + ------ + ValueError + If any device is not available. + """ + res = self._rm.devices_allowed() + allowed = res["devices_allowed"] + for name in device_names: + if name not in allowed: + raise ValueError(f"Device '{name}' is not available in the queueserver environment") + + def check_plan_available(self, plan_name: str) -> None: + """ + Verify that a plan is available in the queueserver. + + Parameters + ---------- + plan_name : str + Name of the plan to check. + + Raises + ------ + ValueError + If the plan is not available. + """ + res = self._rm.plans_allowed() + if plan_name not in res["plans_allowed"]: + raise ValueError(f"Plan '{plan_name}' is not available in the queueserver environment") + + def submit_plan(self, plan: BPlan, autostart: bool = True, timeout: int = 600) -> None: + """ + Submit a plan to the queueserver queue. + + Parameters + ---------- + plan : BPlan + The plan to submit. + autostart : bool, optional + If True, start the queue after adding the plan. + timeout : float, optional + Timeout in seconds when waiting for queue to be idle. + """ + response = self._rm.item_add(plan) + logger.debug(f"Submitted plan to queue. Response: {response}") + + if autostart: + logger.debug("Waiting for queue to be idle or paused") + self._rm.wait_for_idle_or_paused(timeout=timeout) + response = self._rm.queue_start() + logger.debug(f"Started queue. Response: {response}") + + def start_listener(self, on_stop: Callable[[RunStart, RunStop], None]) -> None: + """ + Start listening for document events from the queueserver. + + Parameters + ---------- + on_stop : callable + Callback invoked when a stop document is received. + Signature: on_stop(start_doc, stop_doc) + """ + if self._listener_thread is not None: + logger.warning("Listener already running") + return + + dispatcher = RemoteDispatcher(self._zmq_consumer_addr) + self._consumer_callback = ConsumerCallback(callback=on_stop) + dispatcher.subscribe(self._consumer_callback) + + logger.info("Starting ZMQ listener thread") + self._listener_thread = threading.Thread( + target=dispatcher.start, + name="qserver-zmq-consumer", + daemon=True, + ) + self._listener_thread.start() + self._dispatcher = dispatcher + + def stop_listener(self) -> None: + """Stop the ZMQ listener thread.""" + if self._dispatcher is not None: + self._dispatcher.stop() + self._dispatcher = None + self._consumer_callback = None + self._listener_thread = None + logger.info("Stopped ZMQ listener") + + +@dataclass +class _OptimizationState: + """Internal mutable state for an optimization run.""" + + max_iterations: int = 1 + num_points: int = 1 + current_iteration: int = 0 + current_suggestions: list[dict] = field(default_factory=list) + current_uid: str | None = None + finished: bool = False + + +class QueueserverOptimizationRunner: + """ + Runs optimization loops through a Bluesky queueserver. + + This class coordinates the optimization workflow by getting suggestions from + the optimizer, submitting acquisition plans to the queueserver, and ingesting + results when plans complete. + + Parameters + ---------- + optimization_problem : QueueserverOptimizationProblem + The optimization problem to solve, containing the optimizer, actuators, + sensors, and evaluation function. + queueserver_client : QueueserverClient + Client for communicating with the queueserver. + acquisition_plan_name : str + Name of the acquisition plan registered in the queueserver. + """ + + def __init__( + self, + optimization_problem: QueueserverOptimizationProblem, + queueserver_client: QueueserverClient, + ): + self._problem = optimization_problem + self._client = queueserver_client + self._plan_name = optimization_problem.acquisition_plan or DEFAULT_ACQUIRE_PLAN_NAME + self._state: _OptimizationState | None = None + self._continuous = True + self._autostart = True + + @property + def optimization_problem(self) -> QueueserverOptimizationProblem: + """The optimization problem being solved.""" + return self._problem + + @property + def is_running(self) -> bool: + """Whether an optimization run is currently in progress.""" + return self._state is not None and not self._state.finished + + @property + def current_iteration(self) -> int: + """The current iteration number (0 if not running).""" + return self._state.current_iteration if self._state else 0 + + def run(self, iterations: int = 1, num_points: int = 1) -> None: + """ + Start the optimization loop. + + Validates the queueserver state, then begins the suggest -> acquire -> ingest + cycle. This method returns immediately; the optimization runs asynchronously + via callbacks. + + Parameters + ---------- + iterations : int + Number of optimization iterations to run. + num_points : int + Number of points to suggest per iteration. + + Raises + ------ + RuntimeError + If the queueserver environment is not ready. + ValueError + If required devices or plans are not available. + """ + # TODO: What if there is already a run in-progress? + self._validate() + self._state = _OptimizationState(max_iterations=iterations, num_points=num_points) + self._continuous = True + self._client.start_listener(on_stop=self._on_acquisition_complete) + self._submit_next() + + def submit_suggestions(self, suggestions: list[dict]) -> None: + """ + Manually submit suggestions to the queue. This method returns immediately; the + optimization runs asynchronously via callbacks. + + Parameters + ---------- + suggestions : list[dict] + Parameter combinations to evaluate. Can be: + + - Optimizer suggestions (with "_id" keys from suggest()) + - Manual points (without "_id", requires CanRegisterSuggestions protocol) + """ + self._validate() + self._state = _OptimizationState(max_iterations=1, num_points=len(suggestions)) + self._continuous = False + self._client.start_listener(on_stop=self._on_acquisition_complete) + self._submit_next_manual(suggestions) + + def stop(self) -> None: + """ + Stop the optimization loop gracefully. + + The current acquisition will complete, but no further iterations will run. + """ + self._continuous = False + self._client.stop_listener() + if self._state is not None: + self._state.finished = True + logger.info("Optimization stopped") + + def _validate(self) -> None: + """Validate queueserver environment, devices, and plan availability.""" + self._client.check_environment() + + # Collect device names from actuators and sensors + actuator_names = list(self._problem.actuators) + sensor_names = list(self._problem.sensors) + self._client.check_devices_available(actuator_names + sensor_names) + + self._client.check_plan_available(self._plan_name) + + def _submit_next_manual(self, suggestions: list[dict]) -> None: + """Get suggestions from optimizer and submit plan to queueserver.""" + if self._state is None: + raise RuntimeError("_submit_next called before run()") + + # Ensure the suggestions have an ID_KEY or register them with the optimizer + if not isinstance(self.optimization_problem.optimizer, CanRegisterSuggestions) and any( + ID_KEY not in suggestion for suggestion in suggestions + ): + raise ValueError( + f"All suggestions must contain an '{ID_KEY}' key to later match with the outcomes or your optimizer must " + "implement the `blop.protocols.CanRegisterSuggestions` protocol. Please review your optimizer " + f"implementation. Got suggestions: {suggestions}" + ) + elif isinstance(self.optimization_problem.optimizer, CanRegisterSuggestions): + suggestions = self.optimization_problem.optimizer.register_suggestions(suggestions) + + self._state.current_iteration += 1 + self._state.current_suggestions = suggestions + self._state.current_uid = str(uuid.uuid4()) + + logger.info(f"Submitting manually specified suggestion(s) with correlation uid: {self._state.current_uid}") + + plan = self._build_plan() + self._client.submit_plan(plan, autostart=self._autostart) + + def _submit_next(self) -> None: + """Get suggestions from optimizer and submit plan to queueserver.""" + if self._state is None: + raise RuntimeError("_submit_next called before run()") + self._state.current_iteration += 1 + self._state.current_suggestions = self._problem.optimizer.suggest(self._state.num_points) + self._state.current_uid = str(uuid.uuid4()) + + logger.info( + f"Submitting iteration {self._state.current_iteration}/{self._state.max_iterations} " + f"with correlation uid: {self._state.current_uid}" + ) + + plan = self._build_plan() + self._client.submit_plan(plan, autostart=self._autostart) + + def _build_plan(self) -> BPlan: + """Construct a BPlan from the current suggestions.""" + if self._state is None: + raise RuntimeError("_build_plan called before run()") + # Build metadata + md: dict[str, Any] = { + CORRELATION_UID_KEY: self._state.current_uid, + "blop_suggestions": self._state.current_suggestions, + } + + return BPlan( + self._plan_name, + self._state.current_suggestions, + list(self._problem.actuators), + list(self._problem.sensors), + md=md, + ) + + def _on_acquisition_complete(self, start_doc: RunStart, stop_doc: RunStop) -> None: + """Callback when acquisition finishes. Ingest results and maybe continue.""" + if self._state is None: + raise RuntimeError("_on_acquisition_complete called before run()") + if self._state.current_uid is None: + raise RuntimeError("current_uid not set") + if self._state.current_uid != start_doc.get("blop_correlation_uid", None): + raise RuntimeError( + "current_uid did not match start document. " + f"Got: {start_doc.get('blop_correlation_uid', None)}, Expected: {self._state.current_uid}" + ) + logger.info(f"Acquisition complete for uid: {self._state.current_uid}") + + # Evaluate the results + outcomes = self._problem.evaluation_function( + uid=start_doc["uid"], + suggestions=self._state.current_suggestions, + ) + + logger.info(f"Evaluated {len(outcomes)} outcomes") + + # Ingest into optimizer + self._problem.optimizer.ingest(outcomes) + + # Continue if appropriate + if self._continuous and self._state.current_iteration < self._state.max_iterations: + logger.info("Continuing to next iteration") + self._submit_next() + else: + self._state.finished = True + self._client.stop_listener() + logger.info(f"Optimization complete after {self._state.current_iteration} iterations") diff --git a/src/blop/tests/ax/test_agent.py b/src/blop/tests/ax/test_agent.py index 9f67f821..abababfc 100644 --- a/src/blop/tests/ax/test_agent.py +++ b/src/blop/tests/ax/test_agent.py @@ -1,10 +1,11 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np import pytest from ax import Client +from bluesky_queueserver_api.zmq import REManagerAPI -from blop.ax.agent import Agent +from blop.ax.agent import Agent, QueueserverAgent from blop.ax.dof import DOFConstraint, RangeDOF from blop.ax.objective import Objective from blop.ax.optimizer import AxOptimizer @@ -23,6 +24,11 @@ def mock_acquisition_plan(): return MagicMock(spec=AcquisitionPlan) +@pytest.fixture(scope="function") +def mock_re_manager_api(): + return MagicMock(spec=REManagerAPI) + + def test_agent_init(mock_evaluation_function, mock_acquisition_plan): """Test that the agent can be initialized.""" movable1 = MovableSignal(name="test_movable1") @@ -211,3 +217,97 @@ def test_ingest_baseline(mock_evaluation_function): summary_df = agent.ax_client.summarize() assert len(summary_df) == 1 assert summary_df["arm_name"].values[0] == "baseline" + + +def test_agent_init_actuator_string_raises(mock_evaluation_function): + dof1 = RangeDOF(actuator="test_movable1", bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_movable2", bounds=(0, 10), parameter_type="float") + objective = Objective(name="test_objective", minimize=False) + + with pytest.raises(ValueError, match="not strings"): + Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation_function=mock_evaluation_function) + + +def test_queueserver_agent_init(mock_re_manager_api, mock_evaluation_function): + dof1 = RangeDOF(actuator="test_motor1", bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_motor2", bounds=(0, 10), parameter_type="float") + agent = QueueserverAgent( + mock_re_manager_api, + "inproc://test", + ["det"], + [dof1, dof2], + [Objective(name="obj1", minimize=False)], + mock_evaluation_function, + ) + assert agent.sensors == ["det"] + assert agent.actuators == [dof1.actuator, dof2.actuator] + assert agent.evaluation_function == mock_evaluation_function + assert agent.acquisition_plan is None + assert isinstance(agent.ax_client, Client) + + problem = agent.to_optimization_problem() + assert problem.acquisition_plan is None + assert problem.actuators == [dof1.parameter_name, dof2.parameter_name] + assert problem.sensors == ["det"] + assert problem.evaluation_function == mock_evaluation_function + + +def test_queueserver_agent_init_actuator_instance(mock_re_manager_api, mock_evaluation_function): + movable1 = MovableSignal(name="test_movable1") + dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_movable2", bounds=(0, 10), parameter_type="float") + agent = QueueserverAgent( + mock_re_manager_api, + "inproc://test", + ["det"], + [dof1, dof2], + [Objective(name="obj1", minimize=False)], + mock_evaluation_function, + ) + + assert agent.actuators == [movable1.name, dof2.parameter_name] + + +@patch("blop.ax.agent.QueueserverClient") +@patch("blop.ax.agent.QueueserverOptimizationRunner") +def test_queueserver_agent_run( + mock_queueserver_runner_cls, mock_queueserver_client_cls, mock_re_manager_api, mock_evaluation_function +): + dof1 = RangeDOF(actuator="test_motor1", bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_motor2", bounds=(0, 10), parameter_type="float") + agent = QueueserverAgent( + mock_re_manager_api, + "inproc://test", + ["det"], + [dof1, dof2], + [Objective(name="obj1", minimize=False)], + mock_evaluation_function, + ) + mock_queueserver_client_cls.assert_called_once() + mock_queueserver_runner_cls.assert_called_once() + + agent.run() + mock_queueserver_runner_cls.return_value.run.assert_called_once_with(1, 1) + + +@patch("blop.ax.agent.QueueserverClient") +@patch("blop.ax.agent.QueueserverOptimizationRunner") +def test_queueserver_agent_submit_suggestions( + mock_queueserver_runner_cls, mock_queueserver_client_cls, mock_re_manager_api, mock_evaluation_function +): + dof1 = RangeDOF(actuator="test_motor1", bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_motor2", bounds=(0, 10), parameter_type="float") + agent = QueueserverAgent( + mock_re_manager_api, + "inproc://test", + ["det"], + [dof1, dof2], + [Objective(name="obj1", minimize=False)], + mock_evaluation_function, + ) + mock_queueserver_client_cls.assert_called_once() + mock_queueserver_runner_cls.assert_called_once() + + suggestions = [{"test_motor1": 5, "test_motor2": 9}] + agent.submit_suggestions(suggestions) + mock_queueserver_runner_cls.return_value.submit_suggestions.assert_called_once_with(suggestions) diff --git a/src/blop/tests/ax/test_qserver_agent.py b/src/blop/tests/ax/test_qserver_agent.py deleted file mode 100644 index c1b85562..00000000 --- a/src/blop/tests/ax/test_qserver_agent.py +++ /dev/null @@ -1,227 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from blop.ax.dof import RangeDOF -from blop.ax.objective import Objective -from blop.ax.qserver_agent import BlopQserverAgent, ConsumerCallback, ZMQConsumer -from blop.protocols import EvaluationFunction - -from ..conftest import MovableSignal, ReadableSignal - - -@pytest.fixture(scope="function") -def mock_evaluation_function(): - return MagicMock(spec=EvaluationFunction) - - -@pytest.fixture(scope="function") -def basic_dofs(): - movable1 = MovableSignal(name="test_movable1") - movable2 = MovableSignal(name="test_movable2") - dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") - dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float") - return [dof1, dof2] - - -@pytest.fixture(scope="function") -def basic_sensors(): - return [ReadableSignal(name="test_readable")] - - -@pytest.fixture(scope="function") -def basic_objective(): - return Objective(name="test_objective", minimize=False) - - -@patch("blop.ax.qserver_agent.REManagerAPI") -@patch("blop.ax.qserver_agent.ZMQConsumer") -def test_qserver_agent_init( - mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective -): - """Test that the qserver agent can be initialized with proper components.""" - agent = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - acquisition_plan="test_plan", - qserver_control_addr="tcp://localhost:60615", - qserver_info_addr="tcp://localhost:60625", - zmq_consumer_ip="localhost", - zmq_consumer_port="5578", - ) - - # Test public properties - assert agent.sensors == basic_sensors - assert list(agent.dofs) == basic_dofs - assert agent.continuous_suggestion is True - assert agent.num_itterations == 30 - assert agent.n_points == 1 - assert agent.current_itteration == 0 - - # Verify REManagerAPI was initialized with correct addresses - mock_re_manager.assert_called_once_with( - zmq_control_addr="tcp://localhost:60615", - zmq_info_addr="tcp://localhost:60625", - ) - # Verify ZMQ consumer was started - mock_zmq_consumer.assert_called_once() - mock_zmq_consumer.return_value.start_zmq_listener_thread.assert_called_once() - - -@patch("blop.ax.qserver_agent.REManagerAPI") -@patch("blop.ax.qserver_agent.ZMQConsumer") -def test_qserver_agent_optimize_validation( - mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective -): - """Test that optimize validates qserver environment, devices, and plans.""" - # Test 1: Error when qserver environment is not open - agent = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - ) - agent.RM.status.return_value = {"worker_environment_exists": False} - - with pytest.raises(ValueError, match="queueserver environment is not open"): - agent.optimize(iterations=1, n_points=1) - - # Test 2: Error when required device is not in qserver (uses actuator-based DOFs) - agent2 = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - ) - agent2.RM.status.return_value = {"worker_environment_exists": True} - agent2.RM.devices_allowed.return_value = {"devices_allowed": {}} - - with pytest.raises(ValueError, match="device test_movable1 is not in the Queueserver Environment"): - agent2.optimize(iterations=1, n_points=1) - - # Test 3: Error when acquisition plan is not in qserver - agent3 = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - acquisition_plan="missing_plan", - ) - agent3.RM.status.return_value = {"worker_environment_exists": True} - agent3.RM.devices_allowed.return_value = { - "devices_allowed": {"test_movable1": {}, "test_movable2": {}, "test_readable": {}} - } - agent3.RM.plans_allowed.return_value = {"plans_allowed": {}} - - with pytest.raises(ValueError, match="plan missing_plan is not in the Queueserver Environment"): - agent3.optimize(iterations=1, n_points=1) - - -@patch("blop.ax.qserver_agent.REManagerAPI") -@patch("blop.ax.qserver_agent.ZMQConsumer") -def test_qserver_agent_acquire( - mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective -): - """Test that acquire submits a plan to the qserver with proper metadata and starts queue.""" - agent = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - acquisition_plan="acquire", - ) - - trials = {0: {"test_movable1": 5.0, "test_movable2": 3.0}} - uid = agent.acquire(trials) - - # Verify uid is returned - assert uid is not None - assert isinstance(uid, str) - assert agent.acquisition_finished is False - - # Verify plan was submitted to qserver - agent.RM.item_add.assert_called_once() - call_args = agent.RM.item_add.call_args - bplan = call_args[0][0] - assert bplan.name == "acquire" - assert bplan.kwargs["md"]["agent_suggestion_uid"] == uid - assert bplan.kwargs["md"]["blop_suggestions"] == [{"_id": 0, "test_movable1": 5.0, "test_movable2": 3.0}] - - # Verify queue was started (autostart is enabled by default) - agent.RM.wait_for_idle_or_paused.assert_called_once_with(timeout=600) - agent.RM.queue_start.assert_called_once() - - -@patch("blop.ax.qserver_agent.REManagerAPI") -@patch("blop.ax.qserver_agent.ZMQConsumer") -def test_qserver_agent_stop( - mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective -): - """Test that stop prevents the agent from auto-starting the queue on acquire.""" - agent = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - ) - - # Stop the agent - agent.stop() - - # Now acquire should not start the queue - trials = {0: {"test_movable1": 5.0, "test_movable2": 3.0}} - agent.acquire(trials) - - # Verify plan was still submitted - agent.RM.item_add.assert_called_once() - # But queue was NOT started - agent.RM.queue_start.assert_not_called() - - -def test_consumer_callback(): - """Test ConsumerCallback caches start doc, calls callback on stop, and clears cache.""" - mock_callback = MagicMock() - callback = ConsumerCallback(callback=mock_callback, enable=True) - start_doc = {"uid": "test-uid", "time": 123} - stop_doc = {"uid": "test-uid", "exit_status": "success"} - - # Test caching on start - callback.start(start_doc) - assert callback.start_doc_cache == start_doc - - # Test callback invocation on stop and cache clearing - callback.stop(stop_doc) - mock_callback.assert_called_once_with(start_doc, stop_doc) - assert callback.start_doc_cache is None - - # Test disabled callback does nothing - disabled_callback = ConsumerCallback(callback=MagicMock(), enable=False) - disabled_callback.start(start_doc) - disabled_callback.stop(stop_doc) - assert disabled_callback.start_doc_cache is None - disabled_callback.callback.assert_not_called() - - -@patch("blop.ax.qserver_agent.RemoteDispatcher") -def test_zmq_consumer(mock_remote_dispatcher): - """Test ZMQConsumer initialization and thread startup.""" - mock_callback = MagicMock() - - consumer = ZMQConsumer( - zmq_consumer_ip_address="localhost", - zmq_consumer_port="5578", - callback=mock_callback, - ) - - # Test public attributes - assert consumer.zmq_consumer_ip_address == "localhost" - assert consumer.zmq_consumer_port == "5578" - - # Verify RemoteDispatcher was initialized and subscribed - mock_remote_dispatcher.assert_called_once_with("localhost:5578") - mock_remote_dispatcher.return_value.subscribe.assert_called_once() - - # Test thread startup (verify it doesn't raise) - consumer.start_zmq_listener_thread() diff --git a/src/blop/tests/test_queueserver.py b/src/blop/tests/test_queueserver.py new file mode 100644 index 00000000..594e7c93 --- /dev/null +++ b/src/blop/tests/test_queueserver.py @@ -0,0 +1,344 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from blop.protocols import CanRegisterSuggestions, Optimizer, QueueserverOptimizationProblem +from blop.queueserver import CORRELATION_UID_KEY, ConsumerCallback, QueueserverClient, QueueserverOptimizationRunner + + +@pytest.fixture(scope="function") +def mock_optimization_problem(): + """Create a mock OptimizationProblem with necessary components.""" + mock_optimizer = MagicMock() + mock_optimizer.suggest.return_value = [ + {"_id": 0, "motor1": 5.0, "motor2": 3.0}, + ] + + mock_eval_func = MagicMock() + mock_eval_func.return_value = [{"_id": 0, "objective": 1.0}] + + return QueueserverOptimizationProblem( + optimizer=mock_optimizer, + actuators=["motor1", "motor2"], + sensors=["detector"], + evaluation_function=mock_eval_func, + ) + + +def test_consumer_callback_caches_start_and_calls_on_stop(): + """Test ConsumerCallback caches start doc and calls callback on stop.""" + mock_callback = MagicMock() + callback = ConsumerCallback(callback=mock_callback) + start_doc = {"uid": "test-uid", CORRELATION_UID_KEY: "123", "time": 123} + stop_doc = {"uid": "test-uid", "exit_status": "success"} + + callback.start(start_doc) + mock_callback.assert_not_called() + + callback.stop(stop_doc) + mock_callback.assert_called_once_with(start_doc, stop_doc) + + +def test_consumer_callback_clears_cache_after_stop(): + """Test ConsumerCallback clears cache after stop is called.""" + mock_callback = MagicMock() + callback = ConsumerCallback(callback=mock_callback) + start_doc = {"uid": "test-uid", CORRELATION_UID_KEY: "123"} + stop_doc = {"uid": "test-uid"} + + callback.start(start_doc) + callback.stop(stop_doc) + + # Second stop should not call callback (no cached start doc) + callback.stop(stop_doc) + assert mock_callback.call_count == 1 + + +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_check_environment_raises_when_not_ready(mock_re_manager): + """Test check_environment raises RuntimeError when environment not open.""" + mock_re_manager.status.return_value = {"worker_environment_exists": False} + client = QueueserverClient(mock_re_manager, "inproc://test") + + with pytest.raises(RuntimeError, match="queueserver environment is not open"): + client.check_environment() + + +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_check_devices_raises_for_missing_device(mock_re_manager): + """Test check_devices_available raises ValueError for missing devices.""" + mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}}} + client = QueueserverClient(mock_re_manager, "inproc://test") + + with pytest.raises(ValueError, match="Device 'motor2' is not available"): + client.check_devices_available(["motor1", "motor2"]) + + +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_check_plan_raises_for_missing_plan(mock_re_manager): + """Test check_plan_available raises ValueError for missing plan.""" + mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"other_plan": {}}} + client = QueueserverClient(mock_re_manager, "inproc://test") + + with pytest.raises(ValueError, match="Plan 'my_plan' is not available"): + client.check_plan_available("my_plan") + + +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_submit_plan_with_autostart(mock_re_manager): + """Test submit_plan adds item and starts queue when autostart=True.""" + client = QueueserverClient(mock_re_manager, "inproc://test") + mock_plan = MagicMock() + + client.submit_plan(mock_plan, autostart=True) + + mock_re_manager.item_add.assert_called_once_with(mock_plan) + mock_re_manager.wait_for_idle_or_paused.assert_called_once() + mock_re_manager.queue_start.assert_called_once() + + +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_submit_plan_without_autostart(mock_re_manager): + """Test submit_plan only adds item when autostart=False.""" + client = QueueserverClient(mock_re_manager, "inproc://test") + mock_plan = MagicMock() + + client.submit_plan(mock_plan, autostart=False) + + mock_re_manager.item_add.assert_called_once_with(mock_plan) + mock_re_manager.queue_start.assert_not_called() + + +@patch("blop.queueserver.threading.Thread") +@patch("blop.queueserver.RemoteDispatcher") +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_start_listener(mock_re_manager, mock_dispatcher_cls, mock_thread_cls): + """Test start_listener creates dispatcher, subscribes callback, and starts thread.""" + mock_re_manager.status.return_value = {"worker_environment_exists": True} + mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}, "detector": {}}} + mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"default_acquire": {}}} + + client = QueueserverClient(mock_re_manager, "tcp://localhost:5578") + mock_callback = MagicMock() + + client.start_listener(on_stop=mock_callback) + + mock_dispatcher_cls.assert_called_once_with("tcp://localhost:5578") + mock_dispatcher = mock_dispatcher_cls.return_value + mock_dispatcher.subscribe.assert_called_once() + subscribed_callback = mock_dispatcher.subscribe.call_args[0][0] + assert isinstance(subscribed_callback, ConsumerCallback) + assert subscribed_callback._callback is mock_callback + + mock_thread_cls.assert_called_once() + call_kwargs = mock_thread_cls.call_args[1] + assert call_kwargs["target"] == mock_dispatcher.start + mock_thread_cls.return_value.start.assert_called_once() + + +@patch("blop.queueserver.threading.Thread") +@patch("blop.queueserver.RemoteDispatcher") +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_start_listener_already_running_returns_early( + mock_re_manager, mock_dispatcher_cls, mock_thread_cls +): + """Test start_listener returns early when listener is already running.""" + mock_re_manager.status.return_value = {"worker_environment_exists": True} + mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}, "detector": {}}} + mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"default_acquire": {}}} + + client = QueueserverClient(mock_re_manager, "tcp://localhost:5578") + client._listener_thread = MagicMock() # Simulate already running + + client.start_listener(on_stop=MagicMock()) + + mock_dispatcher_cls.assert_not_called() + mock_thread_cls.assert_not_called() + + +@patch("blop.queueserver.threading.Thread") +@patch("blop.queueserver.RemoteDispatcher") +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_stop_listener(mock_re_manager, mock_dispatcher_cls, mock_thread_cls): + """Test stop_listener stops dispatcher and clears state.""" + mock_re_manager.status.return_value = {"worker_environment_exists": True} + mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}, "detector": {}}} + mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"default_acquire": {}}} + + client = QueueserverClient(mock_re_manager, "tcp://localhost:5578") + client.start_listener(on_stop=MagicMock()) + + client.stop_listener() + + mock_dispatcher_cls.return_value.stop.assert_called_once() + assert client._dispatcher is None + assert client._consumer_callback is None + assert client._listener_thread is None + + +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_stop_listener_when_not_started(mock_re_manager): + """Test stop_listener is safe to call when listener was never started.""" + client = QueueserverClient(mock_re_manager, "inproc://test") + + client.stop_listener() # Should not raise + + assert client._dispatcher is None + assert client._listener_thread is None + + +def test_runner_run_validates_environment(mock_optimization_problem): + """Test run() validates queueserver environment before starting.""" + mock_client = MagicMock(spec=QueueserverClient) + mock_client.check_environment.side_effect = RuntimeError("not open") + + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + + with pytest.raises(RuntimeError, match="not open"): + runner.run(iterations=1) + + mock_client.check_environment.assert_called_once() + + +def test_runner_run_submits_suggestions_to_queueserver(): + """Test run() gets suggestions from optimizer and submits plan to queueserver.""" + mock_client = MagicMock(spec=QueueserverClient) + mock_optimization_problem = QueueserverOptimizationProblem( + optimizer=MagicMock(), + actuators=["motor1"], + sensors=["det"], + evaluation_function=MagicMock(), + acquisition_plan="my_acquire", + ) + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + assert runner.optimization_problem == mock_optimization_problem + + runner.run(iterations=1, num_points=1) + + # Verify optimizer.suggest was called + mock_optimization_problem.optimizer.suggest.assert_called_once_with(1) + + # Verify plan was submitted + mock_client.submit_plan.assert_called_once() + submitted_plan = mock_client.submit_plan.call_args[0][0] + assert submitted_plan.name == "my_acquire" + + +def test_runner_stop_sets_finished_state(mock_optimization_problem): + """Test stop() marks the runner as finished and stops listener.""" + mock_client = MagicMock(spec=QueueserverClient) + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + + # The acquisiton completion callback never fires here due to the mocked client, therefore + # the first plan runs forever + runner.run(10) + assert runner.is_running is True + + runner.stop() + assert runner.is_running is False + mock_client.stop_listener.assert_called() + + +def test_runner_submit_suggestions_to_queueserver(): + """Test run() gets suggestions from optimizer and submits plan to queueserver.""" + mock_client = MagicMock(spec=QueueserverClient) + + class CustomOptimizer(Optimizer, CanRegisterSuggestions): ... + + mock_optimization_problem = QueueserverOptimizationProblem( + optimizer=MagicMock(spec=CustomOptimizer), + actuators=["motor1"], + sensors=["det"], + evaluation_function=MagicMock(), + acquisition_plan="my_acquire", + ) + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + + suggestions = [{"motor1": 5}] + runner.submit_suggestions(suggestions) + + # Verify optimizer.suggest was NOT called + mock_optimization_problem.optimizer.suggest.assert_not_called() + mock_optimization_problem.optimizer.register_suggestions.assert_called_once_with(suggestions) + + # Verify plan was submitted + mock_client.submit_plan.assert_called_once() + submitted_plan = mock_client.submit_plan.call_args[0][0] + assert submitted_plan.name == "my_acquire" + + +def test_runner_run_full_cycle(mock_optimization_problem): + """Test run() completes full suggest -> acquire -> ingest cycle across 3 iterations.""" + # Configure for num_points=2: suggest returns 2 items, evaluation_function returns 2 outcomes + mock_optimization_problem.optimizer.suggest.return_value = [ + {"_id": 0, "motor1": 5.0, "motor2": 3.0}, + {"_id": 1, "motor1": 6.0, "motor2": 4.0}, + ] + mock_optimization_problem.evaluation_function.return_value = [ + {"_id": 0, "objective": 1.0}, + {"_id": 1, "objective": 2.0}, + ] + + mock_client = MagicMock(spec=QueueserverClient) + + def capture_callback(on_stop): + mock_client._on_stop = on_stop + + mock_client.start_listener.side_effect = capture_callback + + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + + runner.run(iterations=3, num_points=2) + + # Simulate 3 acquisition completions by invoking the captured callback + for _ in range(3): + current_uid = runner._state.current_uid + uid = f"fake-uid-{_}" + start_doc = {"uid": uid, CORRELATION_UID_KEY: current_uid} + stop_doc = {"uid": uid} + mock_client._on_stop(start_doc, stop_doc) + + assert mock_client.submit_plan.call_count == 3 + assert mock_optimization_problem.optimizer.suggest.call_count == 3 + assert mock_optimization_problem.optimizer.ingest.call_count == 3 + assert mock_optimization_problem.evaluation_function.call_count == 3 + assert runner.is_running is False + + +def test_runner_on_acquisition_complete_raises_on_uid_mismatch(mock_optimization_problem): + """Test _on_acquisition_complete raises RuntimeError when blop_correlation_uid does not match.""" + mock_client = MagicMock(spec=QueueserverClient) + + def capture_callback(on_stop): + mock_client._on_stop = on_stop + + mock_client.start_listener.side_effect = capture_callback + + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + + runner.run(iterations=1, num_points=1) + + # Callback with wrong blop_correlation_uid should raise + start_doc = {"uid": "fake-uid", CORRELATION_UID_KEY: "wrong-uid"} + stop_doc = {"uid": "fake-uid"} + + with pytest.raises(RuntimeError, match="current_uid did not match start document"): + mock_client._on_stop(start_doc, stop_doc)