Skip to content

Commit

Permalink
chore: more types.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mah Neh committed Sep 24, 2024
1 parent 11a7f27 commit 26421f0
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 115 deletions.
18 changes: 10 additions & 8 deletions keras_tuner/engine/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import os
import traceback
import warnings
from pathlib import Path
from typing import cast

from keras_tuner import backend, errors, utils
from keras_tuner import config as config_module
Expand Down Expand Up @@ -425,7 +427,7 @@ def results_summary(self, num_trials: int = 10):
trial.summary()

@property
def remaining_trials(self):
def remaining_trials(self) -> int:
"""Returns the number of trials remaining.
Will return `None` if `max_trials` is not set. This is useful when
Expand All @@ -439,7 +441,7 @@ def get_state(self):
def set_state(self, state):
pass

def _is_worker(self):
def _is_worker(self) -> bool:
"""Return true only if in parallel tuning and is a worker tuner."""
return (
dist_utils.has_chief_oracle() and not dist_utils.is_chief_oracle()
Expand All @@ -458,15 +460,15 @@ def reload(self) -> None:
super().reload(self._get_tuner_fname())

@property
def project_dir(self):
dirname = os.path.join(str(self.directory), self.project_name)
def project_dir(self) -> str:
dirname = str(Path(str(self.directory), self.project_name))
utils.create_directory(dirname)
return dirname

def get_trial_dir(self, trial_id):
dirname = os.path.join(str(self.project_dir), f"trial_{trial_id!s}")
def get_trial_dir(self, trial_id: str) -> str:
dirname = str(Path(self.project_dir, f"trial_{trial_id!s}"))
utils.create_directory(dirname)
return dirname

def _get_tuner_fname(self):
return os.path.join(str(self.project_dir), f"{self.tuner_id!s}.json")
def _get_tuner_fname(self) -> str:
return str(Path(self.project_dir, f"{self.tuner_id!s}.json"))
52 changes: 28 additions & 24 deletions keras_tuner/engine/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

from keras_tuner import errors, utils
from keras_tuner.engine import base_tuner, tuner_utils
from keras_tuner.engine.hypermodel import HyperModel
from keras_tuner.engine.oracle import Oracle
from keras_tuner.engine.trial import Trial


class Tuner(base_tuner.BaseTuner):
Expand Down Expand Up @@ -86,28 +89,29 @@ class Tuner(base_tuner.BaseTuner):

def __init__(
self,
oracle,
hypermodel=None,
max_model_size=None,
oracle: Oracle,
hypermodel: HyperModel | None = None,
max_model_size: int | None = None,
optimizer=None,
loss=None,
metrics=None,
distribution_strategy=None,
directory=None,
project_name=None,
directory: str = ".",
project_name: str = "untitled_project",
logger=None,
tuner_id=None,
overwrite=False,
executions_per_trial=1,
overwrite: bool = False,
executions_per_trial: int = 1,
**kwargs,
):
if hypermodel is None and self.__class__.run_trial is Tuner.run_trial:
raise ValueError(
msg = (
"Received `hypermodel=None`. We only allow not specifying "
"`hypermodel` if the user defines the search space in "
"`Tuner.run_trial()` by subclassing a `Tuner` class without "
"using a `HyperModel` instance."
)
raise ValueError(msg)

super().__init__(
oracle=oracle,
Expand Down Expand Up @@ -154,17 +158,17 @@ def _try_build(self, hp):
model = self._build_hypermodel(hp)
# Stop if `build()` does not return a valid model.
if not isinstance(model, keras.models.Model):
raise errors.FatalTypeError(
msg = (
"Expected the model-building function, or HyperModel.build() "
"to return a valid Keras Model instance. "
f"Received: {model} of type {type(model)}."
)
raise errors.FatalTypeError(msg)
# Check model size.
size = maybe_compute_model_size(model)
if self.max_model_size and size > self.max_model_size:
raise errors.FailedTrialError(
f"Oversized model: {size} parameters. Skip model."
)
msg = f"Oversized model: {size} parameters. Skip model."
raise errors.FailedTrialError(msg)
return model

def _filter_metrics(self, metrics):
Expand Down Expand Up @@ -200,7 +204,7 @@ def _override_compile_args(self, model):
compile_kwargs["metrics"] = self.metrics
model.compile(**compile_kwargs)

def _build_and_fit_model(self, trial, *args, **kwargs):
def _build_and_fit_model(self, trial: Trial, *args, **kwargs):
"""For AutoKeras to override.
DO NOT REMOVE this function. AutoKeras overrides the function to tune
Expand Down Expand Up @@ -234,7 +238,7 @@ def _build_and_fit_model(self, trial, *args, **kwargs):
)
return results

