diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index c4fa0e7ce..d18a22b9d 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -1,7 +1,6 @@ import copy import json import logging.handlers -import math import multiprocessing import os import platform @@ -13,7 +12,7 @@ import uuid import warnings from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from ConfigSpace.configuration_space import Configuration, ConfigurationSpace @@ -47,7 +46,7 @@ from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score, get_metrics from autoPyTorch.utils.backend import Backend, create -from autoPyTorch.utils.common import FitRequirement, replace_string_bool_to_bool +from autoPyTorch.utils.common import replace_string_bool_to_bool from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates from autoPyTorch.utils.logging_ import ( PicklableClientLogger, @@ -170,13 +169,13 @@ def __init__( os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json')))) self.search_space: Optional[ConfigurationSpace] = None - self._dataset_requirements: Optional[List[FitRequirement]] = None self._metric: Optional[autoPyTorchMetric] = None self._logger: Optional[PicklableClientLogger] = None self.run_history: RunHistory = RunHistory() self.trajectory: Optional[List] = None - self.dataset_name: Optional[str] = None + self.dataset_name: str = "" self.cv_models_: Dict = {} + self.experiment_task_name: str = 'runSearch' # By default try to use the TCP logging port or get a new port self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT @@ -469,43 +468,23 @@ def _load_best_individual_model(self) -> SingleBest: run_history=self.run_history, backend=self._backend, ) + msg = "No valid ensemble was created. Please check the log" \ + f"file for errors. Default to the best individual estimator:{ensemble.identifiers_}" + if self._logger is None: - warnings.warn( - "No valid ensemble was created. Please check the log" - "file for errors. Default to the best individual estimator:{}".format( - ensemble.identifiers_ - ) - ) + warnings.warn(msg) else: - self._logger.exception( - "No valid ensemble was created. Please check the log" - "file for errors. Default to the best individual estimator:{}".format( - ensemble.identifiers_ - ) - ) + self._logger.exception(msg) return ensemble - def _do_dummy_prediction(self) -> None: - - assert self._metric is not None - assert self._logger is not None - - # For dummy estimator, we always expect the num_run to be 1 - num_run = 1 - - self._logger.info("Starting to create dummy predictions.") - - memory_limit = self._memory_limit - if memory_limit is not None: - memory_limit = int(math.ceil(memory_limit)) - + def _get_target_algorithm(self, wallclock_limit: int) -> ExecuteTaFuncWithQueue: scenario_mock = unittest.mock.Mock() - scenario_mock.wallclock_limit = self._time_for_task - # This stats object is a hack - maybe the SMAC stats object should - # already be generated here! + scenario_mock.wallclock_limit = wallclock_limit stats = Stats(scenario_mock) stats.start_timing() + + assert self._metric is not None ta = ExecuteTaFuncWithQueue( backend=self._backend, seed=self.seed, @@ -513,48 +492,77 @@ def _do_dummy_prediction(self) -> None: logger_port=self._logger_port, cost_for_crash=get_cost_of_crash(self._metric), abort_on_first_run_crash=False, - initial_num_run=num_run, + initial_num_run=self._backend.get_next_num_run(), stats=stats, - memory_limit=memory_limit, + memory_limit=self._memory_limit, disable_file_output=True if len(self._disable_file_output) > 0 else False, all_supported_metrics=self._all_supported_metrics ) + return ta + + def _logging_failed_prediction(self, additional_info: Any, + header: str, raise_error: bool) -> None: + assert self._logger is not None + + if additional_info.get('exitcode') == -6: + err_msg = "The error suggests that the provided memory limits were too tight. Please " \ + "increase the 'ml_memory_limit' and try again. If this does not solve your " \ + "problem, please open an issue and paste the additional output. " \ + f"Additional output: {str(additional_info)}.", + output = f"{header}. {err_msg}" + self._logger.error(output) + if raise_error: + raise ValueError(output) - status, cost, runtime, additional_info = ta.run(num_run, cutoff=self._time_for_task) - if status == StatusType.SUCCESS: - self._logger.info("Finished creating dummy predictions.") else: - if additional_info.get('exitcode') == -6: - self._logger.error( - "Dummy prediction failed with run state %s. " - "The error suggests that the provided memory limits were too tight. Please " - "increase the 'ml_memory_limit' and try again. If this does not solve your " - "problem, please open an issue and paste the additional output. " - "Additional output: %s.", - str(status), str(additional_info), - ) - # Fail if dummy prediction fails. - raise ValueError( - "Dummy prediction failed with run state %s. " - "The error suggests that the provided memory limits were too tight. Please " - "increase the 'ml_memory_limit' and try again. If this does not solve your " - "problem, please open an issue and paste the additional output. " - "Additional output: %s." % - (str(status), str(additional_info)), - ) + output = f"{header} and additional output: {str(additional_info)}." + self._logger.error(output) + if raise_error: + raise ValueError(output) + + def _parallel_worker_allocation(self, num_future_jobs: int, run_history: RunHistory, + dask_futures: List[Tuple[str, Any]] + ) -> None: + """ + The functin to allocate and implement jobs to unused workers. + The history is recorded in run_history. + + Args: + num_future_jobs (int): The number of jobs to run + dask_futures (List[Tuple[str, Any]]): + The list of pairs of the name of the classifier to run and + the function to train the classifier + run_history (RunHistory): + The running history of the experiment + + Note: + - `dask_futures.pop(0)` gives a classifier and a next job to run + - `future.result()` calls a submitted job in and return the results + - We have to wait for the return of `future.result()` + once the number of running jobs reaches `num_workers` in self.dask_client + """ + assert self._logger is not None + + while num_future_jobs >= 1: + num_future_jobs -= 1 + classifier, future = dask_futures.pop(0) + # call the training by future.result() + status, cost, runtime, additional_info = future.result() + + if status == StatusType.SUCCESS: + self._logger.info( + f"Fitting {classifier} took {runtime}s, performance:{cost}/{additional_info}") + configuration = additional_info['pipeline_configuration'] + origin = additional_info['configuration_origin'] + run_history.add(config=configuration, cost=cost, + time=runtime, status=status, seed=self.seed, + origin=origin) else: - self._logger.error( - "Dummy prediction failed with run state %s and additional output: %s.", - str(status), str(additional_info), - ) - # Fail if dummy prediction fails. - raise ValueError( - "Dummy prediction failed with run state %s and additional output: %s." - % (str(status), str(additional_info)) - ) + header = f"Traditional prediction for {classifier} failed with run state {str(status)}" + self._logging_failed_prediction(additional_info, header, raise_error=False) - def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: int) -> None: + def _traditional_predictions(self, time_left: int) -> None: """ Fits traditional machine learning algorithms to the provided dataset, while complying with time resource allocation. @@ -576,13 +584,11 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: assert self._logger is not None assert self._dask_client is not None - self._logger.info("Starting to create traditional classifier predictions.") + self._logger.info("Start to create traditional classifier predictions.") # Initialise run history for the traditional classifiers run_history = RunHistory() - memory_limit = self._memory_limit - if memory_limit is not None: - memory_limit = int(math.ceil(memory_limit)) + available_classifiers = get_available_classifiers() dask_futures = [] @@ -591,88 +597,33 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: # Only launch a task if there is time start_time = time.time() - if time_left >= func_eval_time_limit_secs: - self._logger.info(f"{n_r}: Started fitting {classifier} with cutoff={func_eval_time_limit_secs}") - scenario_mock = unittest.mock.Mock() - scenario_mock.wallclock_limit = time_left - # This stats object is a hack - maybe the SMAC stats object should - # already be generated here! - stats = Stats(scenario_mock) - stats.start_timing() - ta = ExecuteTaFuncWithQueue( - backend=self._backend, - seed=self.seed, - metric=self._metric, - logger_port=self._logger_port, - cost_for_crash=get_cost_of_crash(self._metric), - abort_on_first_run_crash=False, - initial_num_run=self._backend.get_next_num_run(), - stats=stats, - memory_limit=memory_limit, - disable_file_output=True if len(self._disable_file_output) > 0 else False, - all_supported_metrics=self._all_supported_metrics - ) + sufficient_time_available = (time_left >= self._func_eval_time_limit_secs) + + if sufficient_time_available: + self._logger.info(f"{n_r}: Start fitting {classifier} with cutoff={self._func_eval_time_limit_secs}") + ta = self._get_target_algorithm(time_left) dask_futures.append([ classifier, self._dask_client.submit( ta.run, config=classifier, - cutoff=func_eval_time_limit_secs, - ) - ]) - - # When managing time, we need to take into account the allocated time resources, - # which are dependent on the number of cores. 'dask_futures' is a proxy to the number - # of workers /n_jobs that we have, in that if there are 4 cores allocated, we can run at most - # 4 task in parallel. Every 'cutoff' seconds, we generate up to 4 tasks. - # If we only have 4 workers and there are 4 futures in dask_futures, it means that every - # worker has a task. We would not like to launch another job until a worker is available. To this - # end, the following if-statement queries the number of active jobs, and forces to wait for a job - # completion via future.result(), so that a new worker is available for the next iteration. + cutoff=self._func_eval_time_limit_secs, + )]) + if len(dask_futures) >= self.n_jobs: + last_iteration = (n_r >= total_number_classifiers - 1) + num_future_jobs = 1 + # If it is the last iteration, we have to run all the jobs + if not sufficient_time_available or last_iteration: + num_future_jobs = len(dask_futures) + + self._parallel_worker_allocation(num_future_jobs=num_future_jobs, + run_history=run_history, + dask_futures=dask_futures) - # How many workers to wait before starting fitting the next iteration - workers_to_wait = 1 - if n_r >= total_number_classifiers - 1 or time_left <= func_eval_time_limit_secs: - # If on the last iteration, flush out all tasks - workers_to_wait = len(dask_futures) - - while workers_to_wait >= 1: - workers_to_wait -= 1 - # We launch dask jobs only when there are resources available. - # This allow us to control time allocation properly, and early terminate - # the traditional machine learning pipeline - cls, future = dask_futures.pop(0) - status, cost, runtime, additional_info = future.result() - if status == StatusType.SUCCESS: - self._logger.info( - f"Fitting {cls} took {runtime}s, performance:{cost}/{additional_info}") - configuration = additional_info['pipeline_configuration'] - origin = additional_info['configuration_origin'] - run_history.add(config=configuration, cost=cost, - time=runtime, status=status, seed=self.seed, - origin=origin) - else: - if additional_info.get('exitcode') == -6: - self._logger.error( - "Traditional prediction for %s failed with run state %s. " - "The error suggests that the provided memory limits were too tight. Please " - "increase the 'ml_memory_limit' and try again. If this does not solve your " - "problem, please open an issue and paste the additional output. " - "Additional output: %s.", - cls, str(status), str(additional_info), - ) - else: - self._logger.error( - "Traditional prediction for %s failed with run state %s and additional output: %s.", - cls, str(status), str(additional_info), - ) - - # In the case of a serial execution, calling submit halts the run for a resource - # dynamically adjust time in this case time_left -= int(time.time() - start_time) + self.num_run = n_r - # Exit if no more time is available for a new classifier - if time_left < func_eval_time_limit_secs: + if time_left < self._func_eval_time_limit_secs: self._logger.warning("Not enough time to fit all traditional machine learning models." "Please consider increasing the run time to further improve performance.") break @@ -684,6 +635,283 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: save_external=True) return + def _run_dummy_predictions(self) -> None: + assert self._metric is not None + assert self._logger is not None + + # For dummy estimator, we always expect the num_run to be 1 + num_run = 1 + + dummy_task_name = 'runDummy' + self._stopwatch.start_task(dummy_task_name) + self._logger.info("Start to create dummy predictions.") + ta = self._get_target_algorithm(self._total_walltime_limit) + status, cost, runtime, additional_info = ta.run(num_run, cutoff=self._total_walltime_limit) + if status == StatusType.SUCCESS: + self._logger.info("Finish creating dummy predictions.") + else: + header = f"Dummy prediction failed with run state {str(status)}" + self._logging_failed_prediction(additional_info=additional_info, + header=header, raise_error=True) + self._stopwatch.stop_task(dummy_task_name) + + def _run_traditional_ml(self) -> None: + """We would like to obtain training time for at least 1 Neural network in SMAC""" + assert self._logger is not None + + if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS: + self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...") + else: + traditional_task_name = 'runTraditional' + self._stopwatch.start_task(traditional_task_name) + elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) + + assert self._func_eval_time_limit_secs is not None + time_for_traditional = int( + self._total_walltime_limit - elapsed_time - self._func_eval_time_limit_secs + ) + self._traditional_predictions(time_left=time_for_traditional) + self._stopwatch.stop_task(traditional_task_name) + + def _run_ensemble(self, dataset: BaseDataset, optimize_metric: str, + precision: int) -> Optional[EnsembleBuilderManager]: + + assert self._logger is not None + assert self._metric is not None + + elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) + time_left_for_ensembles = max(0, self._total_walltime_limit - elapsed_time) + proc_ensemble = None + if time_left_for_ensembles <= 0 and self.ensemble_size > 0: + raise ValueError("Could not run ensemble builder because there " + "is no time left. Try increasing the value " + "of total_walltime_limit.") + elif self.ensemble_size <= 0: + self._logger.info("Could not run ensemble builder as ensemble size is non-positive.") + else: + self._logger.info("Run ensemble") + ensemble_task_name = 'ensemble' + self._stopwatch.start_task(ensemble_task_name) + proc_ensemble = EnsembleBuilderManager( + start_time=time.time(), + time_left_for_ensembles=time_left_for_ensembles, + backend=copy.deepcopy(self._backend), + dataset_name=dataset.dataset_name, + output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type], + task_type=STRING_TO_TASK_TYPES[self.task_type], + metrics=[self._metric], + opt_metric=optimize_metric, + ensemble_size=self.ensemble_size, + ensemble_nbest=self.ensemble_nbest, + max_models_on_disc=self.max_models_on_disc, + seed=self.seed, + max_iterations=None, + read_at_most=sys.maxsize, + ensemble_memory_limit=self._memory_limit, + random_state=self.seed, + precision=precision, + logger_port=self._logger_port + ) + self._stopwatch.stop_task(ensemble_task_name) + + return proc_ensemble + + def _get_budget_config(self, budget_type: Optional[str] = None, + budget: Optional[float] = None) -> Dict[str, Union[float, str]]: + + budget_config: Dict[str, Union[float, str]] = {} + if budget_type is not None and budget is not None: + budget_config['budget_type'] = budget_type + budget_config[budget_type] = budget + elif budget_type is not None or budget is not None: + raise ValueError("budget type was not specified in budget_config") + + return budget_config + + def _start_smac(self, proc_smac: AutoMLSMBO) -> None: + assert self._logger is not None + + try: + run_history, self.trajectory, budget_type = \ + proc_smac.run_smbo() + self.run_history.update(run_history, DataOrigin.INTERNAL) + trajectory_filename = os.path.join( + self._backend.get_smac_output_directory_for_run(self.seed), + 'trajectory.json') + + assert self.trajectory is not None + + saveable_trajectory = \ + [list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:]) + for entry in self.trajectory] + except Exception as e: + self._logger.exception(str(e)) + raise + else: + try: + with open(trajectory_filename, 'w') as fh: + json.dump(saveable_trajectory, fh) + except Exception as e: + self._logger.warning(f"Could not save {trajectory_filename} due to {e}...") + + def _run_smac(self, dataset: BaseDataset, proc_ensemble: Optional[EnsembleBuilderManager], + budget_type: Optional[str] = None, budget: Optional[float] = None, + get_smac_object_callback: Optional[Callable] = None, + smac_scenario_args: Optional[Dict[str, Any]] = None) -> None: + + assert self._logger is not None + + smac_task_name = 'runSMAC' + self._stopwatch.start_task(smac_task_name) + elapsed_time = self._stopwatch.wall_elapsed(self.experiment_task_name) + time_left_for_smac = max(0, self._total_walltime_limit - elapsed_time) + + self._logger.info(f"Run SMAC with {time_left_for_smac:.2f} sec time left") + if time_left_for_smac <= 0: + self._logger.warning(" Could not run SMAC because there is no time left") + else: + budget_config = self._get_budget_config(budget_type=budget_type, budget=budget) + + assert self._func_eval_time_limit_secs is not None + assert self._metric is not None + proc_smac = AutoMLSMBO( + config_space=self.search_space, + dataset_name=dataset.dataset_name, + backend=self._backend, + total_walltime_limit=self._total_walltime_limit, + func_eval_time_limit_secs=self._func_eval_time_limit_secs, + dask_client=self._dask_client, + memory_limit=self._memory_limit, + n_jobs=self.n_jobs, + watcher=self._stopwatch, + metric=self._metric, + seed=self.seed, + include=self.include_components, + exclude=self.exclude_components, + disable_file_output=self._disable_file_output, + all_supported_metrics=self._all_supported_metrics, + smac_scenario_args=smac_scenario_args, + get_smac_object_callback=get_smac_object_callback, + pipeline_config={**self.pipeline_options, **budget_config}, + ensemble_callback=proc_ensemble, + logger_port=self._logger_port, + start_num_run=self._backend.get_next_num_run(peek=True), + search_space_updates=self.search_space_updates + ) + + self._start_smac(proc_smac) + + def _search_settings(self, dataset: BaseDataset, disable_file_output: List, + optimize_metric: str, memory_limit: Optional[int] = 4096, + func_eval_time_limit_secs: Optional[int] = None, + total_walltime_limit: int = 100, + all_supported_metrics: bool = True) -> None: + + """Initialise information needed for the experiment""" + self.experiment_task_name = 'runSearch' + dataset_requirements = get_dataset_requirements( + info=self._get_required_dataset_properties(dataset)) + dataset_properties = dataset.get_dataset_properties(dataset_requirements) + + self._stopwatch.start_task(self.experiment_task_name) + self.dataset_name = dataset.dataset_name + self._all_supported_metrics = all_supported_metrics + self._disable_file_output = disable_file_output + self._memory_limit = memory_limit + self._total_walltime_limit = total_walltime_limit + self._func_eval_time_limit_secs = func_eval_time_limit_secs + self._metric = get_metrics( + names=[optimize_metric], dataset_properties=dataset_properties)[0] + + if self._logger is None: + self._logger = self._get_logger(str(self.dataset_name)) + + # Save start time to backend + self._backend.save_start_time(str(self.seed)) + self._backend.save_datamanager(dataset) + + # Print debug information to log + self._print_debug_info_to_log() + + self.search_space = self.get_search_space(dataset) + + # If no dask client was provided, we create one, so that we can + # start a ensemble process in parallel to smbo optimize + if ( + self._dask_client is None and (self.ensemble_size > 0 or self.n_jobs is not None and self.n_jobs > 1) + ): + self._create_dask_client() + else: + self._is_dask_client_internally_created = False + + def _adapt_time_resource_allocation(self) -> None: + assert self._logger is not None + + # Handle time resource allocation + elapsed_time = self._stopwatch.wall_elapsed(self.experiment_task_name) + time_left_for_modelfit = int(max(0, self._total_walltime_limit - elapsed_time)) + if self._func_eval_time_limit_secs is None or self._func_eval_time_limit_secs > time_left_for_modelfit: + self._logger.warning( + 'Time limit for a single run is higher than total time ' + 'limit. Capping the limit for a single run to the total ' + 'time given to SMAC (%f)' % time_left_for_modelfit + ) + self._func_eval_time_limit_secs = time_left_for_modelfit + + # Make sure that at least 2 models are created for the ensemble process + num_models = time_left_for_modelfit // self._func_eval_time_limit_secs + if num_models < 2: + self._func_eval_time_limit_secs = time_left_for_modelfit // 2 + self._logger.warning( + "Capping the func_eval_time_limit_secs to {} to have " + "time for a least 2 models to ensemble.".format( + self._func_eval_time_limit_secs + ) + ) + + def _save_ensemble_performance_history(self, proc_ensemble: EnsembleBuilderManager) -> None: + assert self._logger is not None + + if len(proc_ensemble.futures) > 0: + # Also add ensemble runs that did not finish within smac time + # and add them into the ensemble history + self._logger.info("Ensemble script still running, waiting for it to finish.") + result = proc_ensemble.futures.pop().result() + if result: + ensemble_history, _, _, _ = result + self.ensemble_performance_history.extend(ensemble_history) + self._logger.info("Ensemble script finished, continue shutdown.") + + # save the ensemble performance history file + if len(self.ensemble_performance_history) > 0: + pd.DataFrame(self.ensemble_performance_history).to_json( + os.path.join(self._backend.internals_directory, 'ensemble_history.json')) + + def _finish_experiment(self, proc_ensemble: Optional[EnsembleBuilderManager], + load_models: bool) -> None: + + assert self._logger is not None + # Wait until the ensemble process is finished to avoid shutting down + # while the ensemble builder tries to access the data + self._logger.info("Start Shutdown") + + if proc_ensemble is not None: + self.ensemble_performance_history = list(proc_ensemble.history) + self._save_ensemble_performance_history(proc_ensemble) + + self._logger.info("Close the dask infrastructure") + self._close_dask_client() + self._logger.info("Finish closing the dask infrastructure") + + if load_models: + self._logger.info("Load models...") + self._load_models() + self._logger.info("Finish loading models...") + + # Clean up the logger + self._logger.info("Start to clean up the logger") + self._clean_logger() + def _search( self, optimize_metric: str, @@ -699,7 +927,7 @@ def _search( all_supported_metrics: bool = True, precision: int = 32, disable_file_output: List = [], - load_models: bool = True, + load_models: bool = True ) -> 'BaseTask': """ Search for the best pipeline configuration for the given dataset. @@ -780,230 +1008,31 @@ def _search( raise ValueError("Incompatible dataset entered for current task," "expected dataset to have task type :{} got " ":{}".format(self.task_type, dataset.task_type)) - - # Initialise information needed for the experiment - experiment_task_name = 'runSearch' - dataset_requirements = get_dataset_requirements( - info=self._get_required_dataset_properties(dataset)) - self._dataset_requirements = dataset_requirements - dataset_properties = dataset.get_dataset_properties(dataset_requirements) - self._stopwatch.start_task(experiment_task_name) - self.dataset_name = dataset.dataset_name - if self._logger is None: - self._logger = self._get_logger(self.dataset_name) - self._all_supported_metrics = all_supported_metrics - self._disable_file_output = disable_file_output - self._memory_limit = memory_limit - self._time_for_task = total_walltime_limit - # Save start time to backend - self._backend.save_start_time(str(self.seed)) - - self._backend.save_datamanager(dataset) - - # Print debug information to log - self._print_debug_info_to_log() - - self._metric = get_metrics( - names=[optimize_metric], dataset_properties=dataset_properties)[0] - - self.search_space = self.get_search_space(dataset) - - budget_config: Dict[str, Union[float, str]] = {} - if budget_type is not None and budget is not None: - budget_config['budget_type'] = budget_type - budget_config[budget_type] = budget - elif budget_type is not None or budget is not None: - raise ValueError( - "budget type was not specified in budget_config" - ) - if self.task_type is None: raise ValueError("Cannot interpret task type from the dataset") + if precision not in [16, 32, 64]: + raise ValueError(f"precision must be either [16, 32, 64], but got {precision}") - # If no dask client was provided, we create one, so that we can - # start a ensemble process in parallel to smbo optimize - if ( - self._dask_client is None and (self.ensemble_size > 0 or self.n_jobs is not None and self.n_jobs > 1) - ): - self._create_dask_client() - else: - self._is_dask_client_internally_created = False + self._search_settings(dataset=dataset, disable_file_output=disable_file_output, + optimize_metric=optimize_metric, memory_limit=memory_limit, + all_supported_metrics=all_supported_metrics, + func_eval_time_limit_secs=func_eval_time_limit_secs, + total_walltime_limit=total_walltime_limit) - # Handle time resource allocation - elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name) - time_left_for_modelfit = int(max(0, total_walltime_limit - elapsed_time)) - if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_left_for_modelfit: - self._logger.warning( - 'Time limit for a single run is higher than total time ' - 'limit. Capping the limit for a single run to the total ' - 'time given to SMAC (%f)' % time_left_for_modelfit - ) - func_eval_time_limit_secs = time_left_for_modelfit - - # Make sure that at least 2 models are created for the ensemble process - num_models = time_left_for_modelfit // func_eval_time_limit_secs - if num_models < 2: - func_eval_time_limit_secs = time_left_for_modelfit // 2 - self._logger.warning( - "Capping the func_eval_time_limit_secs to {} to have " - "time for a least 2 models to ensemble.".format( - func_eval_time_limit_secs - ) - ) - - # ============> Run dummy predictions - dummy_task_name = 'runDummy' - self._stopwatch.start_task(dummy_task_name) - self._do_dummy_prediction() - self._stopwatch.stop_task(dummy_task_name) - - # ============> Run traditional ml + self._adapt_time_resource_allocation() + self._run_dummy_predictions() if enable_traditional_pipeline: - if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS: - self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...") - else: - traditional_task_name = 'runTraditional' - self._stopwatch.start_task(traditional_task_name) - elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) - # We want time for at least 1 Neural network in SMAC - time_for_traditional = int( - self._time_for_task - elapsed_time - func_eval_time_limit_secs - ) - self._do_traditional_prediction( - func_eval_time_limit_secs=func_eval_time_limit_secs, - time_left=time_for_traditional, - ) - self._stopwatch.stop_task(traditional_task_name) - - # ============> Starting ensemble - elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) - time_left_for_ensembles = max(0, total_walltime_limit - elapsed_time) - proc_ensemble = None - if time_left_for_ensembles <= 0: - # Fit only raises error when ensemble_size is not zero but - # time_left_for_ensembles is zero. - if self.ensemble_size > 0: - raise ValueError("Not starting ensemble builder because there " - "is no time left. Try increasing the value " - "of time_left_for_this_task.") - elif self.ensemble_size <= 0: - self._logger.info("Not starting ensemble builder as ensemble size is 0") - else: - self._logger.info("Starting ensemble") - ensemble_task_name = 'ensemble' - self._stopwatch.start_task(ensemble_task_name) - proc_ensemble = EnsembleBuilderManager( - start_time=time.time(), - time_left_for_ensembles=time_left_for_ensembles, - backend=copy.deepcopy(self._backend), - dataset_name=dataset.dataset_name, - output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type], - task_type=STRING_TO_TASK_TYPES[self.task_type], - metrics=[self._metric], - opt_metric=optimize_metric, - ensemble_size=self.ensemble_size, - ensemble_nbest=self.ensemble_nbest, - max_models_on_disc=self.max_models_on_disc, - seed=self.seed, - max_iterations=None, - read_at_most=sys.maxsize, - ensemble_memory_limit=self._memory_limit, - random_state=self.seed, - precision=precision, - logger_port=self._logger_port, - ) - self._stopwatch.stop_task(ensemble_task_name) - - # ==> Run SMAC - smac_task_name = 'runSMAC' - self._stopwatch.start_task(smac_task_name) - elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name) - time_left_for_smac = max(0, total_walltime_limit - elapsed_time) + self._run_traditional_ml() - self._logger.info("Starting SMAC with %5.2f sec time left" % time_left_for_smac) - if time_left_for_smac <= 0: - self._logger.warning(" Not starting SMAC because there is no time left") - else: + proc_ensemble = self._run_ensemble(dataset=dataset, precision=precision, + optimize_metric=optimize_metric) - _proc_smac = AutoMLSMBO( - config_space=self.search_space, - dataset_name=dataset.dataset_name, - backend=self._backend, - total_walltime_limit=total_walltime_limit, - func_eval_time_limit_secs=func_eval_time_limit_secs, - dask_client=self._dask_client, - memory_limit=self._memory_limit, - n_jobs=self.n_jobs, - watcher=self._stopwatch, - metric=self._metric, - seed=self.seed, - include=self.include_components, - exclude=self.exclude_components, - disable_file_output=self._disable_file_output, - all_supported_metrics=self._all_supported_metrics, - smac_scenario_args=smac_scenario_args, - get_smac_object_callback=get_smac_object_callback, - pipeline_config={**self.pipeline_options, **budget_config}, - ensemble_callback=proc_ensemble, - logger_port=self._logger_port, - # We do not increase the num_run here, this is something - # smac does internally - start_num_run=self._backend.get_next_num_run(peek=True), - search_space_updates=self.search_space_updates - ) - try: - run_history, self.trajectory, budget_type = \ - _proc_smac.run_smbo() - self.run_history.update(run_history, DataOrigin.INTERNAL) - trajectory_filename = os.path.join( - self._backend.get_smac_output_directory_for_run(self.seed), - 'trajectory.json') - saveable_trajectory = \ - [list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:]) - for entry in self.trajectory] - try: - with open(trajectory_filename, 'w') as fh: - json.dump(saveable_trajectory, fh) - except Exception as e: - self._logger.warning(f"Cannot save {trajectory_filename} due to {e}...") - except Exception as e: - self._logger.exception(str(e)) - raise - # Wait until the ensemble process is finished to avoid shutting down - # while the ensemble builder tries to access the data - self._logger.info("Starting Shutdown") - - if proc_ensemble is not None: - self.ensemble_performance_history = list(proc_ensemble.history) - - if len(proc_ensemble.futures) > 0: - # Also add ensemble runs that did not finish within smac time - # and add them into the ensemble history - self._logger.info("Ensemble script still running, waiting for it to finish.") - result = proc_ensemble.futures.pop().result() - if result: - ensemble_history, _, _, _ = result - self.ensemble_performance_history.extend(ensemble_history) - self._logger.info("Ensemble script finished, continue shutdown.") - - # save the ensemble performance history file - if len(self.ensemble_performance_history) > 0: - pd.DataFrame(self.ensemble_performance_history).to_json( - os.path.join(self._backend.internals_directory, 'ensemble_history.json')) - - self._logger.info("Closing the dask infrastructure") - self._close_dask_client() - self._logger.info("Finished closing the dask infrastructure") - - if load_models: - self._logger.info("Loading models...") - self._load_models() - self._logger.info("Finished loading models...") + self._run_smac(budget=budget, budget_type=budget_type, proc_ensemble=proc_ensemble, + dataset=dataset, get_smac_object_callback=get_smac_object_callback, + smac_scenario_args=smac_scenario_args) - # Clean up the logger - self._logger.info("Starting to clean up the logger") - self._clean_logger() + self._finish_experiment(proc_ensemble=proc_ensemble, load_models=load_models) return self @@ -1035,7 +1064,7 @@ def refit( Returns: self """ - if self.dataset_name is None: + if self.dataset_name == "": self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) if self._logger is None: @@ -1067,12 +1096,12 @@ def refit( for identifier in self.models_: model = self.models_[identifier] - # this updates the model inplace, it can then later be used in + # It updates the model inplace, it can then later be used in # predict method - # try to fit the model. If it fails, shuffle the data. This - # could alleviate the problem in algorithms that depend on - # the ordering of the data. + # Fit the model to check if it fails. + # If it fails, shuffle the data to alleviate + # the ordering-of-the-data issue in algorithms fit_and_suppress_warnings(self._logger, model, X, y=None) self._clean_logger() @@ -1105,7 +1134,7 @@ def fit(self, Returns: (BasePipeline): fitted pipeline """ - if self.dataset_name is None: + if self.dataset_name == "": self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) if self._logger is None: @@ -1238,15 +1267,11 @@ def __del__(self) -> None: self._backend.context.delete_directories(force=False) @typing.no_type_check - def get_incumbent_results( - self - ): + def get_incumbent_results(self): pass @typing.no_type_check - def get_incumbent_config( - self - ): + def get_incumbent_config(self): pass def get_models_with_weights(self) -> List: @@ -1270,7 +1295,7 @@ def _print_debug_info_to_log(self) -> None: Prints to the log file debug information about the current estimator """ assert self._logger is not None - self._logger.debug("Starting to print environment information") + self._logger.debug("Start to print environment information") self._logger.debug(' Python version: %s', sys.version.split('\n')) self._logger.debug(' System: %s', platform.system()) self._logger.debug(' Machine: %s', platform.machine()) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 7866e7674..518cb0e56 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -391,7 +391,7 @@ def test_tabular_input_support(openml_id, backend): ensemble_size=0, ) - estimator._do_dummy_prediction = unittest.mock.MagicMock() + estimator._run_dummy_predictions = unittest.mock.MagicMock() with unittest.mock.patch.object(AutoMLSMBO, 'run_smbo') as AutoMLSMBOMock: AutoMLSMBOMock.return_value = (RunHistory(), {}, 'epochs') @@ -407,7 +407,7 @@ def test_tabular_input_support(openml_id, backend): @pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True) -def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): +def test_run_dummy_prediction(dask_client, fit_dictionary_tabular): backend = fit_dictionary_tabular['backend'] estimator = TabularClassificationTask( backend=backend, @@ -424,7 +424,7 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): estimator._disable_file_output = [] estimator._all_supported_metrics = False - estimator._do_dummy_prediction() + estimator._run_dummy_prediction() # Ensure that the dummy predictions are not in the current working # directory, but in the temporary directory.