From 4d269309d92172b077ec0c4b2ceeb423f1482575 Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Thu, 5 Mar 2026 17:03:04 -0500 Subject: [PATCH 01/13] Fix qserver agent and add tests --- src/blop/ax/qserver_agent.py | 20 ++- src/blop/tests/ax/test_qserver_agent.py | 219 ++++++++++++++++++++++++ 2 files changed, 234 insertions(+), 5 deletions(-) create mode 100644 src/blop/tests/ax/test_qserver_agent.py diff --git a/src/blop/ax/qserver_agent.py b/src/blop/ax/qserver_agent.py index 694b036a..787ce4eb 100644 --- a/src/blop/ax/qserver_agent.py +++ b/src/blop/ax/qserver_agent.py @@ -140,6 +140,9 @@ def __init__( **kwargs, ) + # Store dofs for qserver-specific operations (parent class only stores actuators) + self._dofs = list(dofs) + # Instantiate an object that can communicate with the queueserver self.RM = REManagerAPI( zmq_control_addr=qserver_control_addr, zmq_info_addr=qserver_info_addr @@ -169,6 +172,11 @@ def __init__( self.acquisition_finished = False self.optimization_problem = None + @property + def dofs(self) -> Sequence[DOF]: + """The degrees of freedom for this agent.""" + return self._dofs + def _stop_doc_callback(self, start_doc, stop_doc): """ In here we can decide whether our experiment requested has completed @@ -212,12 +220,14 @@ def optimize(self, iterations=1, n_points=1): # Check that the devices we want to interact with are in the queueserver environment res = self.RM.devices_allowed() for dof in self.dofs: - if dof.name not in res["devices_allowed"]: - raise ValueError(f"The device {dof.name} is not in the Queueserver Environment") + if dof.parameter_name not in res["devices_allowed"]: + raise ValueError(f"The device {dof.parameter_name} is not in the Queueserver Environment") for sensor in self.sensors: - if sensor not in res["devices_allowed"]: - raise ValueError(f"The device {sensor} is not in the Queueserver Environment") + # Handle both sensor objects (with .name) and string sensor names + sensor_name = sensor.name if hasattr(sensor, "name") else sensor + if sensor_name not in res["devices_allowed"]: + raise ValueError(f"The device {sensor_name} is not in the Queueserver Environment") # Check that the plan we want to call is in the queueserver environment res = self.RM.plans_allowed() @@ -292,7 +302,7 @@ def acquire(self, trials: dict[int, TParameterization] | None = None): item = BPlan( self.acquisition_plan, readables=self.sensors, - dofs=[dof.name for dof in self.dofs], + dofs=[dof.parameter_name for dof in self.dofs], trials=trials, md=kwargs["md"], ) diff --git a/src/blop/tests/ax/test_qserver_agent.py b/src/blop/tests/ax/test_qserver_agent.py new file mode 100644 index 00000000..b2dea817 --- /dev/null +++ b/src/blop/tests/ax/test_qserver_agent.py @@ -0,0 +1,219 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from blop.ax.dof import RangeDOF +from blop.ax.objective import Objective +from blop.ax.qserver_agent import BlopQserverAgent, ConsumerCallback, ZMQConsumer +from blop.protocols import EvaluationFunction + +from ..conftest import MovableSignal, ReadableSignal + + +@pytest.fixture(scope="function") +def mock_evaluation_function(): + return MagicMock(spec=EvaluationFunction) + + +@pytest.fixture(scope="function") +def basic_dofs(): + movable1 = MovableSignal(name="test_movable1") + movable2 = MovableSignal(name="test_movable2") + dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float") + return [dof1, dof2] + + +@pytest.fixture(scope="function") +def basic_sensors(): + return [ReadableSignal(name="test_readable")] + + +@pytest.fixture(scope="function") +def basic_objective(): + return Objective(name="test_objective", minimize=False) + + +@patch("blop.ax.qserver_agent.REManagerAPI") +@patch("blop.ax.qserver_agent.ZMQConsumer") +def test_qserver_agent_init(mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective): + """Test that the qserver agent can be initialized with proper components.""" + agent = BlopQserverAgent( + sensors=basic_sensors, + dofs=basic_dofs, + objectives=[basic_objective], + evaluation_function=mock_evaluation_function, + acquisition_plan="test_plan", + qserver_control_addr="tcp://localhost:60615", + qserver_info_addr="tcp://localhost:60625", + zmq_consumer_ip="localhost", + zmq_consumer_port="5578", + ) + + # Test public properties + assert agent.sensors == basic_sensors + assert list(agent.dofs) == basic_dofs + assert agent.continuous_suggestion is True + assert agent.num_itterations == 30 + assert agent.n_points == 1 + assert agent.current_itteration == 0 + + # Verify REManagerAPI was initialized with correct addresses + mock_re_manager.assert_called_once_with( + zmq_control_addr="tcp://localhost:60615", + zmq_info_addr="tcp://localhost:60625", + ) + # Verify ZMQ consumer was started + mock_zmq_consumer.assert_called_once() + mock_zmq_consumer.return_value.start_zmq_listener_thread.assert_called_once() + + +@patch("blop.ax.qserver_agent.REManagerAPI") +@patch("blop.ax.qserver_agent.ZMQConsumer") +def test_qserver_agent_optimize_validation(mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective): + """Test that optimize validates qserver environment, devices, and plans.""" + # Test 1: Error when qserver environment is not open + agent = BlopQserverAgent( + sensors=basic_sensors, + dofs=basic_dofs, + objectives=[basic_objective], + evaluation_function=mock_evaluation_function, + ) + agent.RM.status.return_value = {"worker_environment_exists": False} + + with pytest.raises(ValueError, match="queueserver environment is not open"): + agent.optimize(iterations=1, n_points=1) + + # Test 2: Error when required device is not in qserver (uses actuator-based DOFs) + agent2 = BlopQserverAgent( + sensors=basic_sensors, + dofs=basic_dofs, + objectives=[basic_objective], + evaluation_function=mock_evaluation_function, + ) + agent2.RM.status.return_value = {"worker_environment_exists": True} + agent2.RM.devices_allowed.return_value = {"devices_allowed": {}} + + with pytest.raises(ValueError, match="device test_movable1 is not in the Queueserver Environment"): + agent2.optimize(iterations=1, n_points=1) + + # Test 3: Error when acquisition plan is not in qserver + agent3 = BlopQserverAgent( + sensors=basic_sensors, + dofs=basic_dofs, + objectives=[basic_objective], + evaluation_function=mock_evaluation_function, + acquisition_plan="missing_plan", + ) + agent3.RM.status.return_value = {"worker_environment_exists": True} + agent3.RM.devices_allowed.return_value = { + "devices_allowed": {"test_movable1": {}, "test_movable2": {}, "test_readable": {}} + } + agent3.RM.plans_allowed.return_value = {"plans_allowed": {}} + + with pytest.raises(ValueError, match="plan missing_plan is not in the Queueserver Environment"): + agent3.optimize(iterations=1, n_points=1) + + +@patch("blop.ax.qserver_agent.REManagerAPI") +@patch("blop.ax.qserver_agent.ZMQConsumer") +def test_qserver_agent_acquire(mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective): + """Test that acquire submits a plan to the qserver with proper metadata and starts queue.""" + agent = BlopQserverAgent( + sensors=basic_sensors, + dofs=basic_dofs, + objectives=[basic_objective], + evaluation_function=mock_evaluation_function, + acquisition_plan="acquire", + ) + + trials = {0: {"test_movable1": 5.0, "test_movable2": 3.0}} + uid = agent.acquire(trials) + + # Verify uid is returned + assert uid is not None + assert isinstance(uid, str) + assert agent.acquisition_finished is False + + # Verify plan was submitted to qserver + agent.RM.item_add.assert_called_once() + call_args = agent.RM.item_add.call_args + bplan = call_args[0][0] + assert bplan.name == "acquire" + assert bplan.kwargs["md"]["agent_suggestion_uid"] == uid + assert bplan.kwargs["md"]["blop_suggestions"] == [{"_id": 0, "test_movable1": 5.0, "test_movable2": 3.0}] + + # Verify queue was started (autostart is enabled by default) + agent.RM.wait_for_idle_or_paused.assert_called_once_with(timeout=600) + agent.RM.queue_start.assert_called_once() + + +@patch("blop.ax.qserver_agent.REManagerAPI") +@patch("blop.ax.qserver_agent.ZMQConsumer") +def test_qserver_agent_stop(mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective): + """Test that stop prevents the agent from auto-starting the queue on acquire.""" + agent = BlopQserverAgent( + sensors=basic_sensors, + dofs=basic_dofs, + objectives=[basic_objective], + evaluation_function=mock_evaluation_function, + ) + + # Stop the agent + agent.stop() + + # Now acquire should not start the queue + trials = {0: {"test_movable1": 5.0, "test_movable2": 3.0}} + agent.acquire(trials) + + # Verify plan was still submitted + agent.RM.item_add.assert_called_once() + # But queue was NOT started + agent.RM.queue_start.assert_not_called() + + +def test_consumer_callback(): + """Test ConsumerCallback caches start doc, calls callback on stop, and clears cache.""" + mock_callback = MagicMock() + callback = ConsumerCallback(callback=mock_callback, enable=True) + start_doc = {"uid": "test-uid", "time": 123} + stop_doc = {"uid": "test-uid", "exit_status": "success"} + + # Test caching on start + callback.start(start_doc) + assert callback.start_doc_cache == start_doc + + # Test callback invocation on stop and cache clearing + callback.stop(stop_doc) + mock_callback.assert_called_once_with(start_doc, stop_doc) + assert callback.start_doc_cache is None + + # Test disabled callback does nothing + disabled_callback = ConsumerCallback(callback=MagicMock(), enable=False) + disabled_callback.start(start_doc) + disabled_callback.stop(stop_doc) + assert disabled_callback.start_doc_cache is None + disabled_callback.callback.assert_not_called() + + +@patch("blop.ax.qserver_agent.RemoteDispatcher") +def test_zmq_consumer(mock_remote_dispatcher): + """Test ZMQConsumer initialization and thread startup.""" + mock_callback = MagicMock() + + consumer = ZMQConsumer( + zmq_consumer_ip_address="localhost", + zmq_consumer_port="5578", + callback=mock_callback, + ) + + # Test public attributes + assert consumer.zmq_consumer_ip_address == "localhost" + assert consumer.zmq_consumer_port == "5578" + + # Verify RemoteDispatcher was initialized and subscribed + mock_remote_dispatcher.assert_called_once_with("localhost:5578") + mock_remote_dispatcher.return_value.subscribe.assert_called_once() + + # Test thread startup (verify it doesn't raise) + consumer.start_zmq_listener_thread() From 20f80de4fad49d2d2fb1b0dde7cc67e40ce4f3a7 Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Thu, 5 Mar 2026 17:15:06 -0500 Subject: [PATCH 02/13] Refactor bluesky-queueserver integration --- src/blop/ax/__init__.py | 6 +- src/blop/ax/qserver_agent.py | 334 --------------------- src/blop/ax/queueserver.py | 380 ++++++++++++++++++++++++ src/blop/tests/ax/test_qserver_agent.py | 219 -------------- src/blop/tests/ax/test_queueserver.py | 196 ++++++++++++ 5 files changed, 580 insertions(+), 555 deletions(-) delete mode 100644 src/blop/ax/qserver_agent.py create mode 100644 src/blop/ax/queueserver.py delete mode 100644 src/blop/tests/ax/test_qserver_agent.py create mode 100644 src/blop/tests/ax/test_queueserver.py diff --git a/src/blop/ax/__init__.py b/src/blop/ax/__init__.py index e5a192c4..8e7de507 100644 --- a/src/blop/ax/__init__.py +++ b/src/blop/ax/__init__.py @@ -2,11 +2,10 @@ from .dof import DOF, ChoiceDOF, DOFConstraint, RangeDOF from .objective import Objective, OutcomeConstraint, ScalarizedObjective, to_ax_objective_str from .optimizer import AxOptimizer -from .qserver_agent import BlopQserverAgent as QserverAgent +from .queueserver import ConsumerCallback, QServerClient, QServerOptimizationRunner __all__ = [ "Agent", - "QserverAgent", "DOF", "RangeDOF", "ChoiceDOF", @@ -16,4 +15,7 @@ "ScalarizedObjective", "to_ax_objective_str", "AxOptimizer", + "QServerClient", + "QServerOptimizationRunner", + "ConsumerCallback", ] diff --git a/src/blop/ax/qserver_agent.py b/src/blop/ax/qserver_agent.py deleted file mode 100644 index 787ce4eb..00000000 --- a/src/blop/ax/qserver_agent.py +++ /dev/null @@ -1,334 +0,0 @@ -import logging - -### -import threading -import uuid -from collections.abc import Sequence -from typing import Any - -from ax.api.types import TParameterization -from bluesky.callbacks import CallbackBase -from bluesky.callbacks.zmq import RemoteDispatcher -from bluesky_queueserver_api import BPlan -from bluesky_queueserver_api.zmq import REManagerAPI - -from ..protocols import EvaluationFunction, Sensor -from .agent import Agent as BlopAxAgent # type: ignore[import-untyped] -from .dof import DOF, DOFConstraint -from .objective import Objective - -logger = logging.getLogger("blop") - - -class ConsumerCallback(CallbackBase): - """ - A child of Callback base which caches the start document and calls a callback function on the stop document - - """ - - def __init__(self, callback: callable = None, enable=True, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.start_doc_cache = None - self.callback = callback # A function that is called when stop is called - self.enable = enable - - def start(self, doc): - if self.enable: - self.start_doc_cache = doc - - def stop(self, doc): - if self.enable: - self.callback(self.start_doc_cache, doc) - self._clear_cache() - - def _clear_cache(self): - self.start_doc_cache = None - - -class ZMQConsumer: - """ - Allows us to start a thread which will listen to docs and call the callback in CallbackBase - """ - - def __init__(self, zmq_consumer_ip_address, zmq_consumer_port, callback: callable = None): - self.zmq_consumer_ip_address = zmq_consumer_ip_address - self.zmq_consumer_port = zmq_consumer_port - - self.zmq_consumer = RemoteDispatcher(f"{self.zmq_consumer_ip_address}:{self.zmq_consumer_port}") - self.zmq_consumer_callback = ConsumerCallback(callback=callback, enable=True) - self.zmq_consumer.subscribe(self.zmq_consumer_callback) - self._zmq_thread = None - - def start_zmq_listener_thread(self): - logger.info("Starting ZMQ Callback Thread") - - self._zmq_thread = threading.Thread(target=self.zmq_consumer.start, name="zmq-consumer", daemon=True) - self._zmq_thread.start() - - -class BlopQserverAgent(BlopAxAgent): - """ - An interface that uses Ax as the backend for optimization and experiment tracking. - - The Agent is the main entry point for setting up and running Bayesian optimization - using Blop. It coordinates the DOFs, objectives, evaluation function, and optimizer - to perform intelligent exploration of the parameter space. - - This class sends JSON strings to a queueserver, rather than emmitting messages to be - consumed directly by a RE. - - Parameters - ---------- - sensors : Sequence[Sensor] - The sensors to use for acquisition. These should be the minimal set - of sensors that are needed to compute the objectives. - dofs : Sequence[DOF] - The degrees of freedom that the agent can control, which determine the search space. - objectives : Sequence[Objective] - The objectives which the agent will try to optimize. - evaluation_function : EvaluationFunction - The function to evaluate acquired data and produce outcomes. - acquisition_plan : str, optional - The name of the plan on the queueserver - dof_constraints : Sequence[DOFConstraint] | None, optional - Constraints on DOFs to refine the search space. - outcome_constraints : Sequence[OutcomeConstraint] | None, optional - Constraints on outcomes to be satisfied during optimization. - qserver_control_addr : str, default="tcp://localhost:60615" - Queueserver Control Address - qserver_info_addr : str, default="tcp://localhost:60625" - Queueserver Info Address - zmq_consumer_ip : str, default= "localhost" - The IP address of the ZMQ proxy to listen for stop document - zmq_consumer_port : str, default= "5578" - The PORT of the ZMQ proxy to listen for stop document - **kwargs : Any - Additional keyword arguments to configure the Ax experiment. - - Notes - ----- - For more complex setups, you can configure the Ax client directly via ``self.ax_client``. - - For complete working examples of creating and using an Agent, see the tutorial - documentation, particularly :doc:`/tutorials/qserver-experiment`. - - - """ - - def __init__( - self, - sensors: Sequence[Sensor], - dofs: Sequence[DOF], - objectives: Sequence[Objective], - evaluation_function: EvaluationFunction = None, - acquisition_plan: str = "acquire", - dof_constraints: Sequence[DOFConstraint] = None, - qserver_control_addr: str = "tcp://localhost:60615", - qserver_info_addr: str = "tcp://localhost:60625", - zmq_consumer_ip: str = "localhost", - zmq_consumer_port: str = "5578", - **kwargs: Any, - ): - super().__init__( - sensors=sensors, - dofs=dofs, - objectives=objectives, - evaluation_function=evaluation_function, - acquisition_plan=acquisition_plan, - dof_constraints=dof_constraints, - **kwargs, - ) - - # Store dofs for qserver-specific operations (parent class only stores actuators) - self._dofs = list(dofs) - - # Instantiate an object that can communicate with the queueserver - self.RM = REManagerAPI( - zmq_control_addr=qserver_control_addr, zmq_info_addr=qserver_info_addr - ) # To Do, Add arguements to class init - - # Should plans be submitted and automatically started, or not? - self._queue_autostart = True - - # Instantiate an object that will listen for start and stop documents and call a method on the stop document - self.zmq_consumer = ZMQConsumer(zmq_consumer_ip, zmq_consumer_port, callback=self._stop_doc_callback) - self.zmq_consumer.start_zmq_listener_thread() - - # Should we do something when there is a new event? - self._listen_to_events = True - - # Should new suggestions be made automatically until all of the trials are complete? - self.continuous_suggestion = True - - # Learning parameters - self.num_itterations = 30 - self.n_points = 1 - - # Variables used to keep track of the current optimization - self.current_itteration = 0 - self.agent_suggestion_uid = None - self.trials = None - self.acquisition_finished = False - self.optimization_problem = None - - @property - def dofs(self) -> Sequence[DOF]: - """The degrees of freedom for this agent.""" - return self._dofs - - def _stop_doc_callback(self, start_doc, stop_doc): - """ - In here we can decide whether our experiment requested has completed - - If it has completed, we can digest the data from it and move on to the next point. - """ - - if self._listen_to_events: - # Mark the current acquisition as finished - - logger.info("A stop document has been received, evaluating") - - # Evaluate it with the evaluation function - outcomes = self.optimization_problem.evaluation_function(uid=self.agent_suggestion_uid, suggestions=self.trials) - - logger.info(f"successfully evaluated id: {self.agent_suggestion_uid}") - - self.acquisition_finished = True - # ingest the data, updating the model of the optimizer - self.optimization_problem.optimizer.ingest(outcomes) - - # After this is complete, call gen_next_trials again if required - if self.continuous_suggestion: - if self.current_itteration < self.num_itterations: - logger.info("making another suggestion") - self.suggest() - else: - self.current_itteration = 0 - logger.info("made all required suggestions") - - def optimize(self, iterations=1, n_points=1): - """ - This method will create the optimization problem, suggest points and execute them in the QS - """ - - # Before we do anything check the connection to the Queueserver - status = self.RM.status() - if not status["worker_environment_exists"]: - raise ValueError("The queueserver environment is not open") - - # Check that the devices we want to interact with are in the queueserver environment - res = self.RM.devices_allowed() - for dof in self.dofs: - if dof.parameter_name not in res["devices_allowed"]: - raise ValueError(f"The device {dof.parameter_name} is not in the Queueserver Environment") - - for sensor in self.sensors: - # Handle both sensor objects (with .name) and string sensor names - sensor_name = sensor.name if hasattr(sensor, "name") else sensor - if sensor_name not in res["devices_allowed"]: - raise ValueError(f"The device {sensor_name} is not in the Queueserver Environment") - - # Check that the plan we want to call is in the queueserver environment - res = self.RM.plans_allowed() - if self._acquisition_plan not in res["plans_allowed"]: - raise ValueError(f"The plan {self._acquisition_plan} is not in the Queueserver Environment") - - # Form the problem and start suggesting points to measure at - self.optimization_problem = self.to_optimization_problem() - self.num_itterations = iterations - self.n_points = n_points - - self.suggest() - - def suggest(self): - """ - get suggestions from the optimizer, then send them to the plan on the queueserver - - """ - - # record this itteration - self.current_itteration = self.current_itteration + 1 - - # Get the trials to perform - self.trials = self.optimization_problem.optimizer._client.get_next_trials(self.n_points) - - # acquire the values from those trials - self.agent_suggestion_uid = self.acquire(self.trials) - logger.info( - f"sending suggestion {self.current_itteration} to queueserver with suggestion id: {self.agent_suggestion_uid}" - ) - - def acquire_baseline(self, parameterization: dict[str, Any] | None = None): - logger.info("This is not implemented for the qserver agent") - - def acquire(self, trials: dict[int, TParameterization] | None = None): - """ - Acquire the new data from the system by submitting the suggested - points to the queueserver. This method does not block while the - queueserver is running. - - Parameters - ---------- - trials : dict[int, TParameterization] - A dictionary mapping trial indices to their suggested parameterizations. - """ - - try: - self.acquisition_finished = False - - # Create a unique identifier which will connect the children to the parent batch - # This batch ID will be used by all runs from this request by the agent. - # It will be used by the EvaluationFunction later to work out what happened. - - agent_suggestion_uid = str(uuid.uuid4()) - kwargs = {} - kwargs.setdefault("md", {}) - - # Add the unique suggestion ID so we can find this run later - kwargs["md"]["agent_suggestion_uid"] = agent_suggestion_uid - - # Add the suggestion _id key so we can work out which number we are on later - suggestions = [ - { - "_id": trial_index, - **parameterization, - } - for trial_index, parameterization in trials.items() - ] - kwargs["md"]["blop_suggestions"] = suggestions - - # Create the BPlan object to send to the queue. Convert dofs to strings - item = BPlan( - self.acquisition_plan, - readables=self.sensors, - dofs=[dof.parameter_name for dof in self.dofs], - trials=trials, - md=kwargs["md"], - ) - - # Send the plan to the Run Engine Manager - r = self.RM.item_add(item) - logger.debug( - f"Sent http-server request for trials {trials}" - f"with agent_suggestion_uid= {agent_suggestion_uid}\n.Received reponse: {r}" - ) - - # If the queue should start automatically, then start the queue. - if self._queue_autostart: - logger.debug("Waiting for Queue to be idle or paused") - self.RM.wait_for_idle_or_paused(timeout=600) - r = self.RM.queue_start() - logger.debug(f"Sent http-server request to start the queue\n.Received reponse: {r}") - - except KeyboardInterrupt as interrupt: - raise interrupt - - return agent_suggestion_uid - - def stop(self): - """ - Stop the agent if it is running - """ - self._queue_autostart = False - self._listen_to_events = False diff --git a/src/blop/ax/queueserver.py b/src/blop/ax/queueserver.py new file mode 100644 index 00000000..d6a97b3e --- /dev/null +++ b/src/blop/ax/queueserver.py @@ -0,0 +1,380 @@ +""" +Queueserver integration for running optimization through a Bluesky queueserver. + +This module provides components for running optimization loops remotely through +a queueserver, rather than directly through a RunEngine. +""" + +import logging +import threading +import uuid +from collections.abc import Sequence +from dataclasses import dataclass, field + +from bluesky.callbacks import CallbackBase +from bluesky.callbacks.zmq import RemoteDispatcher +from bluesky_queueserver_api import BPlan +from bluesky_queueserver_api.zmq import REManagerAPI + +from ..protocols import OptimizationProblem + +logger = logging.getLogger("blop") + + +class ConsumerCallback(CallbackBase): + """ + A callback that caches the start document and invokes a callback on stop. + + Parameters + ---------- + callback : callable + Function to call when a stop document is received. + Signature: callback(start_doc, stop_doc) + """ + + def __init__(self, callback: callable = None): + super().__init__() + self._start_doc_cache = None + self._callback = callback + + def start(self, doc): + self._start_doc_cache = doc + + def stop(self, doc): + if self._callback is not None and self._start_doc_cache is not None: + self._callback(self._start_doc_cache, doc) + self._start_doc_cache = None + + +class QServerClient: + """ + Handles communication with a Bluesky queueserver. + + This class encapsulates all ZMQ and HTTP communication with the queueserver, + including plan submission and event listening. + + Parameters + ---------- + control_addr : str + ZMQ address for queueserver control (e.g., "tcp://localhost:60615"). + info_addr : str + ZMQ address for queueserver info (e.g., "tcp://localhost:60625"). + zmq_consumer_addr : str + Address for ZMQ document consumer (e.g., "localhost:5578"). + """ + + def __init__( + self, + control_addr: str = "tcp://localhost:60615", + info_addr: str = "tcp://localhost:60625", + zmq_consumer_addr: str = "localhost:5578", + ): + self._control_addr = control_addr + self._info_addr = info_addr + self._zmq_consumer_addr = zmq_consumer_addr + + self._rm = REManagerAPI(zmq_control_addr=control_addr, zmq_info_addr=info_addr) + self._dispatcher: RemoteDispatcher | None = None + self._consumer_callback: ConsumerCallback | None = None + self._listener_thread: threading.Thread | None = None + + def check_environment(self) -> None: + """ + Verify that the queueserver environment is ready. + + Raises + ------ + RuntimeError + If the queueserver environment is not open. + """ + status = self._rm.status() + if not status["worker_environment_exists"]: + raise RuntimeError("The queueserver environment is not open") + + def check_devices_available(self, device_names: Sequence[str]) -> None: + """ + Verify that all specified devices are available in the queueserver. + + Parameters + ---------- + device_names : Sequence[str] + Names of devices to check. + + Raises + ------ + ValueError + If any device is not available. + """ + res = self._rm.devices_allowed() + allowed = res["devices_allowed"] + for name in device_names: + if name not in allowed: + raise ValueError(f"Device '{name}' is not available in the queueserver environment") + + def check_plan_available(self, plan_name: str) -> None: + """ + Verify that a plan is available in the queueserver. + + Parameters + ---------- + plan_name : str + Name of the plan to check. + + Raises + ------ + ValueError + If the plan is not available. + """ + res = self._rm.plans_allowed() + if plan_name not in res["plans_allowed"]: + raise ValueError(f"Plan '{plan_name}' is not available in the queueserver environment") + + def submit_plan(self, plan: BPlan, autostart: bool = True, timeout: float = 600) -> None: + """ + Submit a plan to the queueserver queue. + + Parameters + ---------- + plan : BPlan + The plan to submit. + autostart : bool + If True, start the queue after adding the plan. + timeout : float + Timeout in seconds when waiting for queue to be idle. + """ + response = self._rm.item_add(plan) + logger.debug(f"Submitted plan to queue. Response: {response}") + + if autostart: + logger.debug("Waiting for queue to be idle or paused") + self._rm.wait_for_idle_or_paused(timeout=timeout) + response = self._rm.queue_start() + logger.debug(f"Started queue. Response: {response}") + + def start_listener(self, on_stop: callable) -> None: + """ + Start listening for document events from the queueserver. + + Parameters + ---------- + on_stop : callable + Callback invoked when a stop document is received. + Signature: on_stop(start_doc, stop_doc) + """ + if self._listener_thread is not None: + logger.warning("Listener already running") + return + + self._dispatcher = RemoteDispatcher(self._zmq_consumer_addr) + self._consumer_callback = ConsumerCallback(callback=on_stop) + self._dispatcher.subscribe(self._consumer_callback) + + logger.info("Starting ZMQ listener thread") + self._listener_thread = threading.Thread( + target=self._dispatcher.start, + name="qserver-zmq-consumer", + daemon=True, + ) + self._listener_thread.start() + + def stop_listener(self) -> None: + """Stop the ZMQ listener thread.""" + if self._dispatcher is not None: + self._dispatcher.stop() + self._dispatcher = None + self._consumer_callback = None + self._listener_thread = None + logger.info("Stopped ZMQ listener") + + +@dataclass +class _OptimizationState: + """Internal mutable state for an optimization run.""" + + max_iterations: int = 1 + num_points: int = 1 + current_iteration: int = 0 + current_trials: list[dict] = field(default_factory=list) + current_uid: str | None = None + finished: bool = False + + +class QServerOptimizationRunner: + """ + Runs optimization loops through a Bluesky queueserver. + + This class coordinates the optimization workflow by getting suggestions from + the optimizer, submitting acquisition plans to the queueserver, and ingesting + results when plans complete. + + Parameters + ---------- + optimization_problem : OptimizationProblem + The optimization problem to solve, containing the optimizer, actuators, + sensors, and evaluation function. + qserver_client : QServerClient + Client for communicating with the queueserver. + acquisition_plan_name : str + Name of the acquisition plan registered in the queueserver. + + Examples + -------- + >>> from blop.protocols import OptimizationProblem + >>> from blop.ax import AxOptimizer + >>> + >>> # Create optimization problem + >>> problem = OptimizationProblem( + ... optimizer=optimizer, + ... actuators=[motor1, motor2], + ... sensors=[detector], + ... evaluation_function=my_eval_func, + ... ) + >>> + >>> # Create qserver client and runner + >>> client = QServerClient() + >>> runner = QServerOptimizationRunner(problem, client, "my_acquire_plan") + >>> + >>> # Run optimization + >>> runner.run(iterations=10, num_points=1) + """ + + def __init__( + self, + optimization_problem: OptimizationProblem, + qserver_client: QServerClient, + acquisition_plan_name: str = "acquire", + ): + self._problem = optimization_problem + self._client = qserver_client + self._plan_name = acquisition_plan_name + self._state: _OptimizationState | None = None + self._continuous = True + self._autostart = True + + @property + def optimization_problem(self) -> OptimizationProblem: + """The optimization problem being solved.""" + return self._problem + + @property + def is_running(self) -> bool: + """Whether an optimization run is currently in progress.""" + return self._state is not None and not self._state.finished + + @property + def current_iteration(self) -> int: + """The current iteration number (0 if not running).""" + return self._state.current_iteration if self._state else 0 + + def run(self, iterations: int = 1, num_points: int = 1) -> None: + """ + Start the optimization loop. + + Validates the queueserver state, then begins the suggest -> acquire -> ingest + cycle. This method returns immediately; the optimization runs asynchronously + via callbacks. + + Parameters + ---------- + iterations : int + Number of optimization iterations to run. + num_points : int + Number of points to suggest per iteration. + + Raises + ------ + RuntimeError + If the queueserver environment is not ready. + ValueError + If required devices or plans are not available. + """ + self._validate() + self._state = _OptimizationState(max_iterations=iterations, num_points=num_points) + self._continuous = True + self._client.start_listener(on_stop=self._on_acquisition_complete) + self._submit_next() + + def stop(self) -> None: + """ + Stop the optimization loop gracefully. + + The current acquisition will complete, but no further iterations will run. + """ + self._continuous = False + self._client.stop_listener() + if self._state is not None: + self._state.finished = True + logger.info("Optimization stopped") + + def _validate(self) -> None: + """Validate queueserver environment, devices, and plan availability.""" + self._client.check_environment() + + # Collect device names from actuators and sensors + actuator_names = [a.name for a in self._problem.actuators] + sensor_names = [s.name for s in self._problem.sensors] + self._client.check_devices_available(actuator_names + sensor_names) + + self._client.check_plan_available(self._plan_name) + + def _submit_next(self) -> None: + """Get suggestions from optimizer and submit plan to queueserver.""" + self._state.current_iteration += 1 + self._state.current_trials = self._problem.optimizer.suggest(self._state.num_points) + self._state.current_uid = str(uuid.uuid4()) + + logger.info( + f"Submitting iteration {self._state.current_iteration}/{self._state.max_iterations} " + f"with suggestion uid: {self._state.current_uid}" + ) + + plan = self._build_plan() + self._client.submit_plan(plan, autostart=self._autostart) + + def _build_plan(self) -> BPlan: + """Construct a BPlan from the current suggestions.""" + # Build metadata + md = { + "agent_suggestion_uid": self._state.current_uid, + "blop_suggestions": self._state.current_trials, + } + + # Convert trials list to dict format expected by the plan + # The plan expects {trial_index: parameterization} + trials_dict = {trial["_id"]: {k: v for k, v in trial.items() if k != "_id"} for trial in self._state.current_trials} + + # Get device names + actuator_names = [a.name for a in self._problem.actuators] + sensor_names = [s.name for s in self._problem.sensors] + + return BPlan( + self._plan_name, + readables=sensor_names, + dofs=actuator_names, + trials=trials_dict, + md=md, + ) + + def _on_acquisition_complete(self, start_doc: dict, stop_doc: dict) -> None: + """Callback when acquisition finishes. Ingest results and maybe continue.""" + logger.info(f"Acquisition complete for uid: {self._state.current_uid}") + + # Evaluate the results + outcomes = self._problem.evaluation_function( + uid=self._state.current_uid, + suggestions=self._state.current_trials, + ) + + logger.info(f"Evaluated {len(outcomes)} outcomes") + + # Ingest into optimizer + self._problem.optimizer.ingest(outcomes) + + # Continue if appropriate + if self._continuous and self._state.current_iteration < self._state.max_iterations: + logger.info("Continuing to next iteration") + self._submit_next() + else: + self._state.finished = True + self._client.stop_listener() + logger.info(f"Optimization complete after {self._state.current_iteration} iterations") diff --git a/src/blop/tests/ax/test_qserver_agent.py b/src/blop/tests/ax/test_qserver_agent.py deleted file mode 100644 index b2dea817..00000000 --- a/src/blop/tests/ax/test_qserver_agent.py +++ /dev/null @@ -1,219 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from blop.ax.dof import RangeDOF -from blop.ax.objective import Objective -from blop.ax.qserver_agent import BlopQserverAgent, ConsumerCallback, ZMQConsumer -from blop.protocols import EvaluationFunction - -from ..conftest import MovableSignal, ReadableSignal - - -@pytest.fixture(scope="function") -def mock_evaluation_function(): - return MagicMock(spec=EvaluationFunction) - - -@pytest.fixture(scope="function") -def basic_dofs(): - movable1 = MovableSignal(name="test_movable1") - movable2 = MovableSignal(name="test_movable2") - dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") - dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float") - return [dof1, dof2] - - -@pytest.fixture(scope="function") -def basic_sensors(): - return [ReadableSignal(name="test_readable")] - - -@pytest.fixture(scope="function") -def basic_objective(): - return Objective(name="test_objective", minimize=False) - - -@patch("blop.ax.qserver_agent.REManagerAPI") -@patch("blop.ax.qserver_agent.ZMQConsumer") -def test_qserver_agent_init(mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective): - """Test that the qserver agent can be initialized with proper components.""" - agent = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - acquisition_plan="test_plan", - qserver_control_addr="tcp://localhost:60615", - qserver_info_addr="tcp://localhost:60625", - zmq_consumer_ip="localhost", - zmq_consumer_port="5578", - ) - - # Test public properties - assert agent.sensors == basic_sensors - assert list(agent.dofs) == basic_dofs - assert agent.continuous_suggestion is True - assert agent.num_itterations == 30 - assert agent.n_points == 1 - assert agent.current_itteration == 0 - - # Verify REManagerAPI was initialized with correct addresses - mock_re_manager.assert_called_once_with( - zmq_control_addr="tcp://localhost:60615", - zmq_info_addr="tcp://localhost:60625", - ) - # Verify ZMQ consumer was started - mock_zmq_consumer.assert_called_once() - mock_zmq_consumer.return_value.start_zmq_listener_thread.assert_called_once() - - -@patch("blop.ax.qserver_agent.REManagerAPI") -@patch("blop.ax.qserver_agent.ZMQConsumer") -def test_qserver_agent_optimize_validation(mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective): - """Test that optimize validates qserver environment, devices, and plans.""" - # Test 1: Error when qserver environment is not open - agent = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - ) - agent.RM.status.return_value = {"worker_environment_exists": False} - - with pytest.raises(ValueError, match="queueserver environment is not open"): - agent.optimize(iterations=1, n_points=1) - - # Test 2: Error when required device is not in qserver (uses actuator-based DOFs) - agent2 = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - ) - agent2.RM.status.return_value = {"worker_environment_exists": True} - agent2.RM.devices_allowed.return_value = {"devices_allowed": {}} - - with pytest.raises(ValueError, match="device test_movable1 is not in the Queueserver Environment"): - agent2.optimize(iterations=1, n_points=1) - - # Test 3: Error when acquisition plan is not in qserver - agent3 = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - acquisition_plan="missing_plan", - ) - agent3.RM.status.return_value = {"worker_environment_exists": True} - agent3.RM.devices_allowed.return_value = { - "devices_allowed": {"test_movable1": {}, "test_movable2": {}, "test_readable": {}} - } - agent3.RM.plans_allowed.return_value = {"plans_allowed": {}} - - with pytest.raises(ValueError, match="plan missing_plan is not in the Queueserver Environment"): - agent3.optimize(iterations=1, n_points=1) - - -@patch("blop.ax.qserver_agent.REManagerAPI") -@patch("blop.ax.qserver_agent.ZMQConsumer") -def test_qserver_agent_acquire(mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective): - """Test that acquire submits a plan to the qserver with proper metadata and starts queue.""" - agent = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - acquisition_plan="acquire", - ) - - trials = {0: {"test_movable1": 5.0, "test_movable2": 3.0}} - uid = agent.acquire(trials) - - # Verify uid is returned - assert uid is not None - assert isinstance(uid, str) - assert agent.acquisition_finished is False - - # Verify plan was submitted to qserver - agent.RM.item_add.assert_called_once() - call_args = agent.RM.item_add.call_args - bplan = call_args[0][0] - assert bplan.name == "acquire" - assert bplan.kwargs["md"]["agent_suggestion_uid"] == uid - assert bplan.kwargs["md"]["blop_suggestions"] == [{"_id": 0, "test_movable1": 5.0, "test_movable2": 3.0}] - - # Verify queue was started (autostart is enabled by default) - agent.RM.wait_for_idle_or_paused.assert_called_once_with(timeout=600) - agent.RM.queue_start.assert_called_once() - - -@patch("blop.ax.qserver_agent.REManagerAPI") -@patch("blop.ax.qserver_agent.ZMQConsumer") -def test_qserver_agent_stop(mock_zmq_consumer, mock_re_manager, mock_evaluation_function, basic_dofs, basic_sensors, basic_objective): - """Test that stop prevents the agent from auto-starting the queue on acquire.""" - agent = BlopQserverAgent( - sensors=basic_sensors, - dofs=basic_dofs, - objectives=[basic_objective], - evaluation_function=mock_evaluation_function, - ) - - # Stop the agent - agent.stop() - - # Now acquire should not start the queue - trials = {0: {"test_movable1": 5.0, "test_movable2": 3.0}} - agent.acquire(trials) - - # Verify plan was still submitted - agent.RM.item_add.assert_called_once() - # But queue was NOT started - agent.RM.queue_start.assert_not_called() - - -def test_consumer_callback(): - """Test ConsumerCallback caches start doc, calls callback on stop, and clears cache.""" - mock_callback = MagicMock() - callback = ConsumerCallback(callback=mock_callback, enable=True) - start_doc = {"uid": "test-uid", "time": 123} - stop_doc = {"uid": "test-uid", "exit_status": "success"} - - # Test caching on start - callback.start(start_doc) - assert callback.start_doc_cache == start_doc - - # Test callback invocation on stop and cache clearing - callback.stop(stop_doc) - mock_callback.assert_called_once_with(start_doc, stop_doc) - assert callback.start_doc_cache is None - - # Test disabled callback does nothing - disabled_callback = ConsumerCallback(callback=MagicMock(), enable=False) - disabled_callback.start(start_doc) - disabled_callback.stop(stop_doc) - assert disabled_callback.start_doc_cache is None - disabled_callback.callback.assert_not_called() - - -@patch("blop.ax.qserver_agent.RemoteDispatcher") -def test_zmq_consumer(mock_remote_dispatcher): - """Test ZMQConsumer initialization and thread startup.""" - mock_callback = MagicMock() - - consumer = ZMQConsumer( - zmq_consumer_ip_address="localhost", - zmq_consumer_port="5578", - callback=mock_callback, - ) - - # Test public attributes - assert consumer.zmq_consumer_ip_address == "localhost" - assert consumer.zmq_consumer_port == "5578" - - # Verify RemoteDispatcher was initialized and subscribed - mock_remote_dispatcher.assert_called_once_with("localhost:5578") - mock_remote_dispatcher.return_value.subscribe.assert_called_once() - - # Test thread startup (verify it doesn't raise) - consumer.start_zmq_listener_thread() diff --git a/src/blop/tests/ax/test_queueserver.py b/src/blop/tests/ax/test_queueserver.py new file mode 100644 index 00000000..e18ea55e --- /dev/null +++ b/src/blop/tests/ax/test_queueserver.py @@ -0,0 +1,196 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from blop.ax.queueserver import ConsumerCallback, QServerClient, QServerOptimizationRunner +from blop.protocols import OptimizationProblem + + +@pytest.fixture(scope="function") +def mock_optimization_problem(): + """Create a mock OptimizationProblem with necessary components.""" + mock_optimizer = MagicMock() + mock_optimizer.suggest.return_value = [ + {"_id": 0, "motor1": 5.0, "motor2": 3.0}, + ] + + mock_actuator1 = MagicMock() + mock_actuator1.name = "motor1" + mock_actuator2 = MagicMock() + mock_actuator2.name = "motor2" + + mock_sensor = MagicMock() + mock_sensor.name = "detector" + + mock_eval_func = MagicMock() + mock_eval_func.return_value = [{"_id": 0, "objective": 1.0}] + + return OptimizationProblem( + optimizer=mock_optimizer, + actuators=[mock_actuator1, mock_actuator2], + sensors=[mock_sensor], + evaluation_function=mock_eval_func, + ) + + +def test_consumer_callback_caches_start_and_calls_on_stop(): + """Test ConsumerCallback caches start doc and calls callback on stop.""" + mock_callback = MagicMock() + callback = ConsumerCallback(callback=mock_callback) + start_doc = {"uid": "test-uid", "time": 123} + stop_doc = {"uid": "test-uid", "exit_status": "success"} + + callback.start(start_doc) + callback.stop(stop_doc) + + mock_callback.assert_called_once_with(start_doc, stop_doc) + + +def test_consumer_callback_clears_cache_after_stop(): + """Test ConsumerCallback clears cache after stop is called.""" + callback = ConsumerCallback(callback=MagicMock()) + start_doc = {"uid": "test-uid"} + stop_doc = {"uid": "test-uid"} + + callback.start(start_doc) + callback.stop(stop_doc) + + # Second stop should not call callback (no cached start doc) + callback.stop(stop_doc) + assert callback._callback.call_count == 1 + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_qserver_client_check_environment_raises_when_not_ready(mock_re_manager): + """Test check_environment raises RuntimeError when environment not open.""" + client = QServerClient() + client._rm.status.return_value = {"worker_environment_exists": False} + + with pytest.raises(RuntimeError, match="queueserver environment is not open"): + client.check_environment() + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_qserver_client_check_devices_raises_for_missing_device(mock_re_manager): + """Test check_devices_available raises ValueError for missing devices.""" + client = QServerClient() + client._rm.devices_allowed.return_value = {"devices_allowed": {"motor1": {}}} + + with pytest.raises(ValueError, match="Device 'motor2' is not available"): + client.check_devices_available(["motor1", "motor2"]) + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_qserver_client_check_plan_raises_for_missing_plan(mock_re_manager): + """Test check_plan_available raises ValueError for missing plan.""" + client = QServerClient() + client._rm.plans_allowed.return_value = {"plans_allowed": {"other_plan": {}}} + + with pytest.raises(ValueError, match="Plan 'my_plan' is not available"): + client.check_plan_available("my_plan") + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_qserver_client_submit_plan_with_autostart(mock_re_manager): + """Test submit_plan adds item and starts queue when autostart=True.""" + client = QServerClient() + mock_plan = MagicMock() + + client.submit_plan(mock_plan, autostart=True) + + client._rm.item_add.assert_called_once_with(mock_plan) + client._rm.wait_for_idle_or_paused.assert_called_once() + client._rm.queue_start.assert_called_once() + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_qserver_client_submit_plan_without_autostart(mock_re_manager): + """Test submit_plan only adds item when autostart=False.""" + client = QServerClient() + mock_plan = MagicMock() + + client.submit_plan(mock_plan, autostart=False) + + client._rm.item_add.assert_called_once_with(mock_plan) + client._rm.queue_start.assert_not_called() + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_runner_run_validates_environment(mock_re_manager, mock_optimization_problem): + """Test run() validates qserver environment before starting.""" + mock_client = MagicMock(spec=QServerClient) + mock_client.check_environment.side_effect = RuntimeError("not open") + + runner = QServerOptimizationRunner( + optimization_problem=mock_optimization_problem, + qserver_client=mock_client, + ) + + with pytest.raises(RuntimeError, match="not open"): + runner.run(iterations=1) + + mock_client.check_environment.assert_called_once() + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_runner_run_submits_suggestions_to_qserver(mock_re_manager, mock_optimization_problem): + """Test run() gets suggestions from optimizer and submits plan to qserver.""" + mock_client = MagicMock(spec=QServerClient) + runner = QServerOptimizationRunner( + optimization_problem=mock_optimization_problem, + qserver_client=mock_client, + acquisition_plan_name="my_acquire", + ) + + runner.run(iterations=1, num_points=1) + + # Verify optimizer.suggest was called + mock_optimization_problem.optimizer.suggest.assert_called_once_with(1) + + # Verify plan was submitted + mock_client.submit_plan.assert_called_once() + submitted_plan = mock_client.submit_plan.call_args[0][0] + assert submitted_plan.name == "my_acquire" + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_runner_stop_sets_finished_state(mock_re_manager, mock_optimization_problem): + """Test stop() marks the runner as finished and stops listener.""" + mock_client = MagicMock(spec=QServerClient) + runner = QServerOptimizationRunner( + optimization_problem=mock_optimization_problem, + qserver_client=mock_client, + ) + + runner.run(iterations=10) + assert runner.is_running is True + + runner.stop() + + assert runner.is_running is False + mock_client.stop_listener.assert_called() + + +@patch("blop.ax.queueserver.REManagerAPI") +def test_runner_ingests_outcomes_on_acquisition_complete(mock_re_manager, mock_optimization_problem): + """Test that outcomes are ingested into optimizer when acquisition completes.""" + mock_client = MagicMock(spec=QServerClient) + runner = QServerOptimizationRunner( + optimization_problem=mock_optimization_problem, + qserver_client=mock_client, + ) + + # Start the runner (sets up state) + runner.run(iterations=1, num_points=1) + + # Simulate acquisition completion callback + runner._on_acquisition_complete( + start_doc={"uid": "run-uid"}, + stop_doc={"exit_status": "success"}, + ) + + # Verify evaluation function was called + mock_optimization_problem.evaluation_function.assert_called_once() + + # Verify outcomes were ingested + mock_optimization_problem.optimizer.ingest.assert_called_once() From 3d0936c07ec7650985949afd65f0795b7b631a18 Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Thu, 5 Mar 2026 17:19:31 -0500 Subject: [PATCH 03/13] Fix type issues --- src/blop/ax/queueserver.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/src/blop/ax/queueserver.py b/src/blop/ax/queueserver.py index d6a97b3e..1e96b75f 100644 --- a/src/blop/ax/queueserver.py +++ b/src/blop/ax/queueserver.py @@ -8,13 +8,15 @@ import logging import threading import uuid -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass, field +from typing import Any from bluesky.callbacks import CallbackBase from bluesky.callbacks.zmq import RemoteDispatcher from bluesky_queueserver_api import BPlan from bluesky_queueserver_api.zmq import REManagerAPI +from event_model import RunStart, RunStop from ..protocols import OptimizationProblem @@ -32,15 +34,15 @@ class ConsumerCallback(CallbackBase): Signature: callback(start_doc, stop_doc) """ - def __init__(self, callback: callable = None): + def __init__(self, callback: Callable[[RunStart, RunStop], None] | None = None): super().__init__() - self._start_doc_cache = None + self._start_doc_cache: RunStart | None = None self._callback = callback - def start(self, doc): + def start(self, doc: RunStart) -> None: self._start_doc_cache = doc - def stop(self, doc): + def stop(self, doc: RunStop) -> None: if self._callback is not None and self._start_doc_cache is not None: self._callback(self._start_doc_cache, doc) self._start_doc_cache = None @@ -88,7 +90,7 @@ def check_environment(self) -> None: If the queueserver environment is not open. """ status = self._rm.status() - if not status["worker_environment_exists"]: + if status is None or not status.get("worker_environment_exists", False): raise RuntimeError("The queueserver environment is not open") def check_devices_available(self, device_names: Sequence[str]) -> None: @@ -129,7 +131,7 @@ def check_plan_available(self, plan_name: str) -> None: if plan_name not in res["plans_allowed"]: raise ValueError(f"Plan '{plan_name}' is not available in the queueserver environment") - def submit_plan(self, plan: BPlan, autostart: bool = True, timeout: float = 600) -> None: + def submit_plan(self, plan: BPlan, autostart: bool = True, timeout: int = 600) -> None: """ Submit a plan to the queueserver queue. @@ -151,7 +153,7 @@ def submit_plan(self, plan: BPlan, autostart: bool = True, timeout: float = 600) response = self._rm.queue_start() logger.debug(f"Started queue. Response: {response}") - def start_listener(self, on_stop: callable) -> None: + def start_listener(self, on_stop: Callable[[RunStart, RunStop], None]) -> None: """ Start listening for document events from the queueserver. @@ -165,17 +167,18 @@ def start_listener(self, on_stop: callable) -> None: logger.warning("Listener already running") return - self._dispatcher = RemoteDispatcher(self._zmq_consumer_addr) + dispatcher = RemoteDispatcher(self._zmq_consumer_addr) self._consumer_callback = ConsumerCallback(callback=on_stop) - self._dispatcher.subscribe(self._consumer_callback) + dispatcher.subscribe(self._consumer_callback) logger.info("Starting ZMQ listener thread") self._listener_thread = threading.Thread( - target=self._dispatcher.start, + target=dispatcher.start, name="qserver-zmq-consumer", daemon=True, ) self._listener_thread.start() + self._dispatcher = dispatcher def stop_listener(self) -> None: """Stop the ZMQ listener thread.""" @@ -319,6 +322,8 @@ def _validate(self) -> None: def _submit_next(self) -> None: """Get suggestions from optimizer and submit plan to queueserver.""" + if self._state is None: + raise RuntimeError("_submit_next called before run()") self._state.current_iteration += 1 self._state.current_trials = self._problem.optimizer.suggest(self._state.num_points) self._state.current_uid = str(uuid.uuid4()) @@ -333,8 +338,10 @@ def _submit_next(self) -> None: def _build_plan(self) -> BPlan: """Construct a BPlan from the current suggestions.""" + if self._state is None: + raise RuntimeError("_build_plan called before run()") # Build metadata - md = { + md: dict[str, Any] = { "agent_suggestion_uid": self._state.current_uid, "blop_suggestions": self._state.current_trials, } @@ -355,8 +362,12 @@ def _build_plan(self) -> BPlan: md=md, ) - def _on_acquisition_complete(self, start_doc: dict, stop_doc: dict) -> None: + def _on_acquisition_complete(self, start_doc: RunStart, stop_doc: RunStop) -> None: """Callback when acquisition finishes. Ingest results and maybe continue.""" + if self._state is None: + raise RuntimeError("_on_acquisition_complete called before run()") + if self._state.current_uid is None: + raise RuntimeError("current_uid not set") logger.info(f"Acquisition complete for uid: {self._state.current_uid}") # Evaluate the results From 8b8623b5fb5da6c98af536f87933cc86ffdf16dc Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Mon, 9 Mar 2026 11:50:24 -0400 Subject: [PATCH 04/13] Continue refactor of queueserver support --- pyproject.toml | 48 ++++----- src/blop/ax/__init__.py | 4 - src/blop/protocols.py | 44 +++++++- src/blop/{ax => }/queueserver.py | 106 ++++++++------------ src/blop/tests/{ax => }/test_queueserver.py | 7 +- 5 files changed, 115 insertions(+), 94 deletions(-) rename src/blop/{ax => }/queueserver.py (80%) rename src/blop/tests/{ax => }/test_queueserver.py (97%) diff --git a/pyproject.toml b/pyproject.toml index dc28c323..4daa00ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ - + [build-system] requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" @@ -44,6 +44,7 @@ classifiers = [ dynamic = ["version"] [project.optional-dependencies] +qs = ["bluesky-queueserver-api"] dev = [ "pytest", "pytest-cov", @@ -54,6 +55,7 @@ dev = [ "pandas-stubs", "coverage", "pyright", + "blop[qs]", ] cpu = [ # Empty extra - the source configuration below routes to CPU-only index @@ -83,23 +85,23 @@ local_scheme = "no-local-version" src = ["src", "examples", "docs/source/tutorials"] line-length = 125 lint.select = [ - "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b - "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 - "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e - "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f - "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w - "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i - "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up - "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self - "PLC2701", # private import - https://docs.astral.sh/ruff/rules/import-private-name/ - "LOG015", # root logger call - https://docs.astral.sh/ruff/rules/root-logger-call/ - "S101", # assert - https://docs.astral.sh/ruff/rules/assert/ - "D", # docstring - https://docs.astral.sh/ruff/rules/#pydocstyle-d + "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b + "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 + "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e + "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f + "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w + "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i + "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self + "PLC2701", # private import - https://docs.astral.sh/ruff/rules/import-private-name/ + "LOG015", # root logger call - https://docs.astral.sh/ruff/rules/root-logger-call/ + "S101", # assert - https://docs.astral.sh/ruff/rules/assert/ + "D", # docstring - https://docs.astral.sh/ruff/rules/#pydocstyle-d ] lint.ignore = [ - "D", # TODO: Add docstrings, then enforce these errors - "SLF001", # TODO: Fix private member access, https://github.com/NSLS-II/blop/issues/94 - "B901", # return-in-generator - https://docs.astral.sh/ruff/rules/return-in-generator/ + "D", # TODO: Add docstrings, then enforce these errors + "SLF001", # TODO: Fix private member access, https://github.com/NSLS-II/blop/issues/94 + "B901", # return-in-generator - https://docs.astral.sh/ruff/rules/return-in-generator/ ] lint.preview = true # so that preview mode PLC2701, and LOG015 is enabled @@ -114,20 +116,18 @@ convention = "numpy" [tool.pyright] ignore = [ - "sim/", - "src/blop/tests/", - "src/blop/bayesian/", # TODO: Remove this and fix type errors - "src/blop/ax/qserver_agent.py", # TODO: Remove this and fix type errors + "sim/", + "src/blop/tests/", + "src/blop/bayesian/", # TODO: Remove this and fix type errors + "src/blop/ax/qserver_agent.py", # TODO: Remove this and fix type errors ] # Configure PyTorch CPU-only installation when [cpu] extra is requested # This requires uv (https://docs.astral.sh/uv/) [tool.uv.sources] -torch = [ - { index = "pytorch-cpu", marker = "extra == 'cpu'" }, -] +torch = [{ index = "pytorch-cpu", marker = "extra == 'cpu'" }] [[tool.uv.index]] name = "pytorch-cpu" url = "https://download.pytorch.org/whl/cpu" -explicit = true # Only use this index for torch-related packages +explicit = true # Only use this index for torch-related packages diff --git a/src/blop/ax/__init__.py b/src/blop/ax/__init__.py index 8e7de507..cdc4dc9b 100644 --- a/src/blop/ax/__init__.py +++ b/src/blop/ax/__init__.py @@ -2,7 +2,6 @@ from .dof import DOF, ChoiceDOF, DOFConstraint, RangeDOF from .objective import Objective, OutcomeConstraint, ScalarizedObjective, to_ax_objective_str from .optimizer import AxOptimizer -from .queueserver import ConsumerCallback, QServerClient, QServerOptimizationRunner __all__ = [ "Agent", @@ -15,7 +14,4 @@ "ScalarizedObjective", "to_ax_objective_str", "AxOptimizer", - "QServerClient", - "QServerOptimizationRunner", - "ConsumerCallback", ] diff --git a/src/blop/protocols.py b/src/blop/protocols.py index dc31f0c6..39566656 100644 --- a/src/blop/protocols.py +++ b/src/blop/protocols.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, Protocol, runtime_checkable +from typing import Any, Literal, Protocol, runtime_checkable from bluesky.protocols import EventCollectable, EventPageCollectable, Flyable, NamedMovable, Readable from bluesky.utils import MsgGenerator, plan @@ -175,6 +175,7 @@ def __call__( suggestions: list[dict], actuators: Sequence[Actuator], sensors: Sequence[Sensor] | None = None, + md: dict[str, Any] | None = None, ) -> MsgGenerator[str]: """ Acquire data for optimization. @@ -192,6 +193,8 @@ def __call__( The actuators to move to their suggested positions. sensors: Sequence[Sensor], optional The sensors that produce data to evaluate. + md : dict[str, Any] | None, optional + Metadata to attach to the start document Returns ------- @@ -235,3 +238,42 @@ class OptimizationProblem: sensors: Sequence[Sensor] evaluation_function: EvaluationFunction acquisition_plan: AcquisitionPlan | None = None + + +@dataclass(frozen=True) +class RemoteOptimizationProblem: + """ + An optimization problem to solve. Immutable once initialized. + + This dataclass encapsulates all components needed for optimization into a single + immutable structure. It is typically used with optimization plans that exist on + a remote server (e.g. bluesky-queueserver). + + Since the remote server manages the instances of the actuators, sensors, and acquistion_plan, + we only refer to them by their names here. + + Attributes + ---------- + optimizer : Optimizer + Suggests points to evaluate and ingests outcomes to inform the optimization. + actuators : Sequence[str] + Names of objects that can be moved. + sensors : Sequence[str] + Names of objects that can produce data. + evaluation_function : EvaluationFunction + A callable to evaluate data from a Bluesky run and produce outcomes. + acquisition_plan: str, optional + Name of a Bluesky plan to acquire data from the beamline. If not provided, a default plan will be used. + Function signature must match `blop.protocols.AcquisitionPlan`. + + See Also + -------- + OptimizationProblem : Alternative configuration when control is local. + blop.queueserver.QueueserverOptimizationRunner : Runner for remote optimization problems using bluesky-queueserver + """ + + optimizer: Optimizer + actuators: Sequence[str] + sensors: Sequence[str] + evaluation_function: EvaluationFunction + acquisition_plan: str | None = None diff --git a/src/blop/ax/queueserver.py b/src/blop/queueserver.py similarity index 80% rename from src/blop/ax/queueserver.py rename to src/blop/queueserver.py index 1e96b75f..275ed86c 100644 --- a/src/blop/ax/queueserver.py +++ b/src/blop/queueserver.py @@ -18,11 +18,15 @@ from bluesky_queueserver_api.zmq import REManagerAPI from event_model import RunStart, RunStop -from ..protocols import OptimizationProblem +from .plans import default_acquire +from .protocols import RemoteOptimizationProblem logger = logging.getLogger("blop") +DEFAULT_ACQUIRE_PLAN_NAME: str = default_acquire.__name__ + + class ConsumerCallback(CallbackBase): """ A callback that caches the start document and invokes a callback on stop. @@ -40,15 +44,20 @@ def __init__(self, callback: Callable[[RunStart, RunStop], None] | None = None): self._callback = callback def start(self, doc: RunStart) -> None: - self._start_doc_cache = doc + """Caches the start document if it came from Blop""" + if doc.get("blop_correlation_id", None): + self._start_doc_cache = doc + else: + self._start_doc_cache = None def stop(self, doc: RunStop) -> None: - if self._callback is not None and self._start_doc_cache is not None: + """Executes the callback if the cached start and stop document match""" + if self._callback is not None and self._start_doc_cache is not None and self._start_doc_cache["uid"] == doc["uid"]: self._callback(self._start_doc_cache, doc) self._start_doc_cache = None -class QServerClient: +class QueueserverClient: """ Handles communication with a Bluesky queueserver. @@ -57,25 +66,20 @@ class QServerClient: Parameters ---------- - control_addr : str - ZMQ address for queueserver control (e.g., "tcp://localhost:60615"). - info_addr : str - ZMQ address for queueserver info (e.g., "tcp://localhost:60625"). + re_manager_api : bluesky_queueserver_api.zmq.REManagerAPI + Manager instance for communication with Bluesky Queueserver zmq_consumer_addr : str Address for ZMQ document consumer (e.g., "localhost:5578"). """ def __init__( self, - control_addr: str = "tcp://localhost:60615", - info_addr: str = "tcp://localhost:60625", - zmq_consumer_addr: str = "localhost:5578", + re_manager_api: REManagerAPI, + zmq_consumer_addr: str, ): - self._control_addr = control_addr - self._info_addr = info_addr self._zmq_consumer_addr = zmq_consumer_addr - self._rm = REManagerAPI(zmq_control_addr=control_addr, zmq_info_addr=info_addr) + self._rm = re_manager_api self._dispatcher: RemoteDispatcher | None = None self._consumer_callback: ConsumerCallback | None = None self._listener_thread: threading.Thread | None = None @@ -139,9 +143,9 @@ def submit_plan(self, plan: BPlan, autostart: bool = True, timeout: int = 600) - ---------- plan : BPlan The plan to submit. - autostart : bool + autostart : bool, optional If True, start the queue after adding the plan. - timeout : float + timeout : float, optional Timeout in seconds when waiting for queue to be idle. """ response = self._rm.item_add(plan) @@ -197,12 +201,12 @@ class _OptimizationState: max_iterations: int = 1 num_points: int = 1 current_iteration: int = 0 - current_trials: list[dict] = field(default_factory=list) + current_suggestions: list[dict] = field(default_factory=list) current_uid: str | None = None finished: bool = False -class QServerOptimizationRunner: +class QueueserverOptimizationRunner: """ Runs optimization loops through a Bluesky queueserver. @@ -212,40 +216,20 @@ class QServerOptimizationRunner: Parameters ---------- - optimization_problem : OptimizationProblem + optimization_problem : RemoteOptimizationProblem The optimization problem to solve, containing the optimizer, actuators, sensors, and evaluation function. - qserver_client : QServerClient + qserver_client : QueueserverClient Client for communicating with the queueserver. acquisition_plan_name : str Name of the acquisition plan registered in the queueserver. - - Examples - -------- - >>> from blop.protocols import OptimizationProblem - >>> from blop.ax import AxOptimizer - >>> - >>> # Create optimization problem - >>> problem = OptimizationProblem( - ... optimizer=optimizer, - ... actuators=[motor1, motor2], - ... sensors=[detector], - ... evaluation_function=my_eval_func, - ... ) - >>> - >>> # Create qserver client and runner - >>> client = QServerClient() - >>> runner = QServerOptimizationRunner(problem, client, "my_acquire_plan") - >>> - >>> # Run optimization - >>> runner.run(iterations=10, num_points=1) """ def __init__( self, - optimization_problem: OptimizationProblem, - qserver_client: QServerClient, - acquisition_plan_name: str = "acquire", + optimization_problem: RemoteOptimizationProblem, + qserver_client: QueueserverClient, + acquisition_plan_name: str = DEFAULT_ACQUIRE_PLAN_NAME, ): self._problem = optimization_problem self._client = qserver_client @@ -255,7 +239,7 @@ def __init__( self._autostart = True @property - def optimization_problem(self) -> OptimizationProblem: + def optimization_problem(self) -> RemoteOptimizationProblem: """The optimization problem being solved.""" return self._problem @@ -291,6 +275,7 @@ def run(self, iterations: int = 1, num_points: int = 1) -> None: ValueError If required devices or plans are not available. """ + # TODO: What if there is already a run? self._validate() self._state = _OptimizationState(max_iterations=iterations, num_points=num_points) self._continuous = True @@ -314,8 +299,8 @@ def _validate(self) -> None: self._client.check_environment() # Collect device names from actuators and sensors - actuator_names = [a.name for a in self._problem.actuators] - sensor_names = [s.name for s in self._problem.sensors] + actuator_names = list(self._problem.actuators) + sensor_names = list(self._problem.sensors) self._client.check_devices_available(actuator_names + sensor_names) self._client.check_plan_available(self._plan_name) @@ -325,7 +310,7 @@ def _submit_next(self) -> None: if self._state is None: raise RuntimeError("_submit_next called before run()") self._state.current_iteration += 1 - self._state.current_trials = self._problem.optimizer.suggest(self._state.num_points) + self._state.current_suggestions = self._problem.optimizer.suggest(self._state.num_points) self._state.current_uid = str(uuid.uuid4()) logger.info( @@ -342,23 +327,15 @@ def _build_plan(self) -> BPlan: raise RuntimeError("_build_plan called before run()") # Build metadata md: dict[str, Any] = { - "agent_suggestion_uid": self._state.current_uid, - "blop_suggestions": self._state.current_trials, + "blop_correlation_id": self._state.current_uid, + "blop_suggestions": self._state.current_suggestions, } - # Convert trials list to dict format expected by the plan - # The plan expects {trial_index: parameterization} - trials_dict = {trial["_id"]: {k: v for k, v in trial.items() if k != "_id"} for trial in self._state.current_trials} - - # Get device names - actuator_names = [a.name for a in self._problem.actuators] - sensor_names = [s.name for s in self._problem.sensors] - return BPlan( self._plan_name, - readables=sensor_names, - dofs=actuator_names, - trials=trials_dict, + self._state.current_suggestions, + list(self._problem.actuators), + list(self._problem.sensors), md=md, ) @@ -368,12 +345,17 @@ def _on_acquisition_complete(self, start_doc: RunStart, stop_doc: RunStop) -> No raise RuntimeError("_on_acquisition_complete called before run()") if self._state.current_uid is None: raise RuntimeError("current_uid not set") + if self._state.current_uid != start_doc.get("blop_correlation_uid", None): + raise RuntimeError( + "current_uid did not match start document. " + f"Got: {start_doc.get('blop_correlation_uid', None)}, Expected: {self._state.current_uid}" + ) logger.info(f"Acquisition complete for uid: {self._state.current_uid}") # Evaluate the results outcomes = self._problem.evaluation_function( - uid=self._state.current_uid, - suggestions=self._state.current_trials, + uid=start_doc["uid"], + suggestions=self._state.current_suggestions, ) logger.info(f"Evaluated {len(outcomes)} outcomes") diff --git a/src/blop/tests/ax/test_queueserver.py b/src/blop/tests/test_queueserver.py similarity index 97% rename from src/blop/tests/ax/test_queueserver.py rename to src/blop/tests/test_queueserver.py index e18ea55e..b5de7e18 100644 --- a/src/blop/tests/ax/test_queueserver.py +++ b/src/blop/tests/test_queueserver.py @@ -2,8 +2,8 @@ import pytest -from blop.ax.queueserver import ConsumerCallback, QServerClient, QServerOptimizationRunner from blop.protocols import OptimizationProblem +from blop.queueserver import ConsumerCallback, QServerClient, QServerOptimizationRunner @pytest.fixture(scope="function") @@ -41,8 +41,9 @@ def test_consumer_callback_caches_start_and_calls_on_stop(): stop_doc = {"uid": "test-uid", "exit_status": "success"} callback.start(start_doc) - callback.stop(stop_doc) + mock_callback.assert_not_called() + callback.stop(stop_doc) mock_callback.assert_called_once_with(start_doc, stop_doc) @@ -60,7 +61,7 @@ def test_consumer_callback_clears_cache_after_stop(): assert callback._callback.call_count == 1 -@patch("blop.ax.queueserver.REManagerAPI") +@patch("blop.queueserver.REManagerAPI") def test_qserver_client_check_environment_raises_when_not_ready(mock_re_manager): """Test check_environment raises RuntimeError when environment not open.""" client = QServerClient() From dc1a2a2613b8461e1ffa3823e92cdae8113c144c Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Mon, 9 Mar 2026 13:59:32 -0400 Subject: [PATCH 05/13] Starting fixing tests --- src/blop/tests/test_queueserver.py | 76 +++++++++++++----------------- 1 file changed, 33 insertions(+), 43 deletions(-) diff --git a/src/blop/tests/test_queueserver.py b/src/blop/tests/test_queueserver.py index b5de7e18..5dbeb658 100644 --- a/src/blop/tests/test_queueserver.py +++ b/src/blop/tests/test_queueserver.py @@ -2,33 +2,25 @@ import pytest -from blop.protocols import OptimizationProblem -from blop.queueserver import ConsumerCallback, QServerClient, QServerOptimizationRunner +from blop.protocols import RemoteOptimizationProblem +from blop.queueserver import ConsumerCallback, QueueserverClient, QueueserverOptimizationRunner @pytest.fixture(scope="function") -def mock_optimization_problem(): +def mock_remote_optimization_problem(): """Create a mock OptimizationProblem with necessary components.""" mock_optimizer = MagicMock() mock_optimizer.suggest.return_value = [ {"_id": 0, "motor1": 5.0, "motor2": 3.0}, ] - mock_actuator1 = MagicMock() - mock_actuator1.name = "motor1" - mock_actuator2 = MagicMock() - mock_actuator2.name = "motor2" - - mock_sensor = MagicMock() - mock_sensor.name = "detector" - mock_eval_func = MagicMock() mock_eval_func.return_value = [{"_id": 0, "objective": 1.0}] - return OptimizationProblem( + return RemoteOptimizationProblem( optimizer=mock_optimizer, - actuators=[mock_actuator1, mock_actuator2], - sensors=[mock_sensor], + actuators=["motor1", "motor2"], + sensors=["detector"], evaluation_function=mock_eval_func, ) @@ -37,7 +29,7 @@ def test_consumer_callback_caches_start_and_calls_on_stop(): """Test ConsumerCallback caches start doc and calls callback on stop.""" mock_callback = MagicMock() callback = ConsumerCallback(callback=mock_callback) - start_doc = {"uid": "test-uid", "time": 123} + start_doc = {"uid": "test-uid", "blop_correlation_uid": "123", "time": 123} stop_doc = {"uid": "test-uid", "exit_status": "success"} callback.start(start_doc) @@ -49,8 +41,9 @@ def test_consumer_callback_caches_start_and_calls_on_stop(): def test_consumer_callback_clears_cache_after_stop(): """Test ConsumerCallback clears cache after stop is called.""" - callback = ConsumerCallback(callback=MagicMock()) - start_doc = {"uid": "test-uid"} + mock_callback = MagicMock() + callback = ConsumerCallback(callback=mock_callback) + start_doc = {"uid": "test-uid", "blop_correlation_uid": "123"} stop_doc = {"uid": "test-uid"} callback.start(start_doc) @@ -58,14 +51,14 @@ def test_consumer_callback_clears_cache_after_stop(): # Second stop should not call callback (no cached start doc) callback.stop(stop_doc) - assert callback._callback.call_count == 1 + assert mock_callback.call_count == 1 @patch("blop.queueserver.REManagerAPI") def test_qserver_client_check_environment_raises_when_not_ready(mock_re_manager): """Test check_environment raises RuntimeError when environment not open.""" - client = QServerClient() - client._rm.status.return_value = {"worker_environment_exists": False} + mock_re_manager.status.return_value = {"worker_environment_exists": False} + client = QueueserverClient(mock_re_manager, "inproc://test") with pytest.raises(RuntimeError, match="queueserver environment is not open"): client.check_environment() @@ -74,8 +67,8 @@ def test_qserver_client_check_environment_raises_when_not_ready(mock_re_manager) @patch("blop.ax.queueserver.REManagerAPI") def test_qserver_client_check_devices_raises_for_missing_device(mock_re_manager): """Test check_devices_available raises ValueError for missing devices.""" - client = QServerClient() - client._rm.devices_allowed.return_value = {"devices_allowed": {"motor1": {}}} + mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}}} + client = QueueserverClient(mock_re_manager, "inproc://test") with pytest.raises(ValueError, match="Device 'motor2' is not available"): client.check_devices_available(["motor1", "motor2"]) @@ -84,8 +77,8 @@ def test_qserver_client_check_devices_raises_for_missing_device(mock_re_manager) @patch("blop.ax.queueserver.REManagerAPI") def test_qserver_client_check_plan_raises_for_missing_plan(mock_re_manager): """Test check_plan_available raises ValueError for missing plan.""" - client = QServerClient() - client._rm.plans_allowed.return_value = {"plans_allowed": {"other_plan": {}}} + mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"other_plan": {}}} + client = QueueserverClient(mock_re_manager, "inproc://test") with pytest.raises(ValueError, match="Plan 'my_plan' is not available"): client.check_plan_available("my_plan") @@ -94,35 +87,34 @@ def test_qserver_client_check_plan_raises_for_missing_plan(mock_re_manager): @patch("blop.ax.queueserver.REManagerAPI") def test_qserver_client_submit_plan_with_autostart(mock_re_manager): """Test submit_plan adds item and starts queue when autostart=True.""" - client = QServerClient() + client = QueueserverClient(mock_re_manager, "inproc://test") mock_plan = MagicMock() client.submit_plan(mock_plan, autostart=True) - client._rm.item_add.assert_called_once_with(mock_plan) - client._rm.wait_for_idle_or_paused.assert_called_once() - client._rm.queue_start.assert_called_once() + mock_re_manager.item_add.assert_called_once_with(mock_plan) + mock_re_manager.wait_for_idle_or_paused.assert_called_once() + mock_re_manager.queue_start.assert_called_once() @patch("blop.ax.queueserver.REManagerAPI") def test_qserver_client_submit_plan_without_autostart(mock_re_manager): """Test submit_plan only adds item when autostart=False.""" - client = QServerClient() + client = QueueserverClient(mock_re_manager, "inproc://test") mock_plan = MagicMock() client.submit_plan(mock_plan, autostart=False) - client._rm.item_add.assert_called_once_with(mock_plan) - client._rm.queue_start.assert_not_called() + mock_re_manager.item_add.assert_called_once_with(mock_plan) + mock_re_manager.queue_start.assert_not_called() -@patch("blop.ax.queueserver.REManagerAPI") -def test_runner_run_validates_environment(mock_re_manager, mock_optimization_problem): +def test_runner_run_validates_environment(mock_optimization_problem): """Test run() validates qserver environment before starting.""" - mock_client = MagicMock(spec=QServerClient) + mock_client = MagicMock(spec=QueueserverClient) mock_client.check_environment.side_effect = RuntimeError("not open") - runner = QServerOptimizationRunner( + runner = QueueserverOptimizationRunner( optimization_problem=mock_optimization_problem, qserver_client=mock_client, ) @@ -133,11 +125,10 @@ def test_runner_run_validates_environment(mock_re_manager, mock_optimization_pro mock_client.check_environment.assert_called_once() -@patch("blop.ax.queueserver.REManagerAPI") -def test_runner_run_submits_suggestions_to_qserver(mock_re_manager, mock_optimization_problem): +def test_runner_run_submits_suggestions_to_qserver(mock_optimization_problem): """Test run() gets suggestions from optimizer and submits plan to qserver.""" - mock_client = MagicMock(spec=QServerClient) - runner = QServerOptimizationRunner( + mock_client = MagicMock(spec=QueueserverClient) + runner = QueueserverOptimizationRunner( optimization_problem=mock_optimization_problem, qserver_client=mock_client, acquisition_plan_name="my_acquire", @@ -154,11 +145,10 @@ def test_runner_run_submits_suggestions_to_qserver(mock_re_manager, mock_optimiz assert submitted_plan.name == "my_acquire" -@patch("blop.ax.queueserver.REManagerAPI") -def test_runner_stop_sets_finished_state(mock_re_manager, mock_optimization_problem): +def test_runner_stop_sets_finished_state(mock_optimization_problem): """Test stop() marks the runner as finished and stops listener.""" - mock_client = MagicMock(spec=QServerClient) - runner = QServerOptimizationRunner( + mock_client = MagicMock(spec=QueueserverClient) + runner = QueueserverOptimizationRunner( optimization_problem=mock_optimization_problem, qserver_client=mock_client, ) From e1a6333ab15a9f2dc4ce3c90bf627d1633850654 Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Mon, 9 Mar 2026 16:48:30 -0400 Subject: [PATCH 06/13] Finished updating queueserver unit tests --- src/blop/queueserver.py | 7 ++-- src/blop/tests/test_queueserver.py | 58 +++++++++--------------------- 2 files changed, 21 insertions(+), 44 deletions(-) diff --git a/src/blop/queueserver.py b/src/blop/queueserver.py index 275ed86c..6b9cc43d 100644 --- a/src/blop/queueserver.py +++ b/src/blop/queueserver.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal from bluesky.callbacks import CallbackBase from bluesky.callbacks.zmq import RemoteDispatcher @@ -25,6 +25,7 @@ DEFAULT_ACQUIRE_PLAN_NAME: str = default_acquire.__name__ +CORRELATION_UID_KEY: Literal["blop_correlation_uid"] = "blop_correlation_uid" class ConsumerCallback(CallbackBase): @@ -45,7 +46,7 @@ def __init__(self, callback: Callable[[RunStart, RunStop], None] | None = None): def start(self, doc: RunStart) -> None: """Caches the start document if it came from Blop""" - if doc.get("blop_correlation_id", None): + if doc.get(CORRELATION_UID_KEY, None): self._start_doc_cache = doc else: self._start_doc_cache = None @@ -327,7 +328,7 @@ def _build_plan(self) -> BPlan: raise RuntimeError("_build_plan called before run()") # Build metadata md: dict[str, Any] = { - "blop_correlation_id": self._state.current_uid, + CORRELATION_UID_KEY: self._state.current_uid, "blop_suggestions": self._state.current_suggestions, } diff --git a/src/blop/tests/test_queueserver.py b/src/blop/tests/test_queueserver.py index 5dbeb658..4587ad96 100644 --- a/src/blop/tests/test_queueserver.py +++ b/src/blop/tests/test_queueserver.py @@ -3,7 +3,7 @@ import pytest from blop.protocols import RemoteOptimizationProblem -from blop.queueserver import ConsumerCallback, QueueserverClient, QueueserverOptimizationRunner +from blop.queueserver import CORRELATION_UID_KEY, ConsumerCallback, QueueserverClient, QueueserverOptimizationRunner @pytest.fixture(scope="function") @@ -29,7 +29,7 @@ def test_consumer_callback_caches_start_and_calls_on_stop(): """Test ConsumerCallback caches start doc and calls callback on stop.""" mock_callback = MagicMock() callback = ConsumerCallback(callback=mock_callback) - start_doc = {"uid": "test-uid", "blop_correlation_uid": "123", "time": 123} + start_doc = {"uid": "test-uid", CORRELATION_UID_KEY: "123", "time": 123} stop_doc = {"uid": "test-uid", "exit_status": "success"} callback.start(start_doc) @@ -43,7 +43,7 @@ def test_consumer_callback_clears_cache_after_stop(): """Test ConsumerCallback clears cache after stop is called.""" mock_callback = MagicMock() callback = ConsumerCallback(callback=mock_callback) - start_doc = {"uid": "test-uid", "blop_correlation_uid": "123"} + start_doc = {"uid": "test-uid", CORRELATION_UID_KEY: "123"} stop_doc = {"uid": "test-uid"} callback.start(start_doc) @@ -64,7 +64,7 @@ def test_qserver_client_check_environment_raises_when_not_ready(mock_re_manager) client.check_environment() -@patch("blop.ax.queueserver.REManagerAPI") +@patch("blop.queueserver.REManagerAPI") def test_qserver_client_check_devices_raises_for_missing_device(mock_re_manager): """Test check_devices_available raises ValueError for missing devices.""" mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}}} @@ -74,7 +74,7 @@ def test_qserver_client_check_devices_raises_for_missing_device(mock_re_manager) client.check_devices_available(["motor1", "motor2"]) -@patch("blop.ax.queueserver.REManagerAPI") +@patch("blop.queueserver.REManagerAPI") def test_qserver_client_check_plan_raises_for_missing_plan(mock_re_manager): """Test check_plan_available raises ValueError for missing plan.""" mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"other_plan": {}}} @@ -84,7 +84,7 @@ def test_qserver_client_check_plan_raises_for_missing_plan(mock_re_manager): client.check_plan_available("my_plan") -@patch("blop.ax.queueserver.REManagerAPI") +@patch("blop.queueserver.REManagerAPI") def test_qserver_client_submit_plan_with_autostart(mock_re_manager): """Test submit_plan adds item and starts queue when autostart=True.""" client = QueueserverClient(mock_re_manager, "inproc://test") @@ -97,7 +97,7 @@ def test_qserver_client_submit_plan_with_autostart(mock_re_manager): mock_re_manager.queue_start.assert_called_once() -@patch("blop.ax.queueserver.REManagerAPI") +@patch("blop.queueserver.REManagerAPI") def test_qserver_client_submit_plan_without_autostart(mock_re_manager): """Test submit_plan only adds item when autostart=False.""" client = QueueserverClient(mock_re_manager, "inproc://test") @@ -109,13 +109,13 @@ def test_qserver_client_submit_plan_without_autostart(mock_re_manager): mock_re_manager.queue_start.assert_not_called() -def test_runner_run_validates_environment(mock_optimization_problem): +def test_runner_run_validates_environment(mock_remote_optimization_problem): """Test run() validates qserver environment before starting.""" mock_client = MagicMock(spec=QueueserverClient) mock_client.check_environment.side_effect = RuntimeError("not open") runner = QueueserverOptimizationRunner( - optimization_problem=mock_optimization_problem, + optimization_problem=mock_remote_optimization_problem, qserver_client=mock_client, ) @@ -125,11 +125,11 @@ def test_runner_run_validates_environment(mock_optimization_problem): mock_client.check_environment.assert_called_once() -def test_runner_run_submits_suggestions_to_qserver(mock_optimization_problem): +def test_runner_run_submits_suggestions_to_qserver(mock_remote_optimization_problem): """Test run() gets suggestions from optimizer and submits plan to qserver.""" mock_client = MagicMock(spec=QueueserverClient) runner = QueueserverOptimizationRunner( - optimization_problem=mock_optimization_problem, + optimization_problem=mock_remote_optimization_problem, qserver_client=mock_client, acquisition_plan_name="my_acquire", ) @@ -137,7 +137,7 @@ def test_runner_run_submits_suggestions_to_qserver(mock_optimization_problem): runner.run(iterations=1, num_points=1) # Verify optimizer.suggest was called - mock_optimization_problem.optimizer.suggest.assert_called_once_with(1) + mock_remote_optimization_problem.optimizer.suggest.assert_called_once_with(1) # Verify plan was submitted mock_client.submit_plan.assert_called_once() @@ -145,43 +145,19 @@ def test_runner_run_submits_suggestions_to_qserver(mock_optimization_problem): assert submitted_plan.name == "my_acquire" -def test_runner_stop_sets_finished_state(mock_optimization_problem): +def test_runner_stop_sets_finished_state(mock_remote_optimization_problem): """Test stop() marks the runner as finished and stops listener.""" mock_client = MagicMock(spec=QueueserverClient) runner = QueueserverOptimizationRunner( - optimization_problem=mock_optimization_problem, + optimization_problem=mock_remote_optimization_problem, qserver_client=mock_client, ) - runner.run(iterations=10) + # The acquisiton completion callback never fires here due to the mocked client, therefore + # the first plan runs forever + runner.run(10) assert runner.is_running is True runner.stop() - assert runner.is_running is False mock_client.stop_listener.assert_called() - - -@patch("blop.ax.queueserver.REManagerAPI") -def test_runner_ingests_outcomes_on_acquisition_complete(mock_re_manager, mock_optimization_problem): - """Test that outcomes are ingested into optimizer when acquisition completes.""" - mock_client = MagicMock(spec=QServerClient) - runner = QServerOptimizationRunner( - optimization_problem=mock_optimization_problem, - qserver_client=mock_client, - ) - - # Start the runner (sets up state) - runner.run(iterations=1, num_points=1) - - # Simulate acquisition completion callback - runner._on_acquisition_complete( - start_doc={"uid": "run-uid"}, - stop_doc={"exit_status": "success"}, - ) - - # Verify evaluation function was called - mock_optimization_problem.evaluation_function.assert_called_once() - - # Verify outcomes were ingested - mock_optimization_problem.optimizer.ingest.assert_called_once() From 979328019e05c1d153105450f46b6b0e5ee212a5 Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Mon, 9 Mar 2026 16:50:32 -0400 Subject: [PATCH 07/13] Rename argument qserver -> queueserver --- src/blop/queueserver.py | 6 +++--- src/blop/tests/test_queueserver.py | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/blop/queueserver.py b/src/blop/queueserver.py index 6b9cc43d..34a618bd 100644 --- a/src/blop/queueserver.py +++ b/src/blop/queueserver.py @@ -220,7 +220,7 @@ class QueueserverOptimizationRunner: optimization_problem : RemoteOptimizationProblem The optimization problem to solve, containing the optimizer, actuators, sensors, and evaluation function. - qserver_client : QueueserverClient + queueserver_client : QueueserverClient Client for communicating with the queueserver. acquisition_plan_name : str Name of the acquisition plan registered in the queueserver. @@ -229,11 +229,11 @@ class QueueserverOptimizationRunner: def __init__( self, optimization_problem: RemoteOptimizationProblem, - qserver_client: QueueserverClient, + queueserver_client: QueueserverClient, acquisition_plan_name: str = DEFAULT_ACQUIRE_PLAN_NAME, ): self._problem = optimization_problem - self._client = qserver_client + self._client = queueserver_client self._plan_name = acquisition_plan_name self._state: _OptimizationState | None = None self._continuous = True diff --git a/src/blop/tests/test_queueserver.py b/src/blop/tests/test_queueserver.py index 4587ad96..1b6d5b14 100644 --- a/src/blop/tests/test_queueserver.py +++ b/src/blop/tests/test_queueserver.py @@ -55,7 +55,7 @@ def test_consumer_callback_clears_cache_after_stop(): @patch("blop.queueserver.REManagerAPI") -def test_qserver_client_check_environment_raises_when_not_ready(mock_re_manager): +def test_queueserver_client_check_environment_raises_when_not_ready(mock_re_manager): """Test check_environment raises RuntimeError when environment not open.""" mock_re_manager.status.return_value = {"worker_environment_exists": False} client = QueueserverClient(mock_re_manager, "inproc://test") @@ -65,7 +65,7 @@ def test_qserver_client_check_environment_raises_when_not_ready(mock_re_manager) @patch("blop.queueserver.REManagerAPI") -def test_qserver_client_check_devices_raises_for_missing_device(mock_re_manager): +def test_queueserver_client_check_devices_raises_for_missing_device(mock_re_manager): """Test check_devices_available raises ValueError for missing devices.""" mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}}} client = QueueserverClient(mock_re_manager, "inproc://test") @@ -75,7 +75,7 @@ def test_qserver_client_check_devices_raises_for_missing_device(mock_re_manager) @patch("blop.queueserver.REManagerAPI") -def test_qserver_client_check_plan_raises_for_missing_plan(mock_re_manager): +def test_queueserver_client_check_plan_raises_for_missing_plan(mock_re_manager): """Test check_plan_available raises ValueError for missing plan.""" mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"other_plan": {}}} client = QueueserverClient(mock_re_manager, "inproc://test") @@ -85,7 +85,7 @@ def test_qserver_client_check_plan_raises_for_missing_plan(mock_re_manager): @patch("blop.queueserver.REManagerAPI") -def test_qserver_client_submit_plan_with_autostart(mock_re_manager): +def test_queueserver_client_submit_plan_with_autostart(mock_re_manager): """Test submit_plan adds item and starts queue when autostart=True.""" client = QueueserverClient(mock_re_manager, "inproc://test") mock_plan = MagicMock() @@ -98,7 +98,7 @@ def test_qserver_client_submit_plan_with_autostart(mock_re_manager): @patch("blop.queueserver.REManagerAPI") -def test_qserver_client_submit_plan_without_autostart(mock_re_manager): +def test_queueserver_client_submit_plan_without_autostart(mock_re_manager): """Test submit_plan only adds item when autostart=False.""" client = QueueserverClient(mock_re_manager, "inproc://test") mock_plan = MagicMock() @@ -110,13 +110,13 @@ def test_qserver_client_submit_plan_without_autostart(mock_re_manager): def test_runner_run_validates_environment(mock_remote_optimization_problem): - """Test run() validates qserver environment before starting.""" + """Test run() validates queueserver environment before starting.""" mock_client = MagicMock(spec=QueueserverClient) mock_client.check_environment.side_effect = RuntimeError("not open") runner = QueueserverOptimizationRunner( optimization_problem=mock_remote_optimization_problem, - qserver_client=mock_client, + queueserver_client=mock_client, ) with pytest.raises(RuntimeError, match="not open"): @@ -125,12 +125,12 @@ def test_runner_run_validates_environment(mock_remote_optimization_problem): mock_client.check_environment.assert_called_once() -def test_runner_run_submits_suggestions_to_qserver(mock_remote_optimization_problem): - """Test run() gets suggestions from optimizer and submits plan to qserver.""" +def test_runner_run_submits_suggestions_to_queueserver(mock_remote_optimization_problem): + """Test run() gets suggestions from optimizer and submits plan to queueserver.""" mock_client = MagicMock(spec=QueueserverClient) runner = QueueserverOptimizationRunner( optimization_problem=mock_remote_optimization_problem, - qserver_client=mock_client, + queueserver_client=mock_client, acquisition_plan_name="my_acquire", ) @@ -150,7 +150,7 @@ def test_runner_stop_sets_finished_state(mock_remote_optimization_problem): mock_client = MagicMock(spec=QueueserverClient) runner = QueueserverOptimizationRunner( optimization_problem=mock_remote_optimization_problem, - qserver_client=mock_client, + queueserver_client=mock_client, ) # The acquisiton completion callback never fires here due to the mocked client, therefore From 8ed3322fc50debe4d6da1266e1c653e39c07f03c Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Tue, 10 Mar 2026 12:48:59 -0400 Subject: [PATCH 08/13] Refactoring common interfaces between local vs remote agents --- src/blop/ax/agent.py | 372 +++++++++++++++++++---------- src/blop/ax/dof.py | 14 +- src/blop/protocols.py | 63 ++--- src/blop/queueserver.py | 13 +- src/blop/tests/test_queueserver.py | 4 +- 5 files changed, 297 insertions(+), 169 deletions(-) diff --git a/src/blop/ax/agent.py b/src/blop/ax/agent.py index 76a7057d..25285a4d 100644 --- a/src/blop/ax/agent.py +++ b/src/blop/ax/agent.py @@ -1,7 +1,7 @@ import importlib.util import logging from collections.abc import Sequence -from typing import Any, TypeGuard +from typing import Any, TypeGuard, cast from ax import Client from ax.analysis import ContourPlot @@ -14,10 +14,19 @@ from ax.analysis.analysis_card import AnalysisCardBase # type: ignore[import-untyped] # =============================== from bluesky.utils import MsgGenerator +from bluesky_queueserver_api.zmq import REManagerAPI from ..plans import acquire_baseline, optimize, sample_suggestions -from ..protocols import AcquisitionPlan, Actuator, EvaluationFunction, OptimizationProblem, Sensor +from ..protocols import ( + AcquisitionPlan, + Actuator, + EvaluationFunction, + OptimizationProblem, + QueueserverOptimizationProblem, + Sensor, +) from ..utils import InferredReadable +from ..queueserver import QueueserverClient, QueueserverOptimizationRunner from .dof import DOF, DOFConstraint from .objective import Objective, OutcomeConstraint, to_ax_objective_str from .optimizer import AxOptimizer @@ -33,7 +42,147 @@ def _has_dof_keys(d: dict[DOF, Any] | dict[str, Any]) -> TypeGuard[dict[DOF, Any return all(isinstance(key, DOF) for key in d.keys()) -class Agent: +class _AxAgentMixin: + """ + Mixin providing Ax-related functionality shared by agents. + Expects subclasses to define `self._optimizer` as an `AxOptimizer`. + """ + + _optimizer: AxOptimizer + + @property + def ax_client(self) -> Client: + return self._optimizer.ax_client + + @property + def checkpoint_path(self) -> str | None: + return self._optimizer.checkpoint_path + + @property + def fixed_dofs(self) -> dict[str, Any] | None: + return self._optimizer.fixed_parameters + + @fixed_dofs.setter + def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | dict[str, Any] | None) -> None: + """ + Fix degrees of freedom to a certain value for future optimizations. + + Parameters + ---------- + fixed_dofs : dict[DOF, Any] | dict[str, Any] | None + A mapping of DOFs or DOF names to the values they should be fixed to. + + """ + if not fixed_dofs: + self._optimizer.fixed_parameters = None + return + + if _has_str_keys(fixed_dofs): + self._optimizer.fixed_parameters = fixed_dofs + elif _has_dof_keys(fixed_dofs): + self._optimizer.fixed_parameters = {dof.parameter_name: value for dof, value in fixed_dofs.items()} + else: + raise ValueError( + f"Keys must all be either {type(DOF)} or {type(str)}, but got {type(list(fixed_dofs.keys())[0])}" + ) + + def suggest(self, num_points: int = 1) -> list[dict]: + """ + Get the next point(s) to evaluate in the search space. + + Uses the Bayesian optimization algorithm to suggest promising points based + on all previously acquired data. Each suggestion includes an "_id" key for + tracking. + + Parameters + ---------- + num_points : int, optional + The number of points to suggest. Default is 1. Higher values enable + batch optimization but may reduce optimization efficiency per iteration. + + Returns + ------- + list[dict] + A list of dictionaries, each containing a parameterization of a point to + evaluate next. Each dictionary includes an "_id" key for identification. + """ + return self._optimizer.suggest(num_points) + + def ingest(self, points: list[dict]) -> None: + """ + Ingest evaluation results into the optimizer. + + Updates the optimizer's model with new data. Can ingest both suggested points + (with "_id" key) and external data (without "_id" key). + + Parameters + ---------- + points : list[dict] + A list of dictionaries, each containing outcomes for a trial. For suggested + points, include the "_id" key. For external data, include DOF names and + objective values, and omit "_id". + + Notes + ----- + This method is typically called automatically by :meth:`optimize`. Manual usage + is only needed for custom workflows or when ingesting external data. + + For complete examples, see :doc:`/how-to-guides/attach-data-to-experiments`. + """ + self._optimizer.ingest(points) + + def plot_objective( + self, x_dof_name: str, y_dof_name: str, objective_name: str, *args: Any, **kwargs: Any + ) -> list[AnalysisCardBase]: + """ + Plot the predicted objective as a function of two DOFs. + + Creates a contour plot showing the model's prediction of an objective across + the space defined by two DOFs. Useful for visualizing the optimization landscape. + + Parameters + ---------- + x_dof_name : str + The name of the DOF to plot on the x-axis. + y_dof_name : str + The name of the DOF to plot on the y-axis. + objective_name : str + The name of the objective to plot. + *args : Any + Additional positional arguments passed to Ax's compute_analyses. + **kwargs : Any + Additional keyword arguments passed to Ax's compute_analyses. + + Returns + ------- + list[AnalysisCard] + The computed analysis cards containing the plot data. + + See Also + -------- + ax.analysis.ContourPlot : Pre-built analysis for plotting objectives. + ax.analysis.AnalysisCard : Contains the raw and computed data. + """ + return self.ax_client.compute_analyses( + [ + ContourPlot( + x_parameter_name=x_dof_name, + y_parameter_name=y_dof_name, + metric_name=objective_name, + ), + ], + *args, + **kwargs, + ) + + def checkpoint(self) -> None: + """ + Save the agent's state to a JSON file. + """ + self._optimizer.checkpoint() + + +class Agent(_AxAgentMixin): """ An interface that uses Ax as the backend for optimization and experiment tracking. @@ -93,8 +242,13 @@ def __init__( checkpoint_path: str | None = None, **kwargs: Any, ): + if any(isinstance(dof.actuator, str) for dof in dofs): + dof_actuator_strs = [dof.actuator for dof in dofs if isinstance(dof.actuator, str)] + raise ValueError( + f"DOFs with actuators must be `Actuator` instances, not strings. Got strings for: {dof_actuator_strs}" + ) self._sensors = sensors - self._actuators = [dof.actuator for dof in dofs if dof.actuator is not None] + self._actuators: Sequence[Actuator] = [cast(Actuator, dof.actuator) for dof in dofs if dof.actuator is not None] self._evaluation_function = evaluation_function self._acquisition_plan = acquisition_plan self._optimizer = AxOptimizer( @@ -165,42 +319,6 @@ def evaluation_function(self) -> EvaluationFunction: def acquisition_plan(self) -> AcquisitionPlan | None: return self._acquisition_plan - @property - def ax_client(self) -> Client: - return self._optimizer.ax_client - - @property - def checkpoint_path(self) -> str | None: - return self._optimizer.checkpoint_path - - @property - def fixed_dofs(self) -> dict[str, Any] | None: - return self._optimizer.fixed_parameters - - @fixed_dofs.setter - def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | dict[str, Any] | None) -> None: - """ - Fix degrees of freedom to a certain value for future optimizations. - - Parameters - ---------- - fixed_dofs : dict[DOF, Any] | dict[str, Any] | None - A mapping of DOFs or DOF names to the values they should be fixed to. - - """ - if not fixed_dofs: - self._optimizer.fixed_parameters = None - return - - if _has_str_keys(fixed_dofs): - self._optimizer.fixed_parameters = fixed_dofs - elif _has_dof_keys(fixed_dofs): - self._optimizer.fixed_parameters = {dof.parameter_name: value for dof, value in fixed_dofs.items()} - else: - raise ValueError( - f"Keys must all be either {type(DOF)} or {type(str)}, but got {type(list(fixed_dofs.keys())[0])}" - ) - def to_optimization_problem(self) -> OptimizationProblem: """ Construct an optimization problem from the agent. @@ -227,51 +345,6 @@ def to_optimization_problem(self) -> OptimizationProblem: acquisition_plan=self.acquisition_plan, ) - def suggest(self, num_points: int = 1) -> list[dict]: - """ - Get the next point(s) to evaluate in the search space. - - Uses the Bayesian optimization algorithm to suggest promising points based - on all previously acquired data. Each suggestion includes an "_id" key for - tracking. - - Parameters - ---------- - num_points : int, optional - The number of points to suggest. Default is 1. Higher values enable - batch optimization but may reduce optimization efficiency per iteration. - - Returns - ------- - list[dict] - A list of dictionaries, each containing a parameterization of a point to - evaluate next. Each dictionary includes an "_id" key for identification. - """ - return self._optimizer.suggest(num_points) - - def ingest(self, points: list[dict]) -> None: - """ - Ingest evaluation results into the optimizer. - - Updates the optimizer's model with new data. Can ingest both suggested points - (with "_id" key) and external data (without "_id" key). - - Parameters - ---------- - points : list[dict] - A list of dictionaries, each containing outcomes for a trial. For suggested - points, include the "_id" key. For external data, include DOF names and - objective values, and omit "_id". - - Notes - ----- - This method is typically called automatically by :meth:`optimize`. Manual usage - is only needed for custom workflows or when ingesting external data. - - For complete examples, see :doc:`/how-to-guides/attach-data-to-experiments`. - """ - self._optimizer.ingest(points) - def acquire_baseline(self, parameterization: dict[str, Any] | None = None) -> MsgGenerator[None]: """ Acquire a baseline reading for reference. @@ -363,52 +436,97 @@ def sample_suggestions(self, suggestions: list[dict]) -> MsgGenerator[tuple[str, ) ) - def plot_objective( - self, x_dof_name: str, y_dof_name: str, objective_name: str, *args: Any, **kwargs: Any - ) -> list[AnalysisCardBase]: - """ - Plot the predicted objective as a function of two DOFs. - Creates a contour plot showing the model's prediction of an objective across - the space defined by two DOFs. Useful for visualizing the optimization landscape. +class QueueserverAgent(_AxAgentMixin): + """ + An asynchronous interface that uses Ax as the backend for optimization and experiment tracking + and the bluesky-queueserver-api for scheduling plan execution. - Parameters - ---------- - x_dof_name : str - The name of the DOF to plot on the x-axis. - y_dof_name : str - The name of the DOF to plot on the y-axis. - objective_name : str - The name of the objective to plot. - *args : Any - Additional positional arguments passed to Ax's compute_analyses. - **kwargs : Any - Additional keyword arguments passed to Ax's compute_analyses. + Parameters + ---------- + re_manager_api : REManagerAPI + The manager API for interaction with Bluesky queueserver. + zmq_consumer_addr : str + A ZMQ address to consume Bluesky messages from, to react to plan execution on the + remote server. + sensors : Sequence[str] + The sensors to use for acquisition. These should be the minimal set + of sensors that are needed to compute the objectives. + dofs : Sequence[DOF] + The degrees of freedom that the agent can control, which determine the search space. + objectives : Sequence[Objective] + The objectives which the agent will try to optimize. + evaluation_function : EvaluationFunction + The function to evaluate acquired data and produce outcomes. + acquisition_plan : str | None, optional + The acquisition plan to use for acquiring data from the beamline. If not provided, + :func:`blop.plans.default_acquire` will be assumed. + dof_constraints : Sequence[DOFConstraint] | None, optional + Constraints on DOFs to refine the search space. + outcome_constraints : Sequence[OutcomeConstraint] | None, optional + Constraints on outcomes to be satisfied during optimization. + checkpoint_path : str | None, optional + The path to the checkpoint file to save the optimizer's state to. + **kwargs : Any + Additional keyword arguments to configure the Ax experiment. - Returns - ------- - list[AnalysisCard] - The computed analysis cards containing the plot data. + See Also + -------- + blop.protocols.Sensor : The protocol for sensors. + blop.ax.dof.RangeDOF : For continuous parameters. + blop.ax.dof.ChoiceDOF : For discrete parameters. + blop.ax.objective.Objective : For defining objectives. + blop.ax.optimizer.AxOptimizer : The optimizer used internally. + blop.queueserver.QueueservverOptimizatonRunner : Runner that handles interaction with bluesky-queueserver. + """ - See Also - -------- - ax.analysis.ContourPlot : Pre-built analysis for plotting objectives. - ax.analysis.AnalysisCard : Contains the raw and computed data. - """ - return self.ax_client.compute_analyses( - [ - ContourPlot( - x_parameter_name=x_dof_name, - y_parameter_name=y_dof_name, - metric_name=objective_name, - ), - ], - *args, + def __init__( + self, + re_manager_api: REManagerAPI, + zmq_consumer_addr: str, + sensors: Sequence[str], + dofs: Sequence[DOF], + objectives: Sequence[Objective], + evaluation_function: EvaluationFunction, + acquisition_plan: str | None = None, + dof_constraints: Sequence[DOFConstraint] | None = None, + outcome_constraints: Sequence[OutcomeConstraint] | None = None, + checkpoint_path: str | None = None, + **kwargs: Any, + ): + self._sensors = sensors + self._actuators: Sequence[str] = [] + for dof in dofs: + if dof.actuator is not None: + if isinstance(dof.actuator, Actuator): + self._actuators.append(dof.actuator.name) + else: + self._actuators.append(dof.actuator) + self._evaluation_function = evaluation_function + self._acquisition_plan = acquisition_plan + self._optimizer = AxOptimizer( + parameters=[dof.to_ax_parameter_config() for dof in dofs], + objective=to_ax_objective_str(objectives), + parameter_constraints=[constraint.ax_constraint for constraint in dof_constraints] if dof_constraints else None, + outcome_constraints=[constraint.ax_constraint for constraint in outcome_constraints] + if outcome_constraints + else None, + checkpoint_path=checkpoint_path, **kwargs, ) + self._runner = QueueserverOptimizationRunner( + self.to_optimization_problem(), + QueueserverClient(re_manager_api, zmq_consumer_addr), + ) - def checkpoint(self) -> None: - """ - Save the agent's state to a JSON file. - """ - self._optimizer.checkpoint() + def to_optimization_problem(self) -> QueueserverOptimizationProblem: + return QueueserverOptimizationProblem( + optimizer=self._optimizer, + actuators=self._actuators, + sensors=self._sensors, + evaluation_function=self._evaluation_function, + acquisition_plan=self._acquisition_plan, + ) + + def run(self, iterations=1, n_points=1) -> None: + self._runner.run(iterations, n_points) diff --git a/src/blop/ax/dof.py b/src/blop/ax/dof.py index abcd3e3f..a469f0a6 100644 --- a/src/blop/ax/dof.py +++ b/src/blop/ax/dof.py @@ -24,8 +24,8 @@ class DOF(ABC): ---------- name : str | None The name of the DOF. Provide a name if the DOF is not an actuator. - actuator : Actuator | None - The actuator to use for the DOF. Provide an actuator if the DOF is controllable by Bluesky. + actuator : Actuator | str | None + The actuator or its name to use for the DOF. Provide an actuator if the DOF is controllable by Bluesky. Notes ----- @@ -42,7 +42,7 @@ class DOF(ABC): """ name: str | None = None - actuator: Actuator | None = None + actuator: Actuator | str | None = None def __post_init__(self) -> None: if not (bool(self.name) ^ bool(self.actuator)): @@ -50,7 +50,13 @@ def __post_init__(self) -> None: @property def parameter_name(self) -> str: - return self.name or cast(Actuator, self.actuator).name + if isinstance(self.actuator, Actuator): + param_name = self.actuator.name + elif isinstance(self.actuator, str): + param_name = self.actuator + else: + param_name = cast(str, self.name) + return param_name @abstractmethod def to_ax_parameter_config(self) -> RangeParameterConfig | ChoiceParameterConfig: ... diff --git a/src/blop/protocols.py b/src/blop/protocols.py index 39566656..4c9f36ab 100644 --- a/src/blop/protocols.py +++ b/src/blop/protocols.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Literal, Protocol, runtime_checkable +from typing import Any, Literal, Protocol, runtime_checkable, TypeVar, Generic from bluesky.protocols import EventCollectable, EventPageCollectable, Flyable, NamedMovable, Readable from bluesky.utils import MsgGenerator, plan @@ -9,6 +9,13 @@ Actuator = NamedMovable | Flyable Sensor = Readable | EventCollectable | EventPageCollectable +TActuator = TypeVar("TActuator") +"""Actuator generic type""" +TSensor = TypeVar("TSensor") +"""Sensor generic type""" +TPlan = TypeVar("TPlan") +"""Plan generic type""" + @runtime_checkable class CanRegisterSuggestions(Protocol): @@ -205,7 +212,15 @@ def __call__( @dataclass(frozen=True) -class OptimizationProblem: +class BaseOptimizationProblem(Generic[TActuator, TSensor, TPlan]): + optimizer: Optimizer + actuators: Sequence[TActuator] + sensors: Sequence[TSensor] + evaluation_function: EvaluationFunction + acquisition_plan: TPlan | None = None + + +class OptimizationProblem(BaseOptimizationProblem[Actuator, Sensor, AcquisitionPlan]): """ An optimization problem to solve. Immutable once initialized. @@ -233,47 +248,37 @@ class OptimizationProblem: blop.plans.optimize : Bluesky plan that uses an OptimizationProblem. """ - optimizer: Optimizer - actuators: Sequence[Actuator] - sensors: Sequence[Sensor] - evaluation_function: EvaluationFunction - acquisition_plan: AcquisitionPlan | None = None + ... -@dataclass(frozen=True) -class RemoteOptimizationProblem: +class QueueserverOptimizationProblem(BaseOptimizationProblem[str, str, str]): """ An optimization problem to solve. Immutable once initialized. This dataclass encapsulates all components needed for optimization into a single - immutable structure. It is typically used with optimization plans that exist on - a remote server (e.g. bluesky-queueserver). - - Since the remote server manages the instances of the actuators, sensors, and acquistion_plan, - we only refer to them by their names here. + immutable structure. It is typically created via :meth:`blop.ax.QueueserverAgent.to_optimization_problem` + and used with bluesky-queueserver-api. Actuators, sensors, and the acquisition plan are referenced + by their names, since their instances live on a remote server. Attributes ---------- - optimizer : Optimizer + optimizer: Optimizer Suggests points to evaluate and ingests outcomes to inform the optimization. - actuators : Sequence[str] - Names of objects that can be moved. - sensors : Sequence[str] - Names of objects that can produce data. - evaluation_function : EvaluationFunction + actuators: Sequence[str] + Names of objects that can be moved to control the beamline using the Bluesky RunEngine. + A subset of the actuators' names must match the names of suggested parameterizations. + sensors: Sequence[str] + Names of objects that can produce data to acquire data from the beamline using the Bluesky RunEngine. + evaluation_function: EvaluationFunction A callable to evaluate data from a Bluesky run and produce outcomes. acquisition_plan: str, optional - Name of a Bluesky plan to acquire data from the beamline. If not provided, a default plan will be used. - Function signature must match `blop.protocols.AcquisitionPlan`. + The name of a Bluesky plan to acquire data from the beamline. If not provided, a default plan name will be used. + The plan must match the arguments of :ref:`AcquisitionPlan`. See Also -------- - OptimizationProblem : Alternative configuration when control is local. - blop.queueserver.QueueserverOptimizationRunner : Runner for remote optimization problems using bluesky-queueserver + blop.ax.QueueserverAgent.to_optimization_problem : Creates a QueueserverOptimizationProblem from an agent. + blop.queueserver.QueueserverOptimizationRunner : Runs the optimization loop using the bluesky-queueserver-api. """ - optimizer: Optimizer - actuators: Sequence[str] - sensors: Sequence[str] - evaluation_function: EvaluationFunction - acquisition_plan: str | None = None + ... diff --git a/src/blop/queueserver.py b/src/blop/queueserver.py index 34a618bd..2d343e8b 100644 --- a/src/blop/queueserver.py +++ b/src/blop/queueserver.py @@ -19,7 +19,7 @@ from event_model import RunStart, RunStop from .plans import default_acquire -from .protocols import RemoteOptimizationProblem +from .protocols import QueueserverOptimizationProblem logger = logging.getLogger("blop") @@ -217,7 +217,7 @@ class QueueserverOptimizationRunner: Parameters ---------- - optimization_problem : RemoteOptimizationProblem + optimization_problem : QueueserverOptimizationProblem The optimization problem to solve, containing the optimizer, actuators, sensors, and evaluation function. queueserver_client : QueueserverClient @@ -228,19 +228,18 @@ class QueueserverOptimizationRunner: def __init__( self, - optimization_problem: RemoteOptimizationProblem, + optimization_problem: QueueserverOptimizationProblem, queueserver_client: QueueserverClient, - acquisition_plan_name: str = DEFAULT_ACQUIRE_PLAN_NAME, ): self._problem = optimization_problem self._client = queueserver_client - self._plan_name = acquisition_plan_name + self._plan_name = optimization_problem.acquisition_plan or DEFAULT_ACQUIRE_PLAN_NAME self._state: _OptimizationState | None = None self._continuous = True self._autostart = True @property - def optimization_problem(self) -> RemoteOptimizationProblem: + def optimization_problem(self) -> QueueserverOptimizationProblem: """The optimization problem being solved.""" return self._problem @@ -276,7 +275,7 @@ def run(self, iterations: int = 1, num_points: int = 1) -> None: ValueError If required devices or plans are not available. """ - # TODO: What if there is already a run? + # TODO: What if there is already a run in-progress? self._validate() self._state = _OptimizationState(max_iterations=iterations, num_points=num_points) self._continuous = True diff --git a/src/blop/tests/test_queueserver.py b/src/blop/tests/test_queueserver.py index 1b6d5b14..e8ce2028 100644 --- a/src/blop/tests/test_queueserver.py +++ b/src/blop/tests/test_queueserver.py @@ -2,7 +2,7 @@ import pytest -from blop.protocols import RemoteOptimizationProblem +from blop.protocols import QueueserverOptimizationProblem from blop.queueserver import CORRELATION_UID_KEY, ConsumerCallback, QueueserverClient, QueueserverOptimizationRunner @@ -17,7 +17,7 @@ def mock_remote_optimization_problem(): mock_eval_func = MagicMock() mock_eval_func.return_value = [{"_id": 0, "objective": 1.0}] - return RemoteOptimizationProblem( + return QueueserverOptimizationProblem( optimizer=mock_optimizer, actuators=["motor1", "motor2"], sensors=["detector"], From 59df4b832a243415f1c6616cdfbde02c81af9b6a Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Tue, 10 Mar 2026 13:02:46 -0400 Subject: [PATCH 09/13] Add ability to submit suggestions to the queue --- src/blop/ax/agent.py | 40 +++++++++++++++++++++++++++++++++ src/blop/queueserver.py | 49 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/src/blop/ax/agent.py b/src/blop/ax/agent.py index 25285a4d..c33a8fe2 100644 --- a/src/blop/ax/agent.py +++ b/src/blop/ax/agent.py @@ -529,4 +529,44 @@ def to_optimization_problem(self) -> QueueserverOptimizationProblem: ) def run(self, iterations=1, n_points=1) -> None: + """ + Start the optimization loop. + + Validates the queueserver state, then begins the suggest -> acquire -> ingest + cycle. This method returns immediately; the optimization runs asynchronously + via callbacks. + + Parameters + ---------- + iterations : int + Number of optimization iterations to run. + num_points : int + Number of points to suggest per iteration. + + Raises + ------ + RuntimeError + If the queueserver environment is not ready. + ValueError + If required devices or plans are not available. + """ + self._runner.run(iterations, n_points) + + def submit_suggestions(self, suggestions: list[dict]) -> None: + """ + Evaluate specific parameter combinations. + + Acquires data for given suggestions and ingests results. Supports both + optimizer suggestions and manual points. + + Parameters + ---------- + suggestions : list[dict] + Either optimizer suggestions (with "_id") or manual points (without "_id"). + + See Also + -------- + suggest : Get optimizer suggestions. + """ + self._runner.submit_suggestions(suggestions) diff --git a/src/blop/queueserver.py b/src/blop/queueserver.py index 2d343e8b..a9ce989e 100644 --- a/src/blop/queueserver.py +++ b/src/blop/queueserver.py @@ -19,7 +19,7 @@ from event_model import RunStart, RunStop from .plans import default_acquire -from .protocols import QueueserverOptimizationProblem +from .protocols import QueueserverOptimizationProblem, CanRegisterSuggestions, ID_KEY logger = logging.getLogger("blop") @@ -282,6 +282,25 @@ def run(self, iterations: int = 1, num_points: int = 1) -> None: self._client.start_listener(on_stop=self._on_acquisition_complete) self._submit_next() + def submit_suggestions(self, suggestions: list[dict]) -> None: + """ + Manually submit suggestions to the queue. This method returns immediately; the + optimization runs asynchronously via callbacks. + + Parameters + ---------- + suggestions : list[dict] + Parameter combinations to evaluate. Can be: + + - Optimizer suggestions (with "_id" keys from suggest()) + - Manual points (without "_id", requires CanRegisterSuggestions protocol) + """ + self._validate() + self._state = _OptimizationState(max_iterations=1, num_points=len(suggestions)) + self._continuous = False + self._client.start_listener(on_stop=self._on_acquisition_complete) + self._submit_next_manual(suggestions) + def stop(self) -> None: """ Stop the optimization loop gracefully. @@ -305,6 +324,32 @@ def _validate(self) -> None: self._client.check_plan_available(self._plan_name) + def _submit_next_manual(self, suggestions: list[dict]) -> None: + """Get suggestions from optimizer and submit plan to queueserver.""" + if self._state is None: + raise RuntimeError("_submit_next called before run()") + + # Ensure the suggestions have an ID_KEY or register them with the optimizer + if not isinstance(self.optimization_problem.optimizer, CanRegisterSuggestions) and any( + ID_KEY not in suggestion for suggestion in suggestions + ): + raise ValueError( + f"All suggestions must contain an '{ID_KEY}' key to later match with the outcomes or your optimizer must " + "implement the `blop.protocols.CanRegisterSuggestions` protocol. Please review your optimizer " + f"implementation. Got suggestions: {suggestions}" + ) + elif isinstance(self.optimization_problem.optimizer, CanRegisterSuggestions): + suggestions = self.optimization_problem.optimizer.register_suggestions(suggestions) + + self._state.current_iteration += 1 + self._state.current_suggestions = suggestions + self._state.current_uid = str(uuid.uuid4()) + + logger.info(f"Submitting manually specified suggestion(s) with correlation uid: {self._state.current_uid}") + + plan = self._build_plan() + self._client.submit_plan(plan, autostart=self._autostart) + def _submit_next(self) -> None: """Get suggestions from optimizer and submit plan to queueserver.""" if self._state is None: @@ -315,7 +360,7 @@ def _submit_next(self) -> None: logger.info( f"Submitting iteration {self._state.current_iteration}/{self._state.max_iterations} " - f"with suggestion uid: {self._state.current_uid}" + f"with correlation uid: {self._state.current_uid}" ) plan = self._build_plan() From a80879cd83e3ea9e372a5f2b8ee3a05f25769cf2 Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Tue, 10 Mar 2026 14:24:29 -0400 Subject: [PATCH 10/13] Fix test_queueserver.py --- src/blop/ax/agent.py | 2 +- src/blop/protocols.py | 2 +- src/blop/queueserver.py | 2 +- src/blop/tests/test_queueserver.py | 24 +++++++++++++++--------- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/blop/ax/agent.py b/src/blop/ax/agent.py index c33a8fe2..9d2bbc5f 100644 --- a/src/blop/ax/agent.py +++ b/src/blop/ax/agent.py @@ -25,8 +25,8 @@ QueueserverOptimizationProblem, Sensor, ) -from ..utils import InferredReadable from ..queueserver import QueueserverClient, QueueserverOptimizationRunner +from ..utils import InferredReadable from .dof import DOF, DOFConstraint from .objective import Objective, OutcomeConstraint, to_ax_objective_str from .optimizer import AxOptimizer diff --git a/src/blop/protocols.py b/src/blop/protocols.py index 4c9f36ab..8b9883b7 100644 --- a/src/blop/protocols.py +++ b/src/blop/protocols.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Literal, Protocol, runtime_checkable, TypeVar, Generic +from typing import Any, Generic, Literal, Protocol, TypeVar, runtime_checkable from bluesky.protocols import EventCollectable, EventPageCollectable, Flyable, NamedMovable, Readable from bluesky.utils import MsgGenerator, plan diff --git a/src/blop/queueserver.py b/src/blop/queueserver.py index a9ce989e..7015e7f3 100644 --- a/src/blop/queueserver.py +++ b/src/blop/queueserver.py @@ -19,7 +19,7 @@ from event_model import RunStart, RunStop from .plans import default_acquire -from .protocols import QueueserverOptimizationProblem, CanRegisterSuggestions, ID_KEY +from .protocols import ID_KEY, CanRegisterSuggestions, QueueserverOptimizationProblem logger = logging.getLogger("blop") diff --git a/src/blop/tests/test_queueserver.py b/src/blop/tests/test_queueserver.py index e8ce2028..252a1ba3 100644 --- a/src/blop/tests/test_queueserver.py +++ b/src/blop/tests/test_queueserver.py @@ -7,7 +7,7 @@ @pytest.fixture(scope="function") -def mock_remote_optimization_problem(): +def mock_optimization_problem(): """Create a mock OptimizationProblem with necessary components.""" mock_optimizer = MagicMock() mock_optimizer.suggest.return_value = [ @@ -109,13 +109,13 @@ def test_queueserver_client_submit_plan_without_autostart(mock_re_manager): mock_re_manager.queue_start.assert_not_called() -def test_runner_run_validates_environment(mock_remote_optimization_problem): +def test_runner_run_validates_environment(mock_optimization_problem): """Test run() validates queueserver environment before starting.""" mock_client = MagicMock(spec=QueueserverClient) mock_client.check_environment.side_effect = RuntimeError("not open") runner = QueueserverOptimizationRunner( - optimization_problem=mock_remote_optimization_problem, + optimization_problem=mock_optimization_problem, queueserver_client=mock_client, ) @@ -125,19 +125,25 @@ def test_runner_run_validates_environment(mock_remote_optimization_problem): mock_client.check_environment.assert_called_once() -def test_runner_run_submits_suggestions_to_queueserver(mock_remote_optimization_problem): +def test_runner_run_submits_suggestions_to_queueserver(): """Test run() gets suggestions from optimizer and submits plan to queueserver.""" mock_client = MagicMock(spec=QueueserverClient) + mock_optimization_problem = QueueserverOptimizationProblem( + optimizer=MagicMock(), + actuators=["motor1"], + sensors=["det"], + evaluation_function=MagicMock(), + acquisition_plan="my_acquire", + ) runner = QueueserverOptimizationRunner( - optimization_problem=mock_remote_optimization_problem, + optimization_problem=mock_optimization_problem, queueserver_client=mock_client, - acquisition_plan_name="my_acquire", ) runner.run(iterations=1, num_points=1) # Verify optimizer.suggest was called - mock_remote_optimization_problem.optimizer.suggest.assert_called_once_with(1) + mock_optimization_problem.optimizer.suggest.assert_called_once_with(1) # Verify plan was submitted mock_client.submit_plan.assert_called_once() @@ -145,11 +151,11 @@ def test_runner_run_submits_suggestions_to_queueserver(mock_remote_optimization_ assert submitted_plan.name == "my_acquire" -def test_runner_stop_sets_finished_state(mock_remote_optimization_problem): +def test_runner_stop_sets_finished_state(mock_optimization_problem): """Test stop() marks the runner as finished and stops listener.""" mock_client = MagicMock(spec=QueueserverClient) runner = QueueserverOptimizationRunner( - optimization_problem=mock_remote_optimization_problem, + optimization_problem=mock_optimization_problem, queueserver_client=mock_client, ) From 836bf2802d8bff681de3f6b797419eaeb07fd68f Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Tue, 10 Mar 2026 14:58:52 -0400 Subject: [PATCH 11/13] Wrote unit tests for agent --- src/blop/ax/agent.py | 16 +++++++ src/blop/tests/ax/test_agent.py | 79 ++++++++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/blop/ax/agent.py b/src/blop/ax/agent.py index 9d2bbc5f..ba6ce71c 100644 --- a/src/blop/ax/agent.py +++ b/src/blop/ax/agent.py @@ -519,6 +519,22 @@ def __init__( QueueserverClient(re_manager_api, zmq_consumer_addr), ) + @property + def evaluation_function(self) -> EvaluationFunction: + return self._evaluation_function + + @property + def actuators(self) -> Sequence[str]: + return self._actuators + + @property + def sensors(self) -> Sequence[str]: + return self._sensors + + @property + def acquisition_plan(self) -> str | None: + return self._acquisition_plan + def to_optimization_problem(self) -> QueueserverOptimizationProblem: return QueueserverOptimizationProblem( optimizer=self._optimizer, diff --git a/src/blop/tests/ax/test_agent.py b/src/blop/tests/ax/test_agent.py index 9f67f821..222bd261 100644 --- a/src/blop/tests/ax/test_agent.py +++ b/src/blop/tests/ax/test_agent.py @@ -1,10 +1,11 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np import pytest from ax import Client +from bluesky_queueserver_api.zmq import REManagerAPI -from blop.ax.agent import Agent +from blop.ax.agent import Agent, QueueserverAgent from blop.ax.dof import DOFConstraint, RangeDOF from blop.ax.objective import Objective from blop.ax.optimizer import AxOptimizer @@ -23,6 +24,11 @@ def mock_acquisition_plan(): return MagicMock(spec=AcquisitionPlan) +@pytest.fixture(scope="function") +def mock_re_manager_api(): + return MagicMock(spec=REManagerAPI) + + def test_agent_init(mock_evaluation_function, mock_acquisition_plan): """Test that the agent can be initialized.""" movable1 = MovableSignal(name="test_movable1") @@ -211,3 +217,72 @@ def test_ingest_baseline(mock_evaluation_function): summary_df = agent.ax_client.summarize() assert len(summary_df) == 1 assert summary_df["arm_name"].values[0] == "baseline" + + +def test_queueserver_agent_init(mock_re_manager_api, mock_evaluation_function): + dof1 = RangeDOF(actuator="test_motor1", bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_motor2", bounds=(0, 10), parameter_type="float") + agent = QueueserverAgent( + mock_re_manager_api, + "inproc://test", + ["det"], + [dof1, dof2], + [Objective(name="obj1", minimize=False)], + mock_evaluation_function, + ) + assert agent.sensors == ["det"] + assert agent.actuators == [dof1.actuator, dof2.actuator] + assert agent.evaluation_function == mock_evaluation_function + assert agent.acquisition_plan is None + assert isinstance(agent.ax_client, Client) + + problem = agent.to_optimization_problem() + assert problem.acquisition_plan is None + assert problem.actuators == [dof1.parameter_name, dof2.parameter_name] + assert problem.sensors == ["det"] + assert problem.evaluation_function == mock_evaluation_function + + +@patch("blop.ax.agent.QueueserverClient") +@patch("blop.ax.agent.QueueserverOptimizationRunner") +def test_queueserver_agent_run( + mock_queueserver_runner_cls, mock_queueserver_client_cls, mock_re_manager_api, mock_evaluation_function +): + dof1 = RangeDOF(actuator="test_motor1", bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_motor2", bounds=(0, 10), parameter_type="float") + agent = QueueserverAgent( + mock_re_manager_api, + "inproc://test", + ["det"], + [dof1, dof2], + [Objective(name="obj1", minimize=False)], + mock_evaluation_function, + ) + mock_queueserver_client_cls.assert_called_once() + mock_queueserver_runner_cls.assert_called_once() + + agent.run() + mock_queueserver_runner_cls.return_value.run.assert_called_once_with(1, 1) + + +@patch("blop.ax.agent.QueueserverClient") +@patch("blop.ax.agent.QueueserverOptimizationRunner") +def test_queueserver_agent_submit_suggestions( + mock_queueserver_runner_cls, mock_queueserver_client_cls, mock_re_manager_api, mock_evaluation_function +): + dof1 = RangeDOF(actuator="test_motor1", bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_motor2", bounds=(0, 10), parameter_type="float") + agent = QueueserverAgent( + mock_re_manager_api, + "inproc://test", + ["det"], + [dof1, dof2], + [Objective(name="obj1", minimize=False)], + mock_evaluation_function, + ) + mock_queueserver_client_cls.assert_called_once() + mock_queueserver_runner_cls.assert_called_once() + + suggestions = [{"test_motor1": 5, "test_motor2": 9}] + agent.submit_suggestions(suggestions) + mock_queueserver_runner_cls.return_value.submit_suggestions.assert_called_once_with(suggestions) From 1900a331c62d02803a1a5bcd50e4856203ad533b Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Tue, 10 Mar 2026 17:15:18 -0400 Subject: [PATCH 12/13] Improve test coverage --- src/blop/tests/ax/test_agent.py | 25 ++++ src/blop/tests/test_queueserver.py | 177 ++++++++++++++++++++++++++++- 2 files changed, 201 insertions(+), 1 deletion(-) diff --git a/src/blop/tests/ax/test_agent.py b/src/blop/tests/ax/test_agent.py index 222bd261..abababfc 100644 --- a/src/blop/tests/ax/test_agent.py +++ b/src/blop/tests/ax/test_agent.py @@ -219,6 +219,15 @@ def test_ingest_baseline(mock_evaluation_function): assert summary_df["arm_name"].values[0] == "baseline" +def test_agent_init_actuator_string_raises(mock_evaluation_function): + dof1 = RangeDOF(actuator="test_movable1", bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_movable2", bounds=(0, 10), parameter_type="float") + objective = Objective(name="test_objective", minimize=False) + + with pytest.raises(ValueError, match="not strings"): + Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation_function=mock_evaluation_function) + + def test_queueserver_agent_init(mock_re_manager_api, mock_evaluation_function): dof1 = RangeDOF(actuator="test_motor1", bounds=(0, 10), parameter_type="float") dof2 = RangeDOF(actuator="test_motor2", bounds=(0, 10), parameter_type="float") @@ -243,6 +252,22 @@ def test_queueserver_agent_init(mock_re_manager_api, mock_evaluation_function): assert problem.evaluation_function == mock_evaluation_function +def test_queueserver_agent_init_actuator_instance(mock_re_manager_api, mock_evaluation_function): + movable1 = MovableSignal(name="test_movable1") + dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") + dof2 = RangeDOF(actuator="test_movable2", bounds=(0, 10), parameter_type="float") + agent = QueueserverAgent( + mock_re_manager_api, + "inproc://test", + ["det"], + [dof1, dof2], + [Objective(name="obj1", minimize=False)], + mock_evaluation_function, + ) + + assert agent.actuators == [movable1.name, dof2.parameter_name] + + @patch("blop.ax.agent.QueueserverClient") @patch("blop.ax.agent.QueueserverOptimizationRunner") def test_queueserver_agent_run( diff --git a/src/blop/tests/test_queueserver.py b/src/blop/tests/test_queueserver.py index 252a1ba3..594e7c93 100644 --- a/src/blop/tests/test_queueserver.py +++ b/src/blop/tests/test_queueserver.py @@ -2,7 +2,7 @@ import pytest -from blop.protocols import QueueserverOptimizationProblem +from blop.protocols import CanRegisterSuggestions, Optimizer, QueueserverOptimizationProblem from blop.queueserver import CORRELATION_UID_KEY, ConsumerCallback, QueueserverClient, QueueserverOptimizationRunner @@ -109,6 +109,84 @@ def test_queueserver_client_submit_plan_without_autostart(mock_re_manager): mock_re_manager.queue_start.assert_not_called() +@patch("blop.queueserver.threading.Thread") +@patch("blop.queueserver.RemoteDispatcher") +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_start_listener(mock_re_manager, mock_dispatcher_cls, mock_thread_cls): + """Test start_listener creates dispatcher, subscribes callback, and starts thread.""" + mock_re_manager.status.return_value = {"worker_environment_exists": True} + mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}, "detector": {}}} + mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"default_acquire": {}}} + + client = QueueserverClient(mock_re_manager, "tcp://localhost:5578") + mock_callback = MagicMock() + + client.start_listener(on_stop=mock_callback) + + mock_dispatcher_cls.assert_called_once_with("tcp://localhost:5578") + mock_dispatcher = mock_dispatcher_cls.return_value + mock_dispatcher.subscribe.assert_called_once() + subscribed_callback = mock_dispatcher.subscribe.call_args[0][0] + assert isinstance(subscribed_callback, ConsumerCallback) + assert subscribed_callback._callback is mock_callback + + mock_thread_cls.assert_called_once() + call_kwargs = mock_thread_cls.call_args[1] + assert call_kwargs["target"] == mock_dispatcher.start + mock_thread_cls.return_value.start.assert_called_once() + + +@patch("blop.queueserver.threading.Thread") +@patch("blop.queueserver.RemoteDispatcher") +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_start_listener_already_running_returns_early( + mock_re_manager, mock_dispatcher_cls, mock_thread_cls +): + """Test start_listener returns early when listener is already running.""" + mock_re_manager.status.return_value = {"worker_environment_exists": True} + mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}, "detector": {}}} + mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"default_acquire": {}}} + + client = QueueserverClient(mock_re_manager, "tcp://localhost:5578") + client._listener_thread = MagicMock() # Simulate already running + + client.start_listener(on_stop=MagicMock()) + + mock_dispatcher_cls.assert_not_called() + mock_thread_cls.assert_not_called() + + +@patch("blop.queueserver.threading.Thread") +@patch("blop.queueserver.RemoteDispatcher") +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_stop_listener(mock_re_manager, mock_dispatcher_cls, mock_thread_cls): + """Test stop_listener stops dispatcher and clears state.""" + mock_re_manager.status.return_value = {"worker_environment_exists": True} + mock_re_manager.devices_allowed.return_value = {"devices_allowed": {"motor1": {}, "detector": {}}} + mock_re_manager.plans_allowed.return_value = {"plans_allowed": {"default_acquire": {}}} + + client = QueueserverClient(mock_re_manager, "tcp://localhost:5578") + client.start_listener(on_stop=MagicMock()) + + client.stop_listener() + + mock_dispatcher_cls.return_value.stop.assert_called_once() + assert client._dispatcher is None + assert client._consumer_callback is None + assert client._listener_thread is None + + +@patch("blop.queueserver.REManagerAPI") +def test_queueserver_client_stop_listener_when_not_started(mock_re_manager): + """Test stop_listener is safe to call when listener was never started.""" + client = QueueserverClient(mock_re_manager, "inproc://test") + + client.stop_listener() # Should not raise + + assert client._dispatcher is None + assert client._listener_thread is None + + def test_runner_run_validates_environment(mock_optimization_problem): """Test run() validates queueserver environment before starting.""" mock_client = MagicMock(spec=QueueserverClient) @@ -139,6 +217,7 @@ def test_runner_run_submits_suggestions_to_queueserver(): optimization_problem=mock_optimization_problem, queueserver_client=mock_client, ) + assert runner.optimization_problem == mock_optimization_problem runner.run(iterations=1, num_points=1) @@ -167,3 +246,99 @@ def test_runner_stop_sets_finished_state(mock_optimization_problem): runner.stop() assert runner.is_running is False mock_client.stop_listener.assert_called() + + +def test_runner_submit_suggestions_to_queueserver(): + """Test run() gets suggestions from optimizer and submits plan to queueserver.""" + mock_client = MagicMock(spec=QueueserverClient) + + class CustomOptimizer(Optimizer, CanRegisterSuggestions): ... + + mock_optimization_problem = QueueserverOptimizationProblem( + optimizer=MagicMock(spec=CustomOptimizer), + actuators=["motor1"], + sensors=["det"], + evaluation_function=MagicMock(), + acquisition_plan="my_acquire", + ) + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + + suggestions = [{"motor1": 5}] + runner.submit_suggestions(suggestions) + + # Verify optimizer.suggest was NOT called + mock_optimization_problem.optimizer.suggest.assert_not_called() + mock_optimization_problem.optimizer.register_suggestions.assert_called_once_with(suggestions) + + # Verify plan was submitted + mock_client.submit_plan.assert_called_once() + submitted_plan = mock_client.submit_plan.call_args[0][0] + assert submitted_plan.name == "my_acquire" + + +def test_runner_run_full_cycle(mock_optimization_problem): + """Test run() completes full suggest -> acquire -> ingest cycle across 3 iterations.""" + # Configure for num_points=2: suggest returns 2 items, evaluation_function returns 2 outcomes + mock_optimization_problem.optimizer.suggest.return_value = [ + {"_id": 0, "motor1": 5.0, "motor2": 3.0}, + {"_id": 1, "motor1": 6.0, "motor2": 4.0}, + ] + mock_optimization_problem.evaluation_function.return_value = [ + {"_id": 0, "objective": 1.0}, + {"_id": 1, "objective": 2.0}, + ] + + mock_client = MagicMock(spec=QueueserverClient) + + def capture_callback(on_stop): + mock_client._on_stop = on_stop + + mock_client.start_listener.side_effect = capture_callback + + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + + runner.run(iterations=3, num_points=2) + + # Simulate 3 acquisition completions by invoking the captured callback + for _ in range(3): + current_uid = runner._state.current_uid + uid = f"fake-uid-{_}" + start_doc = {"uid": uid, CORRELATION_UID_KEY: current_uid} + stop_doc = {"uid": uid} + mock_client._on_stop(start_doc, stop_doc) + + assert mock_client.submit_plan.call_count == 3 + assert mock_optimization_problem.optimizer.suggest.call_count == 3 + assert mock_optimization_problem.optimizer.ingest.call_count == 3 + assert mock_optimization_problem.evaluation_function.call_count == 3 + assert runner.is_running is False + + +def test_runner_on_acquisition_complete_raises_on_uid_mismatch(mock_optimization_problem): + """Test _on_acquisition_complete raises RuntimeError when blop_correlation_uid does not match.""" + mock_client = MagicMock(spec=QueueserverClient) + + def capture_callback(on_stop): + mock_client._on_stop = on_stop + + mock_client.start_listener.side_effect = capture_callback + + runner = QueueserverOptimizationRunner( + optimization_problem=mock_optimization_problem, + queueserver_client=mock_client, + ) + + runner.run(iterations=1, num_points=1) + + # Callback with wrong blop_correlation_uid should raise + start_doc = {"uid": "fake-uid", CORRELATION_UID_KEY: "wrong-uid"} + stop_doc = {"uid": "fake-uid"} + + with pytest.raises(RuntimeError, match="current_uid did not match start document"): + mock_client._on_stop(start_doc, stop_doc) From a9e481734001af64a705b8d0fc579628ee195e0c Mon Sep 17 00:00:00 2001 From: thomashopkins32 Date: Wed, 11 Mar 2026 09:55:07 -0400 Subject: [PATCH 13/13] Lower the type constraint on Actuators --- src/blop/protocols.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/blop/protocols.py b/src/blop/protocols.py index 8b9883b7..9373ecab 100644 --- a/src/blop/protocols.py +++ b/src/blop/protocols.py @@ -2,11 +2,26 @@ from dataclasses import dataclass from typing import Any, Generic, Literal, Protocol, TypeVar, runtime_checkable -from bluesky.protocols import EventCollectable, EventPageCollectable, Flyable, NamedMovable, Readable +from bluesky.protocols import EventCollectable, EventPageCollectable, Flyable, HasName, Movable, Readable from bluesky.utils import MsgGenerator, plan + +@runtime_checkable +class MovableHasName(Movable, HasName, Protocol): + """ + A movable that has a name. + + We use this instead of `bluesky.protocols.NamedMovable` since + we do not want to require `HasHints` on the movable. + + A `Movable` and `HasName` is sufficient. `HasHints` should be optional. + """ + + ... + + ID_KEY: Literal["_id"] = "_id" -Actuator = NamedMovable | Flyable +Actuator = MovableHasName | Flyable Sensor = Readable | EventCollectable | EventPageCollectable TActuator = TypeVar("TActuator")