diff --git a/flaml/tune/searcher/blendsearch.py b/flaml/tune/searcher/blendsearch.py index 0d264fcbde..f76b276f56 100644 --- a/flaml/tune/searcher/blendsearch.py +++ b/flaml/tune/searcher/blendsearch.py @@ -3,6 +3,8 @@ # * Licensed under the MIT License. See LICENSE file in the # * project root for license information. from typing import Dict, Optional, List, Tuple, Callable, Union +from collections import defaultdict +from functools import cmp_to_key import numpy as np import time import pickle @@ -26,7 +28,7 @@ from .flow2 import FLOW2 from ..space import add_cost_to_space, indexof, normalize, define_by_run_func from ..result import TIME_TOTAL_S - +from ..utils import get_lexico_bound import logging SEARCH_THREAD_EPS = 1.0 @@ -150,6 +152,7 @@ def __init__( """ self._eps = SEARCH_THREAD_EPS self._input_cost_attr = cost_attr + self.lexico_objectives = lexico_objectives if cost_attr == "auto": if time_budget_s is not None: self.cost_attr = TIME_TOTAL_S @@ -161,8 +164,9 @@ def __init__( self._cost_budget = cost_budget self.penalty = PENALTY # penalty term for constraints self._metric, self._mode = metric, mode + self.f_best = None + self.histories = None self._use_incumbent_result_in_evaluation = use_incumbent_result_in_evaluation - self.lexico_objectives = lexico_objectives init_config = low_cost_partial_config or {} if not init_config: logger.info( @@ -194,9 +198,14 @@ def __init__( self._config_constraints = config_constraints self._metric_constraints = metric_constraints if metric_constraints: - assert all(x[1] in ["<=", ">="] for x in metric_constraints), "sign of metric constraints must be <= or >=." - # metric modified by lagrange - metric += self.lagrange + if self.lexico_objectives: + raise ValueError("Metric constraints should be provided via targets in lexicographic objectives.") + else: + assert all( + x[1] in ["<=", ">="] for x in metric_constraints + ), "sign of metric constraints must be <= or >=." + # metric modified by lagrange + metric += self.lagrange self._cat_hp_cost = cat_hp_cost or {} if space: add_cost_to_space(space, init_config, self._cat_hp_cost) @@ -225,10 +234,18 @@ def __init__( gs_space = space gs_seed = seed - 10 if (seed - 10) >= 0 else seed - 11 + (1 << 32) self._gs_seed = gs_seed + if self.lexico_objectives: + metric, mode = self.lexico_objectives["metrics"], self.lexico_objectives["modes"] if experimental: import optuna as ot - sampler = ot.samplers.TPESampler(seed=gs_seed, multivariate=True, group=True) + if not self.lexico_objectives: + sampler = ot.samplers.TPESampler(seed=gs_seed, multivariate=True, group=True) + else: + n_startup_trials = 11 * len(self.lexico_objectives["metrics"]) - 1 + sampler = ot.samplers.MOTPESampler( + seed=gs_seed, n_startup_trials=n_startup_trials, n_ehvi_candidates=24 + ) else: sampler = None try: @@ -250,6 +267,8 @@ def __init__( seed=gs_seed, sampler=sampler, ) + if self.lexico_objectives: + self._gs.lexico_objectives = self.lexico_objectives self._gs.space = space else: self._gs = None @@ -265,6 +284,33 @@ def __init__( if space is not None: self._init_search() + def update_fbest( + self, + ): + obj_initial = self.lexico_objectives["metrics"][0] + feasible_index = np.array([*range(len(self.histories[obj_initial]))]) + for k_metric in self.lexico_objectives["metrics"]: + k_values = np.array(self.histories[k_metric]) + feasible_value = k_values.take(feasible_index) + self.f_best[k_metric] = np.min(feasible_value) + if not isinstance(self.lexico_objectives["tolerances"][k_metric], str): + tolerance_bound = self.f_best[k_metric] + self.lexico_objectives["tolerances"][k_metric] + else: + assert ( + self.lexico_objectives["tolerances"][k_metric][-1] == "%" + ), "String tolerance of {} should use %% as the suffix".format(k_metric) + tolerance_bound = self.f_best[k_metric] * ( + 1 + 0.01 * float(self.lexico_objectives["tolerances"][k_metric].replace("%", "")) + ) + feasible_index_filter = np.where( + feasible_value + <= max( + tolerance_bound, + self.lexico_objectives["targets"][k_metric], + ) + )[0] + feasible_index = feasible_index.take(feasible_index_filter) + def set_search_properties( self, metric: Optional[str] = None, @@ -305,6 +351,8 @@ def set_search_properties( mode=mode, seed=self._gs_seed, ) + if self.lexico_objectives: + self._gs.lexico_objectives = self.lexico_objective self._gs.space = self._ls.space self._init_search() if spec: @@ -346,7 +394,8 @@ def _init_search(self): self._set_deadline() self._is_ls_ever_converged = False self._subspace = {} # the subspace for each trial id - self._metric_target = np.inf * self._ls.metric_op + if not self.lexico_objectives: + self._metric_target = np.inf * self._ls.metric_op self._search_thread_pool = { # id: int -> thread: SearchThread 0: SearchThread(self._ls.mode, self._gs, self.cost_attr, self._eps) @@ -440,6 +489,12 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: self._search_thread_pool[thread_id].on_trial_complete(trial_id, result, error) del self._trial_proposed_by[trial_id] if result: + if self.lexico_objectives: + if self.histories is None: + self.histories, self.f_best = defaultdict(list), {} + for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): + self.histories[k_metric].append(result[k_metric] if k_mode == "min" else -1 * result[k_metric]) + self.update_fbest() config = result.get("config", {}) if not config: for key, value in result.items(): @@ -454,11 +509,12 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: self._cost_used += result.get(self.cost_attr, 0) self._result[signature] = result # update target metric if improved - objective = result[self._ls.metric] - if (objective - self._metric_target) * self._ls.metric_op < 0: - self._metric_target = objective - if self._ls.resource: - self._best_resource = config[self._ls.resource_attr] + if not self.lexico_objectives: + objective = result[self._ls.metric] + if (objective - self._metric_target) * self._ls.metric_op < 0: + self._metric_target = objective + if self._ls.resource: + self._best_resource = config[self._ls.resource_attr] if thread_id: if not self._metric_constraint_satisfied: # no point has been found to satisfy metric constraint @@ -477,7 +533,7 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: # thread creator thread_id = self._thread_count self._started_from_given = self._candidate_start_points and trial_id in self._candidate_start_points - if self._started_from_given: + if self._started_from_given: # check del self._candidate_start_points[trial_id] else: self._started_from_low_cost = True @@ -495,7 +551,7 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: del self._subspace[trial_id] def _create_thread(self, config, result, space): - if self.lexico_objectives is None: + if not self.lexico_objectives: obj = result[self._ls.metric] else: obj = {k: result[k] for k in self.lexico_objectives["metrics"]} @@ -564,8 +620,17 @@ def _create_condition(self, result: Dict) -> bool: """create thread condition""" if len(self._search_thread_pool) < 2: return True - obj_median = np.median([thread.obj_best1 for id, thread in self._search_thread_pool.items() if id]) - return result[self._ls.metric] * self._ls.metric_op < obj_median + if not self.lexico_objectives: + obj_median = np.median([thread.obj_best1 for id, thread in self._search_thread_pool.items() if id]) + return result[self._ls.metric] * self._ls.metric_op < obj_median + else: + thread_pools = sorted( + [thread.obj_best1 for id, thread in self._search_thread_pool.items() if id], + key=cmp_to_key(self._lexico_inferior), + ) + obj_median = thread_pools[round(len(thread_pools) / 2)] + result = self._unify_op(result) + return self._lexico_inferior(obj_median, result) def _clean(self, thread_id: int): """delete thread and increase admissible region if converged, @@ -591,7 +656,7 @@ def _clean(self, thread_id: int): self._ls_bound_max, self._search_thread_pool[thread_id].space, ) - if self._candidate_start_points: + if self._candidate_start_points and not self.lexico_objectives: if not self._started_from_given: # remove start points whose perf is worse than the converged obj = self._search_thread_pool[thread_id].obj_best1 @@ -615,9 +680,11 @@ def _create_thread_from_best_candidate(self): best_trial_id = None obj_best = None for trial_id, r in self._candidate_start_points.items(): - if r and (best_trial_id is None or r[self._ls.metric] * self._ls.metric_op < obj_best): - best_trial_id = trial_id - obj_best = r[self._ls.metric] * self._ls.metric_op + if r: + r = self._unify_op(r) + if best_trial_id is None or self._lexico_inferior(obj_best, r): + best_trial_id = trial_id + obj_best = r if best_trial_id: # create a new thread config = {} @@ -643,11 +710,39 @@ def _expand_admissible_region(self, lower, upper, space): upper[key] += self._ls.STEPSIZE lower[key] -= self._ls.STEPSIZE + def _lexico_inferior(self, obj_1: Union[dict, float], obj_2: Union[dict, float]) -> bool: + if self.lexico_objectives: + for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): + bound = get_lexico_bound(k_metric, k_mode, self.lexico_objectives, self.f_best) + if (obj_1[k_metric] < bound) and (obj_2[k_metric] < bound): + continue + elif obj_1[k_metric] < obj_2[k_metric]: + return False + else: + return True + for k_metr in self.lexico_objectives["metrics"]: + if obj_1[k_metr] == obj_2[k_metr]: + continue + elif obj_1[k_metr] < obj_2[k_metr]: + return False + else: + return True + else: + return obj_1 > obj_2 + + def _unify_op(self, result: Union[dict, float]): + if isinstance(result, dict): + for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): + result[k_metric] = -1 * result[k_metric] if k_mode == "max" else result[k_metric] + else: + result[self._ls.metric] = result[self._ls.metric] * self._ls.metric_op + return result + def _inferior(self, id1: int, id2: int) -> bool: """whether thread id1 is inferior to id2""" t1 = self._search_thread_pool[id1] t2 = self._search_thread_pool[id2] - if t1.obj_best1 < t2.obj_best2: + if self._lexico_inferior(t2.obj_best2, t1.obj_best1): return False elif t1.resource and t1.resource < t2.resource: return False @@ -664,6 +759,12 @@ def on_trial_result(self, trial_id: str, result: Dict): return if result and self._metric_constraints: result[self._metric + self.lagrange] = result[self._metric] + if self.lexico_objectives: + if self.histories is None: + self.histories, self.f_best = defaultdict(list), {} + for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): + self.histories[k_metric].append(result[k_metric] if k_mode == "min" else -1 * result[k_metric]) + self.update_fbest() self._search_thread_pool[thread_id].on_trial_result(trial_id, result) def suggest(self, trial_id: str) -> Optional[Dict]: @@ -748,6 +849,10 @@ def suggest(self, trial_id: str) -> Optional[Dict]: self._subspace[trial_id] = space else: # use init config if self._candidate_start_points is not None and self._points_to_evaluate: + if self.lexico_objectives: + raise ValueError( + "Providing points_to_evaluate in lexicographic optimization is not supported for now." + ) self._candidate_start_points[trial_id] = None reward = None if self._points_to_evaluate: @@ -806,11 +911,17 @@ def _violate_config_constriants(self, config, config_signature): or sign == "<" and value > threshold ): - self._result[config_signature] = { - self._metric: np.inf * self._ls.metric_op, - "time_total_s": 1, - } - return True + if self.lexico_objectives: + for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): + self._result[config_signature] = {} + self._result[config_signature][k_metric] = np.inf * -1 if k_mode == "max" else np.inf + self._result[config_signature]["time_total_s"] = 1 + else: + self._result[config_signature] = { + self._metric: np.inf * self._ls.metric_op, + "time_total_s": 1, + } + return True return False def _should_skip(self, choice, trial_id, config, space) -> bool: @@ -866,26 +977,46 @@ def _select_thread(self) -> Tuple: num_proposed = num_finished + len(self._trial_proposed_by) min_eci = max(self._num_samples - num_proposed, 0) # update priority - max_speed = 0 - for thread in self._search_thread_pool.values(): - if thread.speed > max_speed: - max_speed = thread.speed + if not self.lexico_objectives: + max_speed = 0 + min_speed = float("inf") + for thread in self._search_thread_pool.values(): + if thread.speed > max_speed: + max_speed = thread.speed + if thread.speed < min_speed: + min_speed = thread.speed + else: + max_speed, min_speed = {k: 0 for k in self.lexico_objectives["metrics"]}, { + k: float("inf") for k in self.lexico_objectives["metrics"] + } + for k_metric in self.lexico_objectives["metrics"]: + for thread in self._search_thread_pool.values(): + if thread.speed[k_metric] > max_speed[k_metric]: + max_speed[k_metric] = thread.speed[k_metric] + if thread.speed[k_metric] < min_speed[k_metric]: + min_speed[k_metric] = thread.speed[k_metric] for thread in self._search_thread_pool.values(): - thread.update_eci(self._metric_target, max_speed) + if not self.lexico_objectives: + thread.update_eci(self._metric_target, max_speed, min_speed) + else: + _metric_1st = self.lexico_objectives["metrics"][0] + _op_1st = self.lexico_objectives["modes"][0] + _lexico_target = self.f_best[_metric_1st] if _op_1st == "min" else -1 * self.f_best[_metric_1st] + thread.update_eci(_lexico_target, max_speed, min_speed) if thread.eci < min_eci: min_eci = thread.eci for thread in self._search_thread_pool.values(): thread.update_priority(min_eci) - top_thread_id = backup_thread_id = 0 + # how to compare global search priority with local search priority1 = priority2 = self._search_thread_pool[0].priority for thread_id, thread in self._search_thread_pool.items(): if thread_id and thread.can_suggest: priority = thread.priority - if priority > priority1: + if not self._lexico_inferior(priority, priority1): priority1 = priority top_thread_id = thread_id - if priority > priority2 or backup_thread_id == 0: + if not self._lexico_inferior(priority, priority2) or backup_thread_id == 0: priority2 = priority backup_thread_id = thread_id return top_thread_id, backup_thread_id @@ -1053,6 +1184,8 @@ def update_search_space(self, search_space): mode=self._mode, sampler=self._gs._sampler, ) + if self.lexico_objectives: + self._gs.lexico_objectives = self.lexico_objectives self._gs.space = config self._init_search() diff --git a/flaml/tune/searcher/flow2.py b/flaml/tune/searcher/flow2.py index fc9d5212dc..a742047404 100644 --- a/flaml/tune/searcher/flow2.py +++ b/flaml/tune/searcher/flow2.py @@ -122,7 +122,7 @@ def __init__( self.resource_attr = resource_attr self.min_resource = min_resource self.lexico_objectives = lexico_objectives - if self.lexico_objectives is not None: + if self.lexico_objectives: if "modes" not in self.lexico_objectives.keys(): self.lexico_objectives["modes"] = ["min"] * len(self.lexico_objectives["metrics"]) for t_metric, t_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): @@ -134,9 +134,9 @@ def __init__( self.cost_attr = cost_attr self.max_resource = max_resource self._resource = None - self._f_best = None # only use for lexico_comapre. It represent the best value achieved by lexico_flow. + self.f_best = None # only use for lexico_comapre. It represent the best value achieved by lexico_flow. self._step_lb = np.Inf - self._histories = None # only use for lexico_comapre. It records the result of historical configurations. + self.histories = None # only use for lexico_comapre. It records the result of historical configurations. if space is not None: self._init_search() @@ -337,18 +337,18 @@ def update_fbest( self, ): obj_initial = self.lexico_objectives["metrics"][0] - feasible_index = np.array([*range(len(self._histories[obj_initial]))]) + feasible_index = np.array([*range(len(self.histories[obj_initial]))]) for k_metric in self.lexico_objectives["metrics"]: - k_values = np.array(self._histories[k_metric]) + k_values = np.array(self.histories[k_metric]) feasible_value = k_values.take(feasible_index) - self._f_best[k_metric] = np.min(feasible_value) + self.f_best[k_metric] = np.min(feasible_value) if not isinstance(self.lexico_objectives["tolerances"][k_metric], str): - tolerance_bound = self._f_best[k_metric] + self.lexico_objectives["tolerances"][k_metric] + tolerance_bound = self.f_best[k_metric] + self.lexico_objectives["tolerances"][k_metric] else: assert ( self.lexico_objectives["tolerances"][k_metric][-1] == "%" ), "String tolerance of {} should use %% as the suffix".format(k_metric) - tolerance_bound = self._f_best[k_metric] * ( + tolerance_bound = self.f_best[k_metric] * ( 1 + 0.01 * float(self.lexico_objectives["tolerances"][k_metric].replace("%", "")) ) feasible_index_filter = np.where( @@ -361,15 +361,15 @@ def update_fbest( feasible_index = feasible_index.take(feasible_index_filter) def lexico_compare(self, result) -> bool: - if self._histories is None: - self._histories, self._f_best = defaultdict(list), {} + if self.histories is None: + self.histories, self.f_best = defaultdict(list), {} for k in self.lexico_objectives["metrics"]: - self._histories[k].append(result[k]) + self.histories[k].append(result[k]) self.update_fbest() return True else: for k in self.lexico_objectives["metrics"]: - self._histories[k].append(result[k]) + self.histories[k].append(result[k]) self.update_fbest() for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): k_target = ( @@ -378,12 +378,12 @@ def lexico_compare(self, result) -> bool: else -self.lexico_objectives["targets"][k_metric] ) if not isinstance(self.lexico_objectives["tolerances"][k_metric], str): - tolerance_bound = self._f_best[k_metric] + self.lexico_objectives["tolerances"][k_metric] + tolerance_bound = self.f_best[k_metric] + self.lexico_objectives["tolerances"][k_metric] else: assert ( self.lexico_objectives["tolerances"][k_metric][-1] == "%" ), "String tolerance of {} should use %% as the suffix".format(k_metric) - tolerance_bound = self._f_best[k_metric] * ( + tolerance_bound = self.f_best[k_metric] * ( 1 + 0.01 * float(self.lexico_objectives["tolerances"][k_metric].replace("%", "")) ) if (result[k_metric] < max(tolerance_bound, k_target)) and ( @@ -416,7 +416,7 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: if not error and result: obj = ( result.get(self._metric) - if self.lexico_objectives is None + if not self.lexico_objectives else {k: result[k] for k in self.lexico_objectives["metrics"]} ) if obj: @@ -473,7 +473,7 @@ def on_trial_result(self, trial_id: str, result: Dict): if result: obj = ( result.get(self._metric) - if self.lexico_objectives is None + if not self.lexico_objectives else {k: result[k] for k in self.lexico_objectives["metrics"]} ) if obj: diff --git a/flaml/tune/searcher/search_thread.py b/flaml/tune/searcher/search_thread.py index f0488c8181..dd2408b05b 100644 --- a/flaml/tune/searcher/search_thread.py +++ b/flaml/tune/searcher/search_thread.py @@ -2,9 +2,10 @@ # * Copyright (c) Microsoft Corporation. All rights reserved. # * Licensed under the MIT License. See LICENSE file in the # * project root for license information. -from typing import Dict, Optional +from typing import Dict, Optional, Union import numpy as np + try: from ray import __version__ as ray_version @@ -18,7 +19,9 @@ from .flow2 import FLOW2 from ..space import add_cost_to_space, unflatten_hierarchical from ..result import TIME_TOTAL_S +from ..utils import get_lexico_bound import logging +from collections import defaultdict logger = logging.getLogger(__name__) @@ -41,11 +44,10 @@ def __init__( self.cost_best = self.cost_last = self.cost_total = self.cost_best1 = getattr(search_alg, "cost_incumbent", 0) self._eps = eps self.cost_best2 = 0 - self.obj_best1 = self.obj_best2 = getattr(search_alg, "best_obj", np.inf) # inherently minimize + self.lexico_objectives = getattr(self._search_alg, "lexico_objectives", None) self.best_result = None # eci: estimated cost for improvement self.eci = self.cost_best - self.priority = self.speed = 0 self._init_config = True self.running = 0 # the number of running trials from the thread self.cost_attr = cost_attr @@ -55,6 +57,22 @@ def __init__( # remember const config self._const = add_cost_to_space(self.space, {}, {}) + if self.lexico_objectives: + # lexicographic tuning setting + self.f_best, self.histories = {}, defaultdict(list) # only use for lexico_comapre. + self.obj_best1 = self.obj_best2 = {} + for k_metric in self.lexico_objectives["metrics"]: + self.obj_best1[k_metric] = self.obj_best2[k_metric] = ( + np.inf if getattr(search_alg, "best_obj", None) is None else search_alg.best_obj[k_metric] + ) + self.priority, self.speed = {}, {} + for k_metric in self.lexico_objectives["metrics"]: + self.priority[k_metric] = self.speed[k_metric] = 0 + else: + # normal tuning setting + self.obj_best1 = self.obj_best2 = getattr(search_alg, "best_obj", np.inf) # inherently minimize + self.priority = self.speed = 0 + def suggest(self, trial_id: str) -> Optional[Dict]: """Use the suggest() of the underlying search algorithm.""" if isinstance(self._search_alg, FLOW2): @@ -74,28 +92,120 @@ def suggest(self, trial_id: str) -> Optional[Dict]: self.running += 1 return config + def update_lexicoPara(self, result): + # update histories, f_best + if self.lexico_objectives: + for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): + self.histories[k_metric].append(result[k_metric]) if k_mode == "min" else self.histories[ + k_metric + ].append(result[k_metric] * -1) + obj_initial = self.lexico_objectives["metrics"][0] + feasible_index = np.array([*range(len(self.histories[obj_initial]))]) + for k_metric in self.lexico_objectives["metrics"]: + k_values = np.array(self.histories[k_metric]) + feasible_value = k_values.take(feasible_index) + self.f_best[k_metric] = np.min(feasible_value) + if not isinstance(self.lexico_objectives["tolerances"][k_metric], str): + tolerance_bound = self.f_best[k_metric] + self.lexico_objectives["tolerances"][k_metric] + else: + assert ( + self.lexico_objectives["tolerances"][k_metric][-1] == "%" + ), "String tolerance of {} should use %% as the suffix".format(k_metric) + tolerance_bound = self.f_best[k_metric] * ( + 1 + 0.01 * float(self.lexico_objectives["tolerances"][k_metric].replace("%", "")) + ) + feasible_index_filter = np.where( + feasible_value + <= max( + tolerance_bound, + self.lexico_objectives["targets"][k_metric], + ) + )[0] + feasible_index = feasible_index.take(feasible_index_filter) + def update_priority(self, eci: Optional[float] = 0): # optimistic projection - self.priority = eci * self.speed - self.obj_best1 + if self.lexico_objectives: + for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): + self.priority[k_metric] = eci * self.speed[k_metric] - self.obj_best1[k_metric] + else: + self.priority = eci * self.speed - self.obj_best1 - def update_eci(self, metric_target: float, max_speed: Optional[float] = np.inf): + def update_eci(self, metric_target: float, max_speed: Optional[float] = np.inf, min_speed: Optional[float] = 1e-9): # calculate eci: estimated cost for improvement over metric_target - best_obj = metric_target * self._metric_op - if not self.speed: - self.speed = max_speed + if not self.lexico_objectives: + _metric_op = self._metric_op + if not self.speed: + self.speed = max_speed + else: + _metric_1st = self.lexico_objectives["metrics"][0] + _metric_op = 1 if self.lexico_objectives["modes"][0] == "min" else -1 + if self.speed[_metric_1st] == 0: + self.speed[_metric_1st] = max_speed[_metric_1st] + elif self.speed[_metric_1st] == -1: + self.speed[_metric_1st] = min_speed[_metric_1st] + best_obj = metric_target * _metric_op self.eci = max(self.cost_total - self.cost_best1, self.cost_best1 - self.cost_best2) - if self.obj_best1 > best_obj and self.speed > 0: - self.eci = max(self.eci, 2 * (self.obj_best1 - best_obj) / self.speed) + obj_best1 = self.obj_best1 if not self.lexico_objectives else self.obj_best1[_metric_1st] + speed = self.speed if not self.lexico_objectives else self.speed[_metric_1st] + if obj_best1 > best_obj and speed > 0: + self.eci = max(self.eci, 2 * (obj_best1 - best_obj) / speed) + + def _better(self, obj_1: Union[dict, float], obj_2: Union[dict, float]): + if self.lexico_objectives: + for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): + _f_best = self._search_alg.f_best if self._is_ls else self.f_best + bound = get_lexico_bound(k_metric, k_mode, self.lexico_objectives, _f_best) + if (obj_1[k_metric] < bound) and (obj_2[k_metric] < bound): + continue + elif obj_1[k_metric] < obj_2[k_metric]: + return True, k_metric + else: + return False, None + for k_metr in self.lexico_objectives["metrics"]: + if obj_1[k_metr] == obj_2[k_metr]: + continue + elif obj_1[k_metr] < obj_2[k_metr]: + return True, k_metric + else: + return False, None + return False, None + else: + if obj_1 < obj_2: + return True, None + else: + return False, None def _update_speed(self): # calculate speed; use 0 for invalid speed temporarily - if self.obj_best2 > self.obj_best1: - # discount the speed if there are unfinished trials - self.speed = ( - (self.obj_best2 - self.obj_best1) / self.running / (max(self.cost_total - self.cost_best2, self._eps)) - ) + if not self.lexico_objectives: + if self.obj_best1 < self.obj_best2: + self.speed = ( + (self.obj_best2 - self.obj_best1) + / self.running + / (max(self.cost_total - self.cost_best2, self._eps)) + ) + else: + self.speed = 0 + elif (self._is_ls and self._search_alg.histories) or (not self._is_ls and self.histories): + _is_better, _op_dimension = self._better(self.obj_best1, self.obj_best2) + if _is_better: + op_index = self.lexico_objectives["metrics"].index(_op_dimension) + self.speed[_op_dimension] = ( + (self.obj_best2[_op_dimension] - self.obj_best1[_op_dimension]) + / self.running + / (max(self.cost_total - self.cost_best2, self._eps)) + ) + for i in range(0, len(self.lexico_objectives["metrics"])): + if i < op_index: + self.speed[self.lexico_objectives["metrics"][i]] = -1 + elif i > op_index: + self.speed[self.lexico_objectives["metrics"][i]] = 0 + else: + for k_metric in self.lexico_objectives["metrics"]: + self.speed[k_metric] = 0 else: - self.speed = 0 + return def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: bool = False): """Update the statistics of the thread.""" @@ -106,6 +216,8 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: if self._is_ls or not self._init_config: try: self._search_alg.on_trial_complete(trial_id, result, error) + if not self._is_ls: + self.update_lexicoPara(result) except RuntimeError as e: # rs is used in place of optuna sometimes if not str(e).endswith("has already finished and can not be updated."): @@ -117,21 +229,33 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: if result: self.cost_last = result.get(self.cost_attr, 1) self.cost_total += self.cost_last - if self._search_alg.metric in result and (getattr(self._search_alg, "lexico_objectives", None) is None): - # TODO: Improve this behavior. When lexico_objectives is provided to CFO, - # related variables are not callable. - obj = result[self._search_alg.metric] * self._metric_op - if obj < self.obj_best1 or self.best_result is None: + _metric_exists = ( + self._search_alg.metric in result + if not self.lexico_objectives + else all(x in result for x in self.lexico_objectives["metrics"]) + ) + if _metric_exists: + if not self.lexico_objectives: + obj = result[self._search_alg.metric] * self._metric_op + else: + obj = {} + for k, m in zip( + self._search_alg.lexico_objectives["metrics"], self._search_alg.lexico_objectives["modes"] + ): + obj[k] = -1 * result[k] if m == "max" else result[k] + if self.best_result is None or self._better(obj, self.obj_best1)[0]: self.cost_best2 = self.cost_best1 self.cost_best1 = self.cost_total - self.obj_best2 = obj if np.isinf(self.obj_best1) else self.obj_best1 + if not self.lexico_objectives: + self.obj_best2 = obj if np.isinf(self.obj_best1) else self.obj_best1 + else: + self.obj_best2 = ( + obj if np.isinf(self.obj_best1[self.lexico_objectives["metrics"][0]]) else self.obj_best1 + ) self.obj_best1 = obj self.cost_best = self.cost_last self.best_result = result - if getattr(self._search_alg, "lexico_objectives", None) is None: - # TODO: Improve this behavior. When lexico_objectives is provided to CFO, - # related variables are not callable. - self._update_speed() + self._update_speed() self.running -= 1 assert self.running >= 0 @@ -142,6 +266,8 @@ def on_trial_result(self, trial_id: str, result: Dict): if not hasattr(self._search_alg, "_ot_trials") or (trial_id in self._search_alg._ot_trials): try: self._search_alg.on_trial_result(trial_id, result) + if not self._is_ls: + self.update_lexicoPara(result) except RuntimeError as e: # rs is used in place of optuna sometimes if not str(e).endswith("has already finished and can not be updated."): @@ -149,7 +275,6 @@ def on_trial_result(self, trial_id: str, result: Dict): new_cost = result.get(self.cost_attr, 1) if self.cost_last < new_cost: self.cost_last = new_cost - # self._update_speed() @property def converged(self) -> bool: diff --git a/flaml/tune/searcher/suggestion.py b/flaml/tune/searcher/suggestion.py index 518bab9da9..57758a3ce1 100644 --- a/flaml/tune/searcher/suggestion.py +++ b/flaml/tune/searcher/suggestion.py @@ -19,6 +19,7 @@ import functools import warnings import copy +import numpy as np import logging from typing import Any, Dict, Optional, Union, List, Tuple, Callable import pickle @@ -33,6 +34,8 @@ Uniform, ) from ..trial import flatten_dict, unflatten_dict +from packaging import version +from collections import defaultdict logger = logging.getLogger(__name__) @@ -309,6 +312,7 @@ def validate_warmstart( class _OptunaTrialSuggestCaptor: """Utility to capture returned values from Optuna's suggest_ methods. + This will wrap around the ``optuna.Trial` object and decorate all `suggest_` callables with a function capturing the returned value, which will be saved in the ``captured_values`` dict. @@ -338,81 +342,198 @@ def __getattr__(self, item_name: str) -> Any: class OptunaSearch(Searcher): """A wrapper around Optuna to provide trial suggestions. - [Optuna](https://optuna.org/) - is a hyperparameter optimization library. - In contrast to other libraries, it employs define-by-run style - hyperparameter definitions. - This Searcher is a thin wrapper around Optuna's search algorithms. - You can pass any Optuna sampler, which will be used to generate - hyperparameter suggestions. - Args: - space (dict|Callable): Hyperparameter search space definition for - Optuna's sampler. This can be either a class `dict` with - parameter names as keys and ``optuna.distributions`` as values, - or a Callable - in which case, it should be a define-by-run - function using ``optuna.trial`` to obtain the hyperparameter - values. The function should return either a class `dict` of - constant values with names as keys, or None. - For more information, see - [tutorial](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/002_configurations.html). - Warning - No actual computation should take place in the define-by-run + + `Optuna `_ is a hyperparameter optimization library. + In contrast to other libraries, it employs define-by-run style + hyperparameter definitions. + + This Searcher is a thin wrapper around Optuna's search algorithms. + You can pass any Optuna sampler, which will be used to generate + hyperparameter suggestions. + + Multi-objective optimization is supported. + + Args: + space: Hyperparameter search space definition for + Optuna's sampler. This can be either a :class:`dict` with + parameter names as keys and ``optuna.distributions`` as values, + or a Callable - in which case, it should be a define-by-run + function using ``optuna.trial`` to obtain the hyperparameter + values. The function should return either a :class:`dict` of + constant values with names as keys, or None. + For more information, see https://optuna.readthedocs.io\ +/en/stable/tutorial/10_key_features/002_configurations.html. + + .. warning:: + No actual computation should take place in the define-by-run function. Instead, put the training logic inside the function - or class trainable passed to tune.run. - metric (str): The training result objective value attribute. If None - but a mode was passed, the anonymous metric `_metric` will be used - per default. - mode (str): One of {min, max}. Determines whether objective is - minimizing or maximizing the metric attribute. - points_to_evaluate (list): Initial parameter suggestions to be run - first. This is for when you already have some good parameters - you want to run first to help the algorithm make better suggestions - for future parameters. Needs to be a list of dicts containing the - configurations. - sampler (optuna.samplers.BaseSampler): Optuna sampler used to - draw hyperparameter configurations. Defaults to ``TPESampler``. - seed (int): Seed to initialize sampler with. This parameter is only - used when ``sampler=None``. In all other cases, the sampler - you pass should be initialized with the seed already. - evaluated_rewards (list): If you have previously evaluated the - parameters passed in as points_to_evaluate you can avoid - re-running those trials by passing in the reward attributes - as a list so the optimiser can be told the results without - needing to re-compute the trial. Must be the same length as - points_to_evaluate. - - Tune automatically converts search spaces to Optuna's format: - - ````python - from ray.tune.suggest.optuna import OptunaSearch # ray version < 2 - config = { "a": tune.uniform(6, 8), - "b": tune.loguniform(1e-4, 1e-2)} - optuna_search = OptunaSearch(metric="loss", mode="min") - tune.run(trainable, config=config, search_alg=optuna_search) - ```` - - If you would like to pass the search space manually, the code would - look like this: + or class trainable passed to ``tune.run``. + + metric: The training result objective value attribute. If + None but a mode was passed, the anonymous metric ``_metric`` + will be used per default. Can be a list of metrics for + multi-objective optimization. + mode: One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. Can be a list of + modes for multi-objective optimization (corresponding to + ``metric``). + points_to_evaluate: Initial parameter suggestions to be run + first. This is for when you already have some good parameters + you want to run first to help the algorithm make better suggestions + for future parameters. Needs to be a list of dicts containing the + configurations. + sampler: Optuna sampler used to + draw hyperparameter configurations. Defaults to ``MOTPESampler`` + for multi-objective optimization with Optuna<2.9.0, and + ``TPESampler`` in every other case. + + .. warning:: + Please note that with Optuna 2.10.0 and earlier + default ``MOTPESampler``/``TPESampler`` suffer + from performance issues when dealing with a large number of + completed trials (approx. >100). This will manifest as + a delay when suggesting new configurations. + This is an Optuna issue and may be fixed in a future + Optuna release. + + seed: Seed to initialize sampler with. This parameter is only + used when ``sampler=None``. In all other cases, the sampler + you pass should be initialized with the seed already. + evaluated_rewards: If you have previously evaluated the + parameters passed in as points_to_evaluate you can avoid + re-running those trials by passing in the reward attributes + as a list so the optimiser can be told the results without + needing to re-compute the trial. Must be the same length as + points_to_evaluate. + + .. warning:: + When using ``evaluated_rewards``, the search space ``space`` + must be provided as a :class:`dict` with parameter names as + keys and ``optuna.distributions`` instances as values. The + define-by-run search space definition is not yet supported with + this functionality. + + Tune automatically converts search spaces to Optuna's format: + + .. code-block:: python + + from ray.tune.suggest.optuna import OptunaSearch + + config = { + "a": tune.uniform(6, 8) + "b": tune.loguniform(1e-4, 1e-2) + } + + optuna_search = OptunaSearch( + metric="loss", + mode="min") + + tune.run(trainable, config=config, search_alg=optuna_search) + + If you would like to pass the search space manually, the code would + look like this: + + .. code-block:: python + + from ray.tune.suggest.optuna import OptunaSearch + import optuna + + space = { + "a": optuna.distributions.UniformDistribution(6, 8), + "b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2), + } + + optuna_search = OptunaSearch( + space, + metric="loss", + mode="min") + + tune.run(trainable, search_alg=optuna_search) + + # Equivalent Optuna define-by-run function approach: + + def define_search_space(trial: optuna.Trial): + trial.suggest_float("a", 6, 8) + trial.suggest_float("b", 1e-4, 1e-2, log=True) + # training logic goes into trainable, this is just + # for search space definition + + optuna_search = OptunaSearch( + define_search_space, + metric="loss", + mode="min") + + tune.run(trainable, search_alg=optuna_search) + + Multi-objective optimization is supported: + + .. code-block:: python + + from ray.tune.suggest.optuna import OptunaSearch + import optuna + + space = { + "a": optuna.distributions.UniformDistribution(6, 8), + "b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2), + } + + # Note you have to specify metric and mode here instead of + # in tune.run + optuna_search = OptunaSearch( + space, + metric=["loss1", "loss2"], + mode=["min", "max"]) + + # Do not specify metric and mode here! + tune.run( + trainable, + search_alg=optuna_search + ) + + You can pass configs that will be evaluated first using + ``points_to_evaluate``: + + .. code-block:: python + + from ray.tune.suggest.optuna import OptunaSearch + import optuna + + space = { + "a": optuna.distributions.UniformDistribution(6, 8), + "b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2), + } + + optuna_search = OptunaSearch( + space, + points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}] + metric="loss", + mode="min") + + tune.run(trainable, search_alg=optuna_search) + + Avoid re-running evaluated trials by passing the rewards together with + `points_to_evaluate`: + + .. code-block:: python + + from ray.tune.suggest.optuna import OptunaSearch + import optuna + + space = { + "a": optuna.distributions.UniformDistribution(6, 8), + "b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2), + } + + optuna_search = OptunaSearch( + space, + points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}] + evaluated_rewards=[0.89, 0.42] + metric="loss", + mode="min") + + tune.run(trainable, search_alg=optuna_search) - ```python - from ray.tune.suggest.optuna import OptunaSearch # ray version < 2 - import optuna - config = { "a": optuna.distributions.UniformDistribution(6, 8), - "b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2)} - optuna_search = OptunaSearch(space,metric="loss",mode="min") - tune.run(trainable, search_alg=optuna_search) - # Equivalent Optuna define-by-run function approach: - def define_search_space(trial: optuna.Trial): - trial.suggest_float("a", 6, 8) - trial.suggest_float("b", 1e-4, 1e-2, log=True) - # training logic goes into trainable, this is just - # for search space definition - optuna_search = OptunaSearch( - define_search_space, - metric="loss", - mode="min") - tune.run(trainable, search_alg=optuna_search) .. versionadded:: 0.8.8 - ``` """ @@ -425,15 +546,15 @@ def __init__( Callable[["OptunaTrial"], Optional[Dict[str, Any]]], ] ] = None, - metric: Optional[str] = None, - mode: Optional[str] = None, + metric: Optional[Union[str, List[str]]] = None, + mode: Optional[Union[str, List[str]]] = None, points_to_evaluate: Optional[List[Dict]] = None, sampler: Optional["BaseSampler"] = None, seed: Optional[int] = None, evaluated_rewards: Optional[List] = None, ): assert ot is not None, "Optuna must be installed! Run `pip install optuna`." - super(OptunaSearch, self).__init__(metric=metric, mode=mode, max_concurrent=None, use_early_stopped_trials=None) + super(OptunaSearch, self).__init__(metric=metric, mode=mode) if isinstance(space, dict) and space: resolved_vars, domain_vars, grid_vars = parse_spec_vars(space) @@ -457,33 +578,57 @@ def __init__( "`seed` parameter has to be passed to the sampler directly " "and will be ignored." ) + elif sampler: + assert isinstance(sampler, BaseSampler), ( + "You can only pass an instance of " "`optuna.samplers.BaseSampler` " "as a sampler to `OptunaSearcher`." + ) - self._sampler = sampler or ot.samplers.TPESampler(seed=seed) + self._sampler = sampler + self._seed = seed - assert isinstance(self._sampler, BaseSampler), ( - "You can only pass an instance of `optuna.samplers.BaseSampler` " "as a sampler to `OptunaSearcher`." - ) + self._completed_trials = set() self._ot_trials = {} self._ot_study = None if self._space: self._setup_study(mode) - def _setup_study(self, mode: str): + def _setup_study(self, mode: Union[str, list]): if self._metric is None and self._mode: + if isinstance(self._mode, list): + raise ValueError( + "If ``mode`` is a list (multi-objective optimization " "case), ``metric`` must be defined." + ) # If only a mode was passed, use anonymous metric self._metric = DEFAULT_METRIC pruner = ot.pruners.NopPruner() storage = ot.storages.InMemoryStorage() + if self._sampler: + sampler = self._sampler + elif isinstance(mode, list) and version.parse(ot.__version__) < version.parse("2.9.0"): + # MOTPESampler deprecated in Optuna>=2.9.0 + sampler = ot.samplers.MOTPESampler(seed=self._seed) + else: + sampler = ot.samplers.TPESampler(seed=self._seed) + + if isinstance(mode, list): + study_direction_args = dict( + directions=["minimize" if m == "min" else "maximize" for m in mode], + ) + else: + study_direction_args = dict( + direction="minimize" if mode == "min" else "maximize", + ) + self._ot_study = ot.study.create_study( storage=storage, - sampler=self._sampler, + sampler=sampler, pruner=pruner, study_name=self._study_name, - direction="minimize" if mode == "min" else "maximize", load_if_exists=True, + **study_direction_args, ) if self._points_to_evaluate: @@ -500,7 +645,7 @@ def _setup_study(self, mode: str): for point in self._points_to_evaluate: self._ot_study.enqueue_trial(point) - def set_search_properties(self, metric: Optional[str], mode: Optional[str], config: Dict) -> bool: + def set_search_properties(self, metric: Optional[str], mode: Optional[str], config: Dict, **spec) -> bool: if self._space: return False space = self.convert_search_space(config) @@ -510,7 +655,7 @@ def set_search_properties(self, metric: Optional[str], mode: Optional[str], conf if mode: self._mode = mode - self._setup_study(mode) + self._setup_study(self._mode) return True def _suggest_from_define_by_run_func( @@ -553,21 +698,8 @@ def suggest(self, trial_id: str) -> Optional[Dict]: raise RuntimeError( UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__, metric=self._metric, mode=self._mode) ) - - if isinstance(self._space, list): - # Keep for backwards compatibility - # Deprecate: 1.5 - if trial_id not in self._ot_trials: - self._ot_trials[trial_id] = self._ot_study.ask() - - ot_trial = self._ot_trials[trial_id] - - # getattr will fetch the trial.suggest_ function on Optuna trials - params = { - args[0] if len(args) > 0 else kwargs["name"]: getattr(ot_trial, fn)(*args, **kwargs) - for (fn, args, kwargs) in self._space - } - elif callable(self._space): + if callable(self._space): + # Define-by-run case if trial_id not in self._ot_trials: self._ot_trials[trial_id] = self._ot_study.ask() @@ -584,15 +716,36 @@ def suggest(self, trial_id: str) -> Optional[Dict]: return unflatten_dict(params) def on_trial_result(self, trial_id: str, result: Dict): + if isinstance(self.metric, list): + # Optuna doesn't support incremental results + # for multi-objective optimization + return + if trial_id in self._completed_trials: + logger.warning( + f"Received additional result for trial {trial_id}, but " f"it already finished. Result: {result}" + ) + return metric = result[self.metric] step = result[TRAINING_ITERATION] ot_trial = self._ot_trials[trial_id] ot_trial.report(metric, step) def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: bool = False): + if trial_id in self._completed_trials: + logger.warning( + f"Received additional completion for trial {trial_id}, but " f"it already finished. Result: {result}" + ) + return + ot_trial = self._ot_trials[trial_id] - val = result.get(self.metric, None) if result else None + if result: + if isinstance(self.metric, list): + val = [result.get(metric, None) for metric in self.metric] + else: + val = result.get(self.metric, None) + else: + val = None ot_trial_state = OptunaTrialState.COMPLETE if val is None: if error: @@ -601,9 +754,11 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: ot_trial_state = OptunaTrialState.PRUNED try: self._ot_study.tell(ot_trial, val, state=ot_trial_state) - except ValueError as exc: + except Exception as exc: logger.warning(exc) # E.g. if NaN was reported + self._completed_trials.add(trial_id) + def add_evaluated_point( self, parameters: Dict, @@ -618,6 +773,13 @@ def add_evaluated_point( raise RuntimeError( UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__, metric=self._metric, mode=self._mode) ) + if callable(self._space): + raise TypeError( + "Define-by-run function passed in `space` argument is not " + "yet supported when using `evaluated_rewards`. Please provide " + "an `OptunaDistribution` dict or pass a Ray Tune " + "search space to `tune.run()`." + ) ot_trial_state = OptunaTrialState.COMPLETE if error: diff --git a/flaml/tune/tune.py b/flaml/tune/tune.py index 37a91774e8..50adb04969 100644 --- a/flaml/tune/tune.py +++ b/flaml/tune/tune.py @@ -52,14 +52,14 @@ def __init__(self, trials, metric, mode, lexico_objectives=None): @property def best_trial(self) -> Trial: - if self.lexico_objectives is None: + if not self.lexico_objectives: return super().best_trial else: return self.get_best_trial(self.default_metric, self.default_mode) @property def best_config(self) -> Dict: - if self.lexico_objectives is None: + if not self.lexico_objectives: return super().best_config else: return self.get_best_config(self.default_metric, self.default_mode) @@ -82,7 +82,7 @@ def lexico_best(self, trials): for k_metric, k_mode in zip(self.lexico_objectives["metrics"], self.lexico_objectives["modes"]): k_values = np.array(histories[k_metric]) k_target = ( - -self.lexico_objectives["targets"][k_metric] + -1 * self.lexico_objectives["targets"][k_metric] if k_mode == "max" else self.lexico_objectives["targets"][k_metric] ) @@ -110,7 +110,7 @@ def get_best_trial( scope: str = "last", filter_nan_and_inf: bool = True, ) -> Optional[Trial]: - if self.lexico_objectives is not None: + if self.lexico_objectives: best_trial = self.lexico_best(self.trials) else: best_trial = super().get_best_trial(metric, mode, scope, filter_nan_and_inf) @@ -127,19 +127,15 @@ def best_result(self) -> Dict: def report(_metric=None, **kwargs): """A function called by the HPO application to report final or intermediate results. - Example: - ```python import time from flaml import tune - def compute_with_config(config): current_time = time.time() metric2minimize = (round(config['x'])-95000)**2 time2eval = time.time() - current_time tune.report(metric2minimize=metric2minimize, time2eval=time2eval) - analysis = tune.run( compute_with_config, config={ @@ -148,15 +144,12 @@ def compute_with_config(config): }, metric='metric2minimize', mode='min', num_samples=1000000, time_budget_s=60, use_ray=False) - print(analysis.trials[-1].last_result) ``` - Args: _metric: Optional default anonymous metric for ``tune.report(value)``. (For compatibility with ray.tune.report) **kwargs: Any key value pair to be reported. - Raises: StopIteration (when not using ray, i.e., _use_ray=False): A StopIteration exception is raised if the trial has been signaled to stop. @@ -234,11 +227,9 @@ def run( """The function-based way of performing HPO. Example: - ```python import time from flaml import tune - def compute_with_config(config): current_time = time.time() metric2minimize = (round(config['x'])-95000)**2 @@ -250,7 +241,6 @@ def compute_with_config(config): # if the failure indicates a config is bad, # report a bad metric value like np.inf or -np.inf # depending on metric mode being min or max - analysis = tune.run( compute_with_config, config={ @@ -259,10 +249,8 @@ def compute_with_config(config): }, metric='metric2minimize', mode='min', num_samples=-1, time_budget_s=60, use_ray=False) - print(analysis.trials[-1].last_result) ``` - Args: evaluation_function: A user-defined evaluation function. It takes a configuration as input, outputs a evaluation @@ -274,7 +262,6 @@ def compute_with_config(config): low_cost_partial_config: A dictionary from a subset of controlled dimensions to the initial low-cost values. e.g., ```{'n_estimators': 4, 'max_leaves': 4}``` - cat_hp_cost: A dictionary from a subset of categorical dimensions to the relative cost of each choice. e.g., ```{'tree_method': [1, 1, 2]}``` @@ -293,7 +280,6 @@ def compute_with_config(config): needing to re-compute the trial. Must be the same or shorter length than points_to_evaluate. e.g., - ```python points_to_evaluate = [ {"b": .99, "cost_related": {"a": 3}}, @@ -301,10 +287,8 @@ def compute_with_config(config): ] evaluated_rewards = [3.0] ``` - means that you know the reward for the first config in points_to_evaluate is 3.0 and want to inform run(). - resource_attr: A string to specify the resource dimension used by the scheduler via "scheduler". min_resource: A float of the minimal resource to use for the resource_attr. @@ -347,7 +331,6 @@ def easy_objective(config): search_alg: An instance of BlendSearch as the search algorithm to be used. The same instance can be used for iterative tuning. e.g., - ```python from flaml import BlendSearch algo = BlendSearch(metric='val_loss', mode='min', @@ -358,7 +341,6 @@ def easy_objective(config): search_alg=algo, use_ray=False) print(analysis.trials[-1].last_result) ``` - verbose: 0, 1, 2, or 3. If ray or spark backend is used, their verbosity will be affected by this argument. 0 = silent, 1 = only status updates, 2 = status and brief trial results, 3 = status and detailed trial results. @@ -372,7 +354,6 @@ def easy_objective(config): [parallel tuning](../../Use-Cases/Tune-User-Defined-Function#parallel-tuning). config_constraints: A list of config constraints to be satisfied. e.g., ```config_constraints = [(mem_size, '<=', 1024**3)]``` - mem_size is a function which produces a float number for the bytes needed for a config. It is used to skip configs which do not fit in memory. @@ -484,7 +465,7 @@ def easy_objective(config): from .searcher.blendsearch import BlendSearch, CFO if lexico_objectives is not None: - logger.warning("If lexico_objectives is not None, search_alg is forced to be CFO") + logger.warning("If lexico_objectives is not None, search_alg is forced to be CFO or Blendsearch") search_alg = None if search_alg is None: flaml_scheduler_resource_attr = ( @@ -511,7 +492,11 @@ def easy_objective(config): logger.warning("Using CFO for search. To use BlendSearch, run: pip install flaml[blendsearch]") metric = metric or DEFAULT_METRIC else: - SearchAlgorithm = CFO + assert "lexico_algorithm" in lexico_objectives and lexico_objectives["lexico_algorithm"] in [ + "CFO", + "BlendSearch", + ], 'When performing lexicographic optimization, "lexico_algorithm" should be provided in lexicographic objectives (CFO or BlendSearch).' + SearchAlgorithm = CFO if lexico_objectives["lexico_algorithm"] == "CFO" else BlendSearch logger.info("Using search algorithm {}.".format(SearchAlgorithm.__name__)) metric = lexico_objectives["metrics"][0] or DEFAULT_METRIC search_alg = SearchAlgorithm( @@ -633,11 +618,9 @@ def easy_objective(config): launch one trial per executor. However, sometimes we can launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. - `max_concurrent` is the maximum number of concurrent trials defined by `search_alg`, `FLAML_MAX_CONCURRENT` will also be used to override `max_concurrent` if `search_alg` is not an instance of `ConcurrencyLimiter`. - The final number of concurrent trials is the minimum of `max_concurrent` and `num_executors` if `n_concurrent_trials<=0` (default, automl cases), otherwise the minimum of `max_concurrent` and `n_concurrent_trials` (tuning cases). diff --git a/flaml/tune/utils.py b/flaml/tune/utils.py index 9398162a3c..919f131ed4 100644 --- a/flaml/tune/utils.py +++ b/flaml/tune/utils.py @@ -25,3 +25,19 @@ def choice(categories: Sequence, order=None): domain = sample.Categorical(categories).uniform() domain.ordered = order if order is not None else all(isinstance(x, (int, float)) for x in categories) return domain + + +def get_lexico_bound(metric, mode, lexico_objectives, f_best): + """Get targeted vector according to the historical points. + LexiFlow uses targeted vector to justify the order of different configurations. + """ + k_target = lexico_objectives["targets"][metric] if mode == "min" else -1 * lexico_objectives["targets"][metric] + if not isinstance(lexico_objectives["tolerances"][metric], str): + tolerance_bound = f_best[metric] + lexico_objectives["tolerances"][metric] + else: + assert ( + lexico_objectives["tolerances"][metric][-1] == "%" + ), "String tolerance of {} should use %% as the suffix".format(metric) + tolerance_bound = f_best[metric] * (1 + 0.01 * float(lexico_objectives["tolerances"][metric].replace("%", ""))) + bound = max(tolerance_bound, k_target) + return bound diff --git a/test/tune/test_lexiflow.py b/test/tune/test_lexiflow.py index 2d0274634a..974f095190 100644 --- a/test/tune/test_lexiflow.py +++ b/test/tune/test_lexiflow.py @@ -103,6 +103,7 @@ def evaluate_function(configuration): lexico_objectives = {} lexico_objectives["metrics"] = ["error_rate", "flops"] + lexico_objectives["lexico_algorithm"] = "CFO" search_space = { "n_layers": tune.randint(lower=1, upper=3), @@ -145,6 +146,7 @@ def evaluate_function(configuration): # 1. lexico tune: absolute tolerance lexico_objectives["tolerances"] = {"error_rate": 0.02, "flops": 0.0} + lexico_objectives["lexico_algorithm"] = "CFO" analysis = tune.run( evaluate_function, num_samples=5, @@ -159,6 +161,22 @@ def evaluate_function(configuration): # 2. lexico tune: percentage tolerance lexico_objectives["tolerances"] = {"error_rate": "10%", "flops": "0%"} + lexico_objectives["lexico_algorithm"] = "CFO" + analysis = tune.run( + evaluate_function, + num_samples=5, + config=search_space, + use_ray=False, + lexico_objectives=lexico_objectives, + low_cost_partial_config=low_cost_partial_config, + ) + print(analysis.best_trial) + print(analysis.best_config) + print(analysis.best_result) + + # 3. lexico tune - blendsearch: percentage tolerance + lexico_objectives["tolerances"] = {"error_rate": "10%", "flops": "0%"} + lexico_objectives["lexico_algorithm"] = "BlendSearch" analysis = tune.run( evaluate_function, num_samples=5, @@ -178,6 +196,7 @@ def test_lexiflow_performance(): lexico_objectives["tolerances"] = {"brain": 10.0, "currin": 0.0} lexico_objectives["targets"] = {"brain": 0.0, "currin": 0.0} lexico_objectives["modes"] = ["min", "min"] + lexico_objectives["lexico_algorithm"] = "CFO" search_space = { "x1": tune.uniform(lower=0.000001, upper=1.0),