diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 879f741e..61928441 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,7 @@ jobs: max-parallel: 1 matrix: os: [ubuntu-latest] # TODO: add windows and macos to matrix - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.11", "3.12"] # xopt >= 3.0 requires python >= 3.11 env: DISPLAY: ':99.0' QT_MAC_WANTS_LAYER: 1 # PyQT gui tests involving qtbot interaction on macOS will fail without this diff --git a/pyproject.toml b/pyproject.toml index 8ab2e6a3..43569f42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,10 @@ dependencies = [ "pillow", "requests", "xopt>=2.6.11", + "xopt @ git+https://github.com/xopt-org/Xopt.git@v3.0", ] +# ^ temporarily point to xopt 3.0 brach for testing (revert b4 this patch gets merged to main) + dynamic = ["version"] [tool.setuptools_scm] version_file = "src/badger/_version.py" diff --git a/src/badger/core.py b/src/badger/core.py index f889f335..dbb6bd40 100644 --- a/src/badger/core.py +++ b/src/badger/core.py @@ -9,6 +9,8 @@ from badger.routine import Routine from badger.utils import curr_ts_to_str, dump_state +from xopt.vocs import select_best + def check_run_status(active_callback): while True: @@ -25,7 +27,7 @@ def check_run_status(active_callback): def convert_to_solution(result: DataFrame, routine: Routine): vocs = routine.vocs try: - best_idx, _, _ = vocs.select_best(routine.sorted_data, n=1) + best_idx, _, _ = select_best(vocs, routine.sorted_data, n=1) if best_idx.size > 0: best_idx = int(best_idx[0]) # convert numpy array to int if best_idx != len(routine.data) - 1: diff --git a/src/badger/core_subprocess.py b/src/badger/core_subprocess.py index efb86ad8..7cc24209 100644 --- a/src/badger/core_subprocess.py +++ b/src/badger/core_subprocess.py @@ -14,6 +14,8 @@ from badger.routine import Routine from badger.log import configure_process_logging from xopt.errors import FeasibilityError, XoptError +from xopt.vocs import select_best + logger = logging.getLogger(__name__) @@ -30,7 +32,7 @@ def convert_to_solution(result: DataFrame, routine: Routine): """ vocs = routine.vocs try: - best_idx, _, _ = vocs.select_best(routine.sorted_data, n=1) + best_idx, _, _ = select_best(vocs, routine.sorted_data, n=1) logger.debug(f"Selected best index: {best_idx}") if best_idx.size > 0: if best_idx[0] != len(routine.data) - 1: diff --git a/src/badger/gui/components/env_cbox.py b/src/badger/gui/components/env_cbox.py index 944f272d..40d3f61f 100644 --- a/src/badger/gui/components/env_cbox.py +++ b/src/badger/gui/components/env_cbox.py @@ -32,7 +32,9 @@ NoHoverFocusComboBox, ) from badger.utils import strtobool -from xopt.vocs import VOCS, ConstraintEnum +from xopt.vocs import VOCS +from gest_api.vocs import ContinuousVariable + import logging LABEL_WIDTH = 96 @@ -617,7 +619,7 @@ def compose_vocs(self) -> tuple[VOCS, list[str]]: (rule,) = objective[obj_name] objectives[obj_name] = rule - constraints: dict[str, list[float | ConstraintEnum]] = {} + constraints: dict[str, list[float]] = {} critical_constraints: list[str] = [] for constraint in self.con_table.export_data(): con_name = next(iter(constraint)) @@ -632,6 +634,10 @@ def compose_vocs(self) -> tuple[VOCS, list[str]]: observables.append(obs_name) try: + variables = { + k: list(v) if not isinstance(v, ContinuousVariable) else v + for k, v in variables.items() + } vocs = VOCS( variables=variables, objectives=objectives, diff --git a/src/badger/gui/components/pydantic_editor.py b/src/badger/gui/components/pydantic_editor.py index 044d70b2..6c1e8a39 100644 --- a/src/badger/gui/components/pydantic_editor.py +++ b/src/badger/gui/components/pydantic_editor.py @@ -560,7 +560,7 @@ def get_parameters_dict(self): class BadgerPydanticEditor(QTreeWidget): - vocs: VOCS = VOCS() + vocs: VOCS = VOCS(variables={}) defaults: dict[str, Any] = {} generator_name: str = "" model_class: type[BaseModel] | None = None @@ -687,7 +687,7 @@ def set_params_from_generator( ): logger.debug(f"vocs: {vocs}") logger.debug(f"defaults: {defaults}") - self.vocs = vocs or VOCS() + self.vocs = vocs or VOCS(variables={}) self.generator_name = generator_name self.defaults = defaults diff --git a/src/badger/gui/components/routine_page.py b/src/badger/gui/components/routine_page.py index 45954175..a9fa5125 100644 --- a/src/badger/gui/components/routine_page.py +++ b/src/badger/gui/components/routine_page.py @@ -15,12 +15,14 @@ from PyQt5.QtWidgets import QTableWidgetItem, QPlainTextEdit from coolname import generate_slug from xopt import VOCS +from xopt.vocs import random_inputs from xopt.generators import ( get_generator_defaults, all_generator_names, get_generator_dynamic, ) from xopt.utils import get_local_region +from gest_api.vocs import GreaterThanConstraint, LessThanConstraint from pydantic import ValidationError from badger.gui.components.generator_cbox import BadgerAlgoBox @@ -66,21 +68,13 @@ ts_float_to_str, ) + import logging logger = logging.getLogger(__name__) -LABEL_WIDTH = 96 -CONS_RELATION_DICT = { - ">": "GREATER_THAN", - "<": "LESS_THAN", -} -CONS_RELATION_DICT_INV = { - "GREATER_THAN": ">", - "LESS_THAN": "<", -} -logger = logging.getLogger(__name__) +LABEL_WIDTH = 96 def format_validation_error(e: ValidationError) -> str: @@ -93,6 +87,15 @@ def format_validation_error(e: ValidationError) -> str: return "\n".join(messages) +def extract_constraint_symbol_and_value(constraint): + if isinstance(constraint, GreaterThanConstraint): + return ">", constraint.value + if isinstance(constraint, LessThanConstraint): + return "<", constraint.value + else: # Expand for other constraints if needed + return "", 0 + + class BadgerRoutinePage(QWidget): sig_updated = pyqtSignal(str, str) # routine name, routine description sig_load_template = pyqtSignal(str) # template path @@ -523,9 +526,8 @@ def set_options_from_template(self, template_dict: dict[str, Any]): status[name] = False # selected constraints.append(cons) for name, val in vocs.constraints.items(): - relation, thres = val + relation, thres = extract_constraint_symbol_and_value(val) critical = name in critical_constraint_names - relation = CONS_RELATION_DICT_INV[relation] idx = constraints_names_full.index(name) if idx == -1: @@ -876,9 +878,8 @@ def refresh_ui(self, routine: Routine | None = None, silent: bool = False): status[name] = False # selected constraints.append(cons) for name, val in routine.vocs.constraints.items(): - relation, thres = val + relation, thres = extract_constraint_symbol_and_value(val) critical = name in routine.critical_constraint_names - relation = CONS_RELATION_DICT_INV[relation] idx = constraints_names_full.index(name) if idx == -1: @@ -1250,7 +1251,8 @@ def add_rand_in_init_table(self, add_rand_config=None, record=True): # get small region around current point to sample try: vocs, _ = self.env_box.compose_vocs() - except Exception: + except Exception as e: + print(str(e)) # Switch to manual mode to allow the user fixing the vocs issue QMessageBox.warning( self, @@ -1264,8 +1266,8 @@ def add_rand_in_init_table(self, add_rand_config=None, record=True): random_sample_region = get_local_region(var_curr, vocs, fraction=fraction) with warnings.catch_warnings(record=True) as caught_warnings: try: - random_points = vocs.random_inputs( - n_point, custom_bounds=random_sample_region + random_points = random_inputs( + vocs, n_point, custom_bounds=random_sample_region ) except ValueError: raise VariableRangeError( diff --git a/src/badger/gui/components/run_monitor.py b/src/badger/gui/components/run_monitor.py index 113c0873..4c7d0f9c 100644 --- a/src/badger/gui/components/run_monitor.py +++ b/src/badger/gui/components/run_monitor.py @@ -23,7 +23,7 @@ QVBoxLayout, QWidget, ) -from xopt import VOCS +from xopt.vocs import VOCS, normalize_inputs, select_best from badger.archive import archive_run, BADGER_ARCHIVE_ROOT from badger.gui.components.pydantic_editor import BadgerPydanticEditor @@ -532,7 +532,7 @@ def update(self, results: pd.DataFrame) -> None: def update_curves(self, results=None): use_time_axis = self.plot_x_axis == 1 - normalize_inputs = self.x_plot_y_axis == 1 + norm_inputs = self.x_plot_y_axis == 1 if results is not None: self.routine.data = results @@ -552,8 +552,8 @@ def update_curves(self, results=None): variable_names = self.vocs.variable_names # if normalize x, normalize using vocs - if normalize_inputs: - input_data = self.vocs.normalize_inputs(data_copy) + if norm_inputs: + input_data = normalize_inputs(self.vocs, data_copy) else: input_data = data_copy[variable_names] @@ -905,8 +905,8 @@ def load_checkpoint(self): def jump_to_optimal(self): try: - best_idx, _, _ = self.routine.vocs.select_best( - self.routine.sorted_data, n=1 + best_idx, _, _ = select_best( + self.routine.vocs, self.routine.sorted_data, n=1 ) # print(best_idx, _) best_idx = int(best_idx[0]) diff --git a/src/badger/gui/components/var_table.py b/src/badger/gui/components/var_table.py index 006c48ba..5773a18f 100644 --- a/src/badger/gui/components/var_table.py +++ b/src/badger/gui/components/var_table.py @@ -31,6 +31,8 @@ from badger.errors import BadgerInterfaceChannelError from badger.gui.windows.expandable_message_box import ExpandableMessageBox +from gest_api.vocs import ContinuousVariable + import logging logger = logging.getLogger(__name__) @@ -309,7 +311,9 @@ def _convert_bounds_to_tuple(self, bounds: Any) -> dict[str, tuple[float, float] return {k: (v[0], v[1]) for k, v in bounds.items()} def update_variables( - self, variables: list[dict[str, tuple[float, float]]], filtered: int = 0 + self, + variables: list[dict[str, tuple[float, float] | ContinuousVariable]], + filtered: int = 0, ): # filtered = 0: completely refresh # filtered = 1: filtered by keyword @@ -364,12 +368,22 @@ def update_variables( self.setItem(i, 1, item) _bounds = self.bounds[name] + default_val = ( + _bounds.domain[0] + if isinstance(_bounds, ContinuousVariable) + else _bounds[0] + ) sb_lower = RobustSpinBox( - default_value=_bounds[0], lower_bound=vrange[0], upper_bound=vrange[1] + default_value=default_val, lower_bound=vrange[0], upper_bound=vrange[1] ) sb_lower.valueChanged.connect(self.update_bounds) + default_val = ( + _bounds.domain[1] + if isinstance(_bounds, ContinuousVariable) + else _bounds[0] + ) sb_upper = RobustSpinBox( - default_value=_bounds[1], lower_bound=vrange[0], upper_bound=vrange[1] + default_value=default_val, lower_bound=vrange[0], upper_bound=vrange[1] ) sb_upper.valueChanged.connect(self.update_bounds) self.setCellWidget(i, 2, sb_lower) diff --git a/src/badger/settings.py b/src/badger/settings.py index 7438f9d8..184529fe 100644 --- a/src/badger/settings.py +++ b/src/badger/settings.py @@ -157,7 +157,11 @@ def load_or_create_config(cls, config_path: str) -> BadgerConfig: # Convert each entry in config_data to an instance of Setting for key, value in config_data.items(): - if isinstance(value, dict) and "value" in value: + if isinstance(value, dict): + value_arg = None + if "value" in value: + value_arg = value["value"] + # Convert to Setting instance config_data[key] = Setting( display_name=value.get("display_name", key), @@ -165,7 +169,7 @@ def load_or_create_config(cls, config_path: str) -> BadgerConfig: "description", f"Setting for {key.replace('_', ' ').lower()}", ), - value=value["value"], + value=value_arg, is_path=value.get("is_path", key), ) else: diff --git a/src/badger/tests/test_factory.py b/src/badger/tests/test_factory.py index ef9ab54a..74732d57 100644 --- a/src/badger/tests/test_factory.py +++ b/src/badger/tests/test_factory.py @@ -3,6 +3,7 @@ from xopt.generators import get_generator, get_generator_defaults from xopt.resources.testing import TEST_VOCS_BASE +from gest_api.vocs import ExploreObjective class TestFactory: @@ -16,7 +17,7 @@ def test_generator_generation(self): gen_config = get_generator_defaults(name) gen_class = get_generator(name) - if name in ["mobo"]: + if name == "mobo": test_vocs = deepcopy(TEST_VOCS_BASE) test_vocs.objectives = {"y1": "MINIMIZE", "y2": "MINIMIZE"} gen_config["reference_point"] = {"y1": 10.0, "y2": 1.0} @@ -29,12 +30,17 @@ def test_generator_generation(self): json.dumps(gen_config) gen_class(vocs=test_vocs, **gen_config) - elif name in ["bayesian_exploration"]: + elif name == "bayesian_exploration": test_vocs = deepcopy(TEST_VOCS_BASE) test_vocs.objectives = {} test_vocs.observables = ["f"] json.dumps(gen_config) - + gen_class(vocs=test_vocs, **gen_config) + elif name == "latin_hypercube": + test_vocs = deepcopy(TEST_VOCS_BASE) + test_vocs.objectives = { + k: ExploreObjective() for k in test_vocs.objectives.keys() + } gen_class(vocs=test_vocs, **gen_config) else: json.dumps(gen_config) diff --git a/src/badger/tests/test_routine_page.py b/src/badger/tests/test_routine_page.py index 8d2be96e..bcee0c60 100644 --- a/src/badger/tests/test_routine_page.py +++ b/src/badger/tests/test_routine_page.py @@ -1,8 +1,15 @@ +import json import pandas as pd import pytest from pytestqt.qtbot import QtBot from PyQt5.QtCore import Qt, QTimer from PyQt5.QtWidgets import QApplication +from gest_api.vocs import ( + LessThanConstraint, + ContinuousVariable, + MinimizeObjective, + Observable, +) def test_routine_page_init(qtbot: QtBot): @@ -60,8 +67,10 @@ def test_routine_generation(qtbot: QtBot): qtbot.mouseClick(window.env_box.btn_add_curr, Qt.LeftButton) routine = window._compose_routine() - assert routine.vocs.variables == {"x0": (-1, 1)} - assert routine.vocs.objectives == {"f": "MINIMIZE"} + assert routine.vocs.variables == { + "x0": ContinuousVariable(dtype=None, default_value=None, domain=[-1.0, 1.0]) + } + assert routine.vocs.objectives == {"f": MinimizeObjective(dtype=None)} # assert routine.initial_points.empty # Test if badger and xopt version match with the current version @@ -156,10 +165,21 @@ def test_ui_update(qtbot: QtBot): idx = window.generators.index(routine.generator.name) window.select_generator(idx) - assert ( - window.generator_box.edit.get_parameters_yaml() - == '{"vocs":{"variables":{"x0":"(-1.0, 1.0)","x1":"(-1.0, 1.0)","x2":"(-1.0, 1.0)","x3":"(-1.0, 1.0)"},"constraints":{"c":"(\'GREATER_THAN\', 0.0)"},"objectives":{"f":"MAXIMIZE"},"constants":{},"observables":[]}}' - ) + expected_params = { + "returns_id": False, + "vocs": { + "variables": "{'x0': {'dtype': None, 'default_value': None, 'domain': [-1.0, 1.0], 'type': 'ContinuousVariable'}, " + "'x1': {'dtype': None, 'default_value': None, 'domain': [-1.0, 1.0], 'type': 'ContinuousVariable'}, " + "'x2': {'dtype': None, 'default_value': None, 'domain': [-1.0, 1.0], 'type': 'ContinuousVariable'}, " + "'x3': {'dtype': None, 'default_value': None, 'domain': [-1.0, 1.0], 'type': 'ContinuousVariable'}}", + "constraints": "{'c': {'dtype': None, 'value': 0.0, 'type': 'GreaterThanConstraint'}}", + "objectives": "{'f': {'dtype': None, 'type': 'MaximizeObjective'}}", + "constants": "{}", + "observables": "{}", + }, + } + actual_params = json.loads(window.generator_box.edit.get_parameters_yaml()) + assert actual_params == expected_params def test_constraints(qtbot: QtBot): @@ -181,7 +201,7 @@ def test_constraints(qtbot: QtBot): con_widget_critical.setChecked(True) routine = window._compose_routine() - assert routine.vocs.constraints == {"c": ("LESS_THAN", 0.0)} + assert routine.vocs.constraints == {"c": LessThanConstraint(dtype=None, value=0.0)} assert routine.critical_constraint_names == ["c"] @@ -204,7 +224,7 @@ def test_observables(qtbot: QtBot): window.env_box.sta_table.cellWidget(1, 0).setChecked(True) routine = window._compose_routine() - assert routine.vocs.observables == ["c"] + assert routine.vocs.observables == {"c": Observable(dtype=None)} def test_add_random_points(qtbot: QtBot):