def run_trial(self, trial, *args, **kwargs):
def run_trial(self, trial: Trial, *args, **kwargs):
"""Evaluate a set of hyperparameter values.
This method is called multiple times during `search` to build and
Expand Down Expand Up @@ -306,7 +310,7 @@ def run_trial(self, trial, *args, **kwargs):
histories.append(obj_value)
return histories

def load_model(self, trial):
def load_model(self, trial: Trial):
model = self._try_build(trial.hyperparameters)

# Build model to create the weights.
Expand All @@ -321,7 +325,7 @@ def load_model(self, trial):
model.load_weights(self._get_checkpoint_fname(trial.trial_id))
return model

def on_batch_begin(self, trial, model, batch, logs):
def on_batch_begin(self, trial: Trial, model, batch, logs):
"""Called at the beginning of a batch.
Args:
Expand All @@ -332,7 +336,7 @@ def on_batch_begin(self, trial, model, batch, logs):
""" # noqa: D401

def on_batch_end(self, trial, model, batch, logs=None):
def on_batch_end(self, trial: Trial, model, batch, logs=None):
"""Called at the end of a batch.
Args:
Expand All @@ -343,7 +347,7 @@ def on_batch_end(self, trial, model, batch, logs=None):
""" # noqa: D401

def on_epoch_begin(self, trial, model, epoch, logs=None):
def on_epoch_begin(self, trial: Trial, model, epoch, logs=None):
"""Called at the beginning of an epoch.
Args:
Expand All @@ -354,7 +358,7 @@ def on_epoch_begin(self, trial, model, epoch, logs=None):
""" # noqa: D401

def on_epoch_end(self, trial, model, epoch, logs=None):
def on_epoch_end(self, trial: Trial, model, epoch, logs=None):
"""Called at the end of an epoch.
Args:
Expand All @@ -368,7 +372,7 @@ def on_epoch_end(self, trial, model, epoch, logs=None):
# Intermediate results are not passed to the Oracle, and
# checkpointing is handled via a `SaveBestEpoch` callback.

def get_best_models(self, num_models=1):
def get_best_models(self, num_models: int = 1):
"""Returns the best model(s), as determined by the tuner's objective.
The models are loaded with the weights corresponding to
Expand Down Expand Up @@ -401,7 +405,7 @@ def _deepcopy_callbacks(self, callbacks):
raise errors.FatalValueError(msg) from None
return callbacks

def _configure_tensorboard_dir(self, callbacks, trial, execution=0):
def _configure_tensorboard_dir(self, callbacks, trial: Trial, execution=0):
# Only import tensorboard when using tensorflow backend to avoid
# importing tensorflow with other backend (tensorboard would import
# tensorflow).
Expand All @@ -427,17 +431,17 @@ def _configure_tensorboard_dir(self, callbacks, trial, execution=0):
)
)

def _get_tensorboard_dir(self, logdir, trial_id, execution):
def _get_tensorboard_dir(self, logdir, trial_id: str, execution):
return Path(str(logdir), str(trial_id), f"execution{execution!s}")

def _get_checkpoint_fname(self, trial_id):
def _get_checkpoint_fname(self, trial_id: str):
return Path(
# Each checkpoint is saved in its own directory.
self.get_trial_dir(trial_id),
"checkpoint",
)

def _get_build_config_fname(self, trial_id):
def _get_build_config_fname(self, trial_id: str):
return Path(
# Each checkpoint is saved in its own directory.
self.get_trial_dir(trial_id),
Expand Down
Loading

0 comments on commit 26421f0

Please sign in to comment.