Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
max-parallel: 1
matrix:
os: [ubuntu-latest] # TODO: add windows and macos to matrix
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.11", "3.12"] # xopt >= 3.0 requires python >= 3.11
env:
DISPLAY: ':99.0'
QT_MAC_WANTS_LAYER: 1 # PyQT gui tests involving qtbot interaction on macOS will fail without this
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ dependencies = [
"pillow",
"requests",
"xopt>=2.6.11",
"xopt @ git+https://github.com/xopt-org/Xopt.git@v3.0",
]
# ^ temporarily point to xopt 3.0 brach for testing (revert b4 this patch gets merged to main)

dynamic = ["version"]
[tool.setuptools_scm]
version_file = "src/badger/_version.py"
Expand Down
4 changes: 3 additions & 1 deletion src/badger/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from badger.routine import Routine
from badger.utils import curr_ts_to_str, dump_state

from xopt.vocs import select_best


def check_run_status(active_callback):
while True:
Expand All @@ -25,7 +27,7 @@ def check_run_status(active_callback):
def convert_to_solution(result: DataFrame, routine: Routine):
vocs = routine.vocs
try:
best_idx, _, _ = vocs.select_best(routine.sorted_data, n=1)
best_idx, _, _ = select_best(vocs, routine.sorted_data, n=1)
if best_idx.size > 0:
best_idx = int(best_idx[0]) # convert numpy array to int
if best_idx != len(routine.data) - 1:
Expand Down
4 changes: 3 additions & 1 deletion src/badger/core_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from badger.routine import Routine
from badger.log import configure_process_logging
from xopt.errors import FeasibilityError, XoptError
from xopt.vocs import select_best


logger = logging.getLogger(__name__)

