Skip to content

Commit

Permalink
chore: adjusting types and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mah Neh committed Sep 25, 2024
1 parent 8608003 commit 21dd995
Show file tree
Hide file tree
Showing 31 changed files with 354 additions and 254 deletions.
25 changes: 13 additions & 12 deletions keras_tuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""Keras Tuner Version."""

from keras_tuner import oracles, tuners
from keras_tuner.engine.hypermodel import HyperModel
from keras_tuner.engine.hyperparameters import HyperParameter, HyperParameters
# from keras_tuner import oracles, tuners
# from keras_tuner.engine.hypermodel import HyperModel
# from keras_tuner.engine.hyperparameters import HyperParameter, HyperParameters
from keras_tuner.engine.objective import Objective
from keras_tuner.engine.oracle import Oracle, synchronized
from keras_tuner.engine.tuner import Tuner
from keras_tuner.tuners import (
BayesianOptimization,
GridSearch,
Hyperband,
RandomSearch,
SklearnTuner,
)

# from keras_tuner.engine.oracle import Oracle, synchronized
# from keras_tuner.engine.tuner import Tuner
# from keras_tuner.tuners import (
# BayesianOptimization,
# GridSearch,
# Hyperband,
# RandomSearch,
# SklearnTuner,
# )

__version__ = "2.0.0"
7 changes: 5 additions & 2 deletions keras_tuner/distribute/oracle_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
"""OracleClient class."""

import os
from typing import TYPE_CHECKING

import grpc

from keras_tuner import protos
from keras_tuner.engine.hyperparameters import HyperParameters
from keras_tuner.engine.oracle import Oracle
from keras_tuner.engine.trial import Trial

if TYPE_CHECKING:
from keras_tuner.engine.oracle import Oracle
from keras_tuner.engine.trial import Trial

# The timeout is so high to prevent a rare race condition from happening.
# We need clients to wait till chief oracle server starts. This normally takes
Expand Down
17 changes: 11 additions & 6 deletions keras_tuner/engine/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@
import traceback
import warnings
from pathlib import Path
from typing import cast
from typing import TYPE_CHECKING, Any, cast

from keras_tuner import backend, errors, utils
from keras_tuner import config as config_module
from keras_tuner.distribute import utils as dist_utils
from keras_tuner.engine import hypermodel as hm_module
from keras_tuner.engine import stateful, tuner_utils
from keras_tuner.engine.hypermodel import HyperModel
from keras_tuner.engine.oracle import Oracle
from keras_tuner.engine.trial import Trial, TrialStatus
from keras_tuner.engine.hypermodel import HyperModel, get_hypermodel
from keras_tuner.engine.trial import TrialStatus
from keras_tuner.types import _Verbose

if TYPE_CHECKING:
from keras_tuner.engine.oracle import Oracle
from keras_tuner.engine.trial import Trial
else:
Oracle = Any
Trial = Any


class BaseTuner(stateful.Stateful):
"""Tuner base class.
Expand Down Expand Up @@ -107,7 +112,7 @@ def __init__(
raise ValueError(msg)

self.oracle = oracle
self.hypermodel = hm_module.get_hypermodel(hypermodel)
self.hypermodel = get_hypermodel(hypermodel)

# Ops and metadata
self.directory = directory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class HyperParameter:
"""

def __init__(self, name, default=None, conditions=None):
def __init__(self, name: str, default=None, conditions=None):
self.name = name
self._default = default

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from keras_tuner import protos
from keras_tuner.engine import conditions as conditions_mod
from keras_tuner.engine.hyperparameters import hp_types
from keras_tuner.engine.hyperparameters import hyperparameter as hp_module
from keras_tuner.engine.hyperparameters.HyperParameter import HyperParameter


class HyperParameters:
Expand All @@ -38,7 +38,7 @@ class HyperParameters:
"""

def __init__(self):
def __init__(self) -> None:
# Current name scopes.
self._name_scopes = []
# Current `Condition`s, managed by `conditional_scope`.
Expand Down Expand Up @@ -78,7 +78,7 @@ def name_scope(self, name):

@contextlib.contextmanager
def conditional_scope(self, parent_name, parent_values):
"""Opens a scope to create conditional HyperParameters.
"""Open a scope to create conditional HyperParameters.
All `HyperParameter`s created under this scope will only be active when
the parent `HyperParameter` specified by `parent_name` is equal to one
Expand Down Expand Up @@ -127,9 +127,8 @@ def build(self, hp):
"""
parent_name = self._get_name(parent_name) # Add name_scopes.
if not self._exists(parent_name):
raise ValueError(
f"`HyperParameter` named: {parent_name} not defined."
)
msg = f"`HyperParameter` named: {parent_name} not defined."
raise ValueError(msg)

condition = conditions_mod.Parent(parent_name, parent_values)
self._conditions.append(condition)
Expand All @@ -145,7 +144,7 @@ def build(self, hp):
self._conditions.pop()

def is_active(self, hyperparameter):
"""Checks if a hyperparameter is currently active for a `Trial`.
"""Check if a hyperparameter is currently active for a `Trial`.
A hyperparameter is considered active if and only if all its parent
conditions are active, and not affected by whether the hyperparameter
Expand All @@ -163,7 +162,7 @@ def is_active(self, hyperparameter):
"""
hp = hyperparameter
if isinstance(hp, hp_module.HyperParameter):
if isinstance(hp, HyperParameter):
return self._conditions_are_active(hp.conditions)
hp_name = str(hp)
return any(
Expand All @@ -175,7 +174,7 @@ def _conditions_are_active(self, conditions):
return all(condition.is_active(self.values) for condition in conditions)

def _exists(self, name, conditions=None):
"""Checks for a hyperparameter with the same name and conditions."""
"""Check for a hyperparameter with the same name and conditions."""
if conditions is None:
conditions = self._conditions

Expand All @@ -187,7 +186,7 @@ def _exists(self, name, conditions=None):
return False

def _retrieve(self, hp):
"""Gets or creates a hyperparameter.
"""Get or creates a hyperparameter.
Args:
hp: A `HyperParameter` instance.
Expand All @@ -204,7 +203,7 @@ def _retrieve(self, hp):
return self._register(hp)

def _register(self, hyperparameter, overwrite=False):
"""Registers a hyperparameter in this container.
"""Register a hyperparameter in this container.
Args:
hp: A `HyperParameter` instance.
Expand Down Expand Up @@ -502,8 +501,8 @@ def Boolean(
)
return self._retrieve(hp)

def Fixed(self, name, value, parent_name=None, parent_values=None):
"""Fixed, untunable value.
def Fixed(self, name: str, value, parent_name=None, parent_values=None):
"""Untunable value.
Args:
name: A string. the name of parameter. Must be unique for each
Expand Down Expand Up @@ -659,7 +658,8 @@ def to_proto(self):
elif isinstance(hp, hp_types.Boolean):
boolean_space.append(hp.to_proto())
else:
raise ValueError(f"Unrecognized HP type: {hp}")
msg = f"Unrecognized HP type: {hp}"
raise TypeError(msg)

values = {}
for name, value in self.values.items():
Expand Down Expand Up @@ -705,7 +705,7 @@ def _get_name(self, name, name_scopes=None):
else str(name)
)

def _validate_name(self, name):
def _validate_name(self, name: str) -> None:
for condition in self._conditions:
if condition.name == name:
raise ValueError(
Expand Down
10 changes: 5 additions & 5 deletions keras_tuner/engine/hyperparameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

from keras import utils

from keras_tuner.engine.hyperparameters import hp_types

# these below are needed apparently
# we re-export to make them available further up the chain.
from keras_tuner.engine.hyperparameters.hp_types import (
Boolean,
Choice,
Fixed,
Float,
Int,
)
from keras_tuner.engine.hyperparameters.hyperparameter import HyperParameter
from keras_tuner.engine.hyperparameters.hyperparameters import HyperParameters

from . import hp_types
from .HyperParameter import HyperParameter
from .HyperParameters import HyperParameters

OBJECTS = [*hp_types.OBJECTS, HyperParameter, HyperParameters]

Expand Down
12 changes: 7 additions & 5 deletions keras_tuner/engine/hyperparameters/hp_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

from keras import utils

from keras_tuner.engine.hyperparameters.hp_types.boolean_hp import Boolean
from keras_tuner.engine.hyperparameters.hp_types.choice_hp import Choice
from keras_tuner.engine.hyperparameters.hp_types.fixed_hp import Fixed
from keras_tuner.engine.hyperparameters.hp_types.float_hp import Float
from keras_tuner.engine.hyperparameters.hp_types.int_hp import Int
# we export the modules' main functions so that they can be imported from
# hp_types instead.
from .boolean_hp import Boolean
from .choice_hp import Choice
from .fixed_hp import Fixed
from .float_hp import Float
from .int_hp import Int

OBJECTS = (
Fixed,
Expand Down
4 changes: 2 additions & 2 deletions keras_tuner/engine/hyperparameters/hp_types/boolean_hp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

from keras_tuner import protos
from keras_tuner.engine import conditions as conditions_mod
from keras_tuner.engine.hyperparameters import hyperparameter
from keras_tuner.engine.hyperparameters.HyperParameter import HyperParameter


class Boolean(hyperparameter.HyperParameter):
class Boolean(HyperParameter):
"""Choice between True and False.
Args:
Expand Down
16 changes: 8 additions & 8 deletions keras_tuner/engine/hyperparameters/hp_types/boolean_hp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@

import pytest

from keras_tuner.engine import hyperparameters as hp_module
from . import Boolean


def test_boolean():
# Test default default
boolean = hp_module.Boolean("bool")
boolean = Boolean("bool")
assert boolean.default is False
# Test default setting
boolean = hp_module.Boolean("bool", default=True)
boolean = Boolean("bool", default=True)
assert boolean.default is True
# Wrong default type
with pytest.raises(ValueError, match="must be a Python boolean"):
hp_module.Boolean("bool", default=None)
Boolean("bool", default=None)
# Test serialization
boolean = hp_module.Boolean("bool", default=True)
boolean = hp_module.Boolean.from_config(boolean.get_config())
boolean = Boolean("bool", default=True)
boolean = Boolean.from_config(boolean.get_config())
assert boolean.default is True
assert boolean.name == "bool"

Expand All @@ -43,8 +43,8 @@ def test_boolean():


def test_boolean_repr():
assert repr(hp_module.Boolean("bool")) == repr(hp_module.Boolean("bool"))
assert repr(Boolean("bool")) == repr(Boolean("bool"))


def test_boolean_values_property():
assert list(hp_module.Boolean("bool").values) == [True, False]
assert list(Boolean("bool").values) == [True, False]
24 changes: 13 additions & 11 deletions keras_tuner/engine/hyperparameters/hp_types/choice_hp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

from keras_tuner import protos
from keras_tuner.engine import conditions as conditions_mod
from keras_tuner.engine.hyperparameters import hp_utils, hyperparameter
from keras_tuner.engine.hyperparameters import hp_utils
from keras_tuner.engine.hyperparameters.HyperParameter import HyperParameter


class Choice(hyperparameter.HyperParameter):
class Choice(HyperParameter):
"""Choice of one value among a predefined set of possible values.
Args:
Expand All @@ -39,15 +40,17 @@ class Choice(hyperparameter.HyperParameter):
def __init__(self, name, values, ordered=None, default=None, **kwargs):
super().__init__(name=name, default=default, **kwargs)
if not values:
raise ValueError("`values` must be provided for `Choice`.")
msg = "`values` must be provided for `Choice`."
raise ValueError(msg)

# Type checking.
types = {type(v) for v in values}
if len(types) > 1:
raise TypeError(
msg = (
"A `Choice` can contain only one type of value, "
f"found values: {values!s} with types {types}."
)
raise TypeError(msg)

# Standardize on str, int, float, bool.
if isinstance(values[0], str):
Expand All @@ -67,26 +70,25 @@ def __init__(self, name, values, ordered=None, default=None, **kwargs):
self._values = values

if default is not None and default not in values:
raise ValueError(
msg = (
"The default value should be one of the choices. "
f"You passed: values={values}, default={default}"
)
raise ValueError(msg)
self._default = default

# Get or infer ordered.
self.ordered = ordered
is_numeric = isinstance(values[0], (int | float))
if self.ordered and not is_numeric:
raise ValueError("`ordered` must be `False` for non-numeric types.")
msg = "`ordered` must be `False` for non-numeric types."
raise ValueError(msg)
if self.ordered is None:
self.ordered = is_numeric

def __repr__(self):
return (
f"Choice(name: '{self.name}', "
+ f"values: {self._values}, "
+ f"ordered: {self.ordered}, default: {self.default})"
)
return f"""Choice(name: '{self.name}', values: {self._values},
ordered: {self.ordered}, default: {self.default})"""

@property
def values(self):
Expand Down
Loading

0 comments on commit 21dd995

Please sign in to comment.