Skip to content

Commit

Permalink
chore(imports): improve some, broke some.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mah Neh committed Sep 24, 2024
1 parent 26421f0 commit 8608003
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 137 deletions.
68 changes: 38 additions & 30 deletions keras_tuner/engine/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
# limitations under the License.

from keras_tuner.engine import metrics_tracking
from keras_tuner.types import (
_EpochLogs,
_SomeObjective,
_SomeObjectiveOrName,
)


class Objective:
Expand All @@ -25,65 +30,60 @@ class Objective:
"""

def __init__(self, name, direction):
def __init__(self, name: str, direction: str):
self.name = name
self.direction = direction

def has_value(self, logs):
def has_value(self, logs: _EpochLogs) -> bool:
"""Check if objective value exists in logs.
Args:
logs: A dictionary with the metric names as the keys and the metric
values as the values, which is in the same format as the `logs`
argument for `Callback.on_epoch_end()`.
logs: metric_name to metric_value dict.
Returns:
Boolean, whether we can compute objective value from the logs.
Is objective in log.
"""
return self.name in logs

def get_value(self, logs):
def get_value(self, logs: _EpochLogs) -> float:
"""Get the objective value from the metrics logs.
Args:
logs: A dictionary with the metric names as the keys and the metric
values as the values, which is in the same format as the `logs`
argument for `Callback.on_epoch_end()`.
logs: metric_name to metric_value dict.
Returns:
The objective value.
"""
return logs[self.name]

def better_than(self, a, b):
"""Whether the first objective value is better than the second.
Args:
a: A float, an objective value to compare.
b: A float, another objective value to compare.
def better_than(self, new_val: float, reference: float) -> bool:
"""Check whether `new_val` is better than `reference`.
Returns:
Boolean, whether the first objective value is better than the
second.
Whether the new_val is an improvement over reference.
"""
return (a > b and self.direction == "max") or (
a < b and self.direction == "min"
return (new_val > reference and self.direction == "max") or (
new_val < reference and self.direction == "min"
)

def __eq__(self, obj):
return self.name == obj.name and self.direction == obj.direction
def __eq__(self, obj: object) -> bool:
"""Check if `obj` has the same name and direction, and class."""
if isinstance(obj, Objective | DefaultObjective):
return self.name == obj.name and self.direction == obj.direction
return False

def __str__(self):
def __str__(self) -> str:
"""Provide a human-readable string for when a user prints the class."""
return f'Objective(name="{self.name}", direction="{self.direction}")'


class DefaultObjective(Objective):
"""Default objective to minimize if not provided by the user."""

def __init__(self):
def __init__(self) -> None:
super().__init__(name="default_objective", direction="min")


Expand All @@ -95,19 +95,24 @@ class MultiObjective(Objective):
"""

def __init__(self, objectives):
def __init__(self, objectives: list[Objective]):
super().__init__(name="multi_objective", direction="min")
self.objectives = objectives
self.name_to_direction = {
objective.name: objective.direction for objective in self.objectives
}

def has_value(self, logs):
def has_value(self, logs: _EpochLogs) -> bool:
"""Check whether all objectives have a log."""
return all(key in logs for key in self.name_to_direction)

def get_value(self, logs):
def get_value(self, logs: _EpochLogs) -> float:
"""Reduce metrics values to single value."""
obj_value = 0
for metric_name, metric_value in logs.items():
if isinstance(metric_value, list):
msg = "Metric value must be a number."
raise TypeError(msg)
if metric_name not in self.name_to_direction:
continue
if self.name_to_direction[metric_name] == "min":
Expand All @@ -131,19 +136,22 @@ def __str__(self):
)


def create_objective(objective):
def create_objective(objective: _SomeObjectiveOrName | None) -> _SomeObjective:
"""Create an objective class given any of the possibilities."""
if objective is None:
return DefaultObjective()
if isinstance(objective, list):
return MultiObjective([create_objective(obj) for obj in objective])
if isinstance(objective, Objective):
return objective
if not isinstance(objective, str):
raise TypeError(
if not isinstance(objective, str): # check for users and debugging.
msg = (
"`objective` not understood, expected str or "
f"`Objective` object, found: {objective}"
)
raise TypeError(msg)

# try to infer direction using string name.
direction = metrics_tracking.infer_metric_direction(objective)
if direction is None:
error_msg = (
Expand Down
52 changes: 34 additions & 18 deletions keras_tuner/engine/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import hashlib
import random
import time
from typing import cast

from keras_tuner import protos, utils
from keras_tuner.engine import hyperparameters as hp_module
from keras_tuner.engine import metrics_tracking, stateful
from keras_tuner.engine.hyperparameters.hyperparameters import HyperParameters


class TrialStatus:
Expand Down Expand Up @@ -56,7 +57,8 @@ def to_proto(status):
return ts.COMPLETED
if status == TrialStatus.FAILED:
return ts.FAILED
raise ValueError(f"Unknown status {status}")
msg = f"Unknown status {status}"
raise ValueError(msg)

@staticmethod
def from_proto(proto):
Expand All @@ -75,7 +77,8 @@ def from_proto(proto):
return TrialStatus.COMPLETED
if proto == ts.FAILED:
return TrialStatus.FAILED
raise ValueError(f"Unknown status {proto}")
msg = f"Unknown status {proto}"
raise ValueError(msg)


class Trial(stateful.Stateful):
Expand All @@ -99,10 +102,10 @@ class Trial(stateful.Stateful):

def __init__(
self,
hyperparameters,
trial_id=None,
hyperparameters: HyperParameters | None,
trial_id: str | None = None,
status=TrialStatus.RUNNING,
message=None,
message: str | None = None,
):
self.hyperparameters = hyperparameters
self.trial_id: str = (
Expand All @@ -115,8 +118,8 @@ def __init__(
self.status = status
self.message = message

def summary(self):
"""Displays a summary of this Trial."""
def summary(self) -> None:
"""Display hyperparameters, score and any messages."""
print(f"Trial {self.trial_id} summary")

print("Hyperparameters:")
Expand All @@ -128,14 +131,19 @@ def summary(self):
if self.message is not None:
print(self.message)

def display_hyperparameters(self):
def display_hyperparameters(self) -> None:
"""Print HP-values to user."""
if self.hyperparameters is None:
raise TypeError("HyperParameters should be defined but found None.")
if self.hyperparameters.values:
for hp, value in self.hyperparameters.values.items():
print(f"{hp}:", value)
else:
print("default configuration")

def get_state(self):
if self.hyperparameters is None:
raise TypeError("HyperParameters should be defined but found None.")
return {
"trial_id": self.trial_id,
"hyperparameters": self.hyperparameters.get_config(),
Expand All @@ -148,7 +156,7 @@ def get_state(self):

def set_state(self, state):
self.trial_id = state["trial_id"]
hp = hp_module.HyperParameters.from_config(state["hyperparameters"])
hp = HyperParameters.from_config(state["hyperparameters"])
self.hyperparameters = hp
self.metrics = metrics_tracking.MetricsTracker.from_config(
state["metrics"]
Expand All @@ -165,33 +173,40 @@ def from_state(cls, state):
return trial

@classmethod
def load(cls, fname):
def load(cls: type["Trial"], fname: str) -> "Trial":
"""Load the Trial from the json-configuration file."""
return cls.from_state(utils.load_json(fname))

def to_proto(self):
if self.score is not None and self.best_step is not None:
is_score = self.score is not None
tuple_length = 2
is_location = (
isinstance(self.best_step, tuple)
and len(self.best_step) == tuple_length
)
if is_score and is_location:
best_step = cast(tuple, self.best_step)
score = protos.get_proto().Trial.Score(
value=self.score,
step={
"exec_idx": self.best_step[0],
"epoch_idx": self.best_step[1],
"exec_idx": best_step[0],
"epoch_idx": best_step[1],
},
)
else:
score = None
proto = protos.get_proto().Trial(
return protos.get_proto().Trial(
trial_id=self.trial_id,
hyperparameters=self.hyperparameters.to_proto(),
score=score,
status=TrialStatus.to_proto(self.status),
metrics=self.metrics.to_proto(),
)
return proto

@classmethod
def from_proto(cls, proto):
instance = cls(
hp_module.HyperParameters.from_proto(proto.hyperparameters),
HyperParameters.from_proto(proto.hyperparameters),
trial_id=proto.trial_id,
status=TrialStatus.from_proto(proto.status),
)
Expand All @@ -204,6 +219,7 @@ def from_proto(cls, proto):
return instance


def generate_trial_id():
def generate_trial_id() -> str:
"""Hash-like ID to identify the trial."""
s = str(time.time()) + str(random.randint(1, int(1e7)))
return hashlib.sha256(s.encode("utf-8")).hexdigest()[:32]
Loading

0 comments on commit 8608003

Please sign in to comment.