Expand All @@ -30,7 +32,7 @@ def convert_to_solution(result: DataFrame, routine: Routine):
"""
vocs = routine.vocs
try:
best_idx, _, _ = vocs.select_best(routine.sorted_data, n=1)
best_idx, _, _ = select_best(vocs, routine.sorted_data, n=1)
logger.debug(f"Selected best index: {best_idx}")
if best_idx.size > 0:
if best_idx[0] != len(routine.data) - 1:
Expand Down
10 changes: 8 additions & 2 deletions src/badger/gui/components/env_cbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
NoHoverFocusComboBox,
)
from badger.utils import strtobool
from xopt.vocs import VOCS, ConstraintEnum
from xopt.vocs import VOCS
from gest_api.vocs import ContinuousVariable

import logging

LABEL_WIDTH = 96
Expand Down Expand Up @@ -617,7 +619,7 @@ def compose_vocs(self) -> tuple[VOCS, list[str]]:
(rule,) = objective[obj_name]
objectives[obj_name] = rule

constraints: dict[str, list[float | ConstraintEnum]] = {}
constraints: dict[str, list[float]] = {}
critical_constraints: list[str] = []
for constraint in self.con_table.export_data():
con_name = next(iter(constraint))
Expand All @@ -632,6 +634,10 @@ def compose_vocs(self) -> tuple[VOCS, list[str]]:
observables.append(obs_name)

try:
variables = {
k: list(v) if not isinstance(v, ContinuousVariable) else v
for k, v in variables.items()
}
vocs = VOCS(
variables=variables,
objectives=objectives,
Expand Down
4 changes: 2 additions & 2 deletions src/badger/gui/components/pydantic_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def get_parameters_dict(self):


class BadgerPydanticEditor(QTreeWidget):
vocs: VOCS = VOCS()
vocs: VOCS = VOCS(variables={})
defaults: dict[str, Any] = {}
generator_name: str = ""
model_class: type[BaseModel] | None = None
Expand Down Expand Up @@ -687,7 +687,7 @@ def set_params_from_generator(
):
logger.debug(f"vocs: {vocs}")
logger.debug(f"defaults: {defaults}")
self.vocs = vocs or VOCS()
self.vocs = vocs or VOCS(variables={})
self.generator_name = generator_name
self.defaults = defaults

Expand Down
36 changes: 19 additions & 17 deletions src/badger/gui/components/routine_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from PyQt5.QtWidgets import QTableWidgetItem, QPlainTextEdit
from coolname import generate_slug
from xopt import VOCS
from xopt.vocs import random_inputs
from xopt.generators import (
get_generator_defaults,
all_generator_names,
get_generator_dynamic,
)
from xopt.utils import get_local_region
from gest_api.vocs import GreaterThanConstraint, LessThanConstraint
from pydantic import ValidationError

from badger.gui.components.generator_cbox import BadgerAlgoBox
Expand Down Expand Up @@ -66,21 +68,13 @@
ts_float_to_str,
)


import logging

logger = logging.getLogger(__name__)

LABEL_WIDTH = 96
CONS_RELATION_DICT = {
">": "GREATER_THAN",
"<": "LESS_THAN",
}
CONS_RELATION_DICT_INV = {
"GREATER_THAN": ">",
"LESS_THAN": "<",
}

logger = logging.getLogger(__name__)
LABEL_WIDTH = 96


def format_validation_error(e: ValidationError) -> str:
Expand All @@ -93,6 +87,15 @@ def format_validation_error(e: ValidationError) -> str:
return "\n".join(messages)


def extract_constraint_symbol_and_value(constraint):
if isinstance(constraint, GreaterThanConstraint):
return ">", constraint.value
if isinstance(constraint, LessThanConstraint):
return "<", constraint.value
else: # Expand for other constraints if needed
return "", 0


class BadgerRoutinePage(QWidget):
sig_updated = pyqtSignal(str, str) # routine name, routine description
sig_load_template = pyqtSignal(str) # template path
Expand Down Expand Up @@ -523,9 +526,8 @@ def set_options_from_template(self, template_dict: dict[str, Any]):
status[name] = False # selected
constraints.append(cons)
for name, val in vocs.constraints.items():
relation, thres = val
relation, thres = extract_constraint_symbol_and_value(val)
critical = name in critical_constraint_names
relation = CONS_RELATION_DICT_INV[relation]

idx = constraints_names_full.index(name)
if idx == -1:
Expand Down Expand Up @@ -876,9 +878,8 @@ def refresh_ui(self, routine: Routine | None = None, silent: bool = False):
status[name] = False # selected
constraints.append(cons)
for name, val in routine.vocs.constraints.items():
relation, thres = val
relation, thres = extract_constraint_symbol_and_value(val)
critical = name in routine.critical_constraint_names
relation = CONS_RELATION_DICT_INV[relation]

idx = constraints_names_full.index(name)
if idx == -1:
Expand Down Expand Up @@ -1250,7 +1251,8 @@ def add_rand_in_init_table(self, add_rand_config=None, record=True):
# get small region around current point to sample
try:
vocs, _ = self.env_box.compose_vocs()
except Exception:
except Exception as e:
print(str(e))
# Switch to manual mode to allow the user fixing the vocs issue
QMessageBox.warning(
self,
Expand All @@ -1264,8 +1266,8 @@ def add_rand_in_init_table(self, add_rand_config=None, record=True):
random_sample_region = get_local_region(var_curr, vocs, fraction=fraction)
with warnings.catch_warnings(record=True) as caught_warnings:
try:
random_points = vocs.random_inputs(
n_point, custom_bounds=random_sample_region
random_points = random_inputs(
vocs, n_point, custom_bounds=random_sample_region
)
except ValueError:
raise VariableRangeError(
Expand Down
12 changes: 6 additions & 6 deletions src/badger/gui/components/run_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
QVBoxLayout,
QWidget,
)
from xopt import VOCS
from xopt.vocs import VOCS, normalize_inputs, select_best

from badger.archive import archive_run, BADGER_ARCHIVE_ROOT
from badger.gui.components.pydantic_editor import BadgerPydanticEditor
Expand Down Expand Up @@ -532,7 +532,7 @@ def update(self, results: pd.DataFrame) -> None:

def update_curves(self, results=None):
use_time_axis = self.plot_x_axis == 1
normalize_inputs = self.x_plot_y_axis == 1
norm_inputs = self.x_plot_y_axis == 1

if results is not None:
self.routine.data = results
Expand All @@ -552,8 +552,8 @@ def update_curves(self, results=None):
variable_names = self.vocs.variable_names

# if normalize x, normalize using vocs
if normalize_inputs:
input_data = self.vocs.normalize_inputs(data_copy)
if norm_inputs:
input_data = normalize_inputs(self.vocs, data_copy)
else:
input_data = data_copy[variable_names]

Expand Down Expand Up @@ -905,8 +905,8 @@ def load_checkpoint(self):

def jump_to_optimal(self):
try:
best_idx, _, _ = self.routine.vocs.select_best(
self.routine.sorted_data, n=1
best_idx, _, _ = select_best(
self.routine.vocs, self.routine.sorted_data, n=1
)
# print(best_idx, _)
best_idx = int(best_idx[0])
Expand Down
20 changes: 17 additions & 3 deletions src/badger/gui/components/var_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from badger.errors import BadgerInterfaceChannelError
from badger.gui.windows.expandable_message_box import ExpandableMessageBox

from gest_api.vocs import ContinuousVariable

import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -309,7 +311,9 @@ def _convert_bounds_to_tuple(self, bounds: Any) -> dict[str, tuple[float, float]
return {k: (v[0], v[1]) for k, v in bounds.items()}

def update_variables(
self, variables: list[dict[str, tuple[float, float]]], filtered: int = 0
self,
variables: list[dict[str, tuple[float, float] | ContinuousVariable]],
filtered: int = 0,
):
# filtered = 0: completely refresh
# filtered = 1: filtered by keyword
Expand Down Expand Up @@ -364,12 +368,22 @@ def update_variables(
self.setItem(i, 1, item)

_bounds = self.bounds[name]
default_val = (
_bounds.domain[0]
if isinstance(_bounds, ContinuousVariable)
else _bounds[0]
)
sb_lower = RobustSpinBox(
default_value=_bounds[0], lower_bound=vrange[0], upper_bound=vrange[1]
default_value=default_val, lower_bound=vrange[0], upper_bound=vrange[1]
)
sb_lower.valueChanged.connect(self.update_bounds)
default_val = (
_bounds.domain[1]
if isinstance(_bounds, ContinuousVariable)
else _bounds[0]
)
sb_upper = RobustSpinBox(
default_value=_bounds[1], lower_bound=vrange[0], upper_bound=vrange[1]
default_value=default_val, lower_bound=vrange[0], upper_bound=vrange[1]
)
sb_upper.valueChanged.connect(self.update_bounds)
self.setCellWidget(i, 2, sb_lower)
Expand Down
8 changes: 6 additions & 2 deletions src/badger/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,19 @@ def load_or_create_config(cls, config_path: str) -> BadgerConfig:

# Convert each entry in config_data to an instance of Setting
for key, value in config_data.items():
if isinstance(value, dict) and "value" in value:
if isinstance(value, dict):
value_arg = None
if "value" in value:
value_arg = value["value"]

# Convert to Setting instance
config_data[key] = Setting(
display_name=value.get("display_name", key),
description=value.get(
"description",
f"Setting for {key.replace('_', ' ').lower()}",
),
value=value["value"],
value=value_arg,
is_path=value.get("is_path", key),
)
else:
Expand Down
12 changes: 9 additions & 3 deletions src/badger/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from xopt.generators import get_generator, get_generator_defaults
from xopt.resources.testing import TEST_VOCS_BASE
from gest_api.vocs import ExploreObjective


class TestFactory:
Expand All @@ -16,7 +17,7 @@ def test_generator_generation(self):
gen_config = get_generator_defaults(name)
gen_class = get_generator(name)

if name in ["mobo"]:
if name == "mobo":
test_vocs = deepcopy(TEST_VOCS_BASE)
test_vocs.objectives = {"y1": "MINIMIZE", "y2": "MINIMIZE"}
gen_config["reference_point"] = {"y1": 10.0, "y2": 1.0}
Expand All @@ -29,12 +30,17 @@ def test_generator_generation(self):
json.dumps(gen_config)

gen_class(vocs=test_vocs, **gen_config)
elif name in ["bayesian_exploration"]:
elif name == "bayesian_exploration":
test_vocs = deepcopy(TEST_VOCS_BASE)
test_vocs.objectives = {}
test_vocs.observables = ["f"]
json.dumps(gen_config)

gen_class(vocs=test_vocs, **gen_config)
elif name == "latin_hypercube":
test_vocs = deepcopy(TEST_VOCS_BASE)
test_vocs.objectives = {
k: ExploreObjective() for k in test_vocs.objectives.keys()
}
gen_class(vocs=test_vocs, **gen_config)
else:
json.dumps(gen_config)
Expand Down
Loading
Loading