Skip to content

Commit

Permalink
Decouple tf.estimator.Head from AdaNet Keras API. #1
Browse files Browse the repository at this point in the history
Create a private `_KerasHead` object to provide necessary values to create a `tf.estimator.EstimatorSpec` within AdaNet without the need of a `tf.estimator.Head`.

PiperOrigin-RevId: 278870525
  • Loading branch information
csvillalta authored and cweill committed Nov 6, 2019
1 parent 11147e8 commit a673986
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 131 deletions.
13 changes: 8 additions & 5 deletions adanet/autoensemble/keras_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class KerasTest(parameterized.TestCase, tf.test.TestCase):
feature_columns=feature_columns,
optimizer=optimizer),
},
"metrics": ["mae"],
"want_metrics_names": ["loss", "mae"]
"metrics": [tf.keras.metrics.MeanAbsoluteError],
"want_metrics_names": ["loss", "mean_absolute_error"]
})
# pylint: enable=g-long-lambda

Expand All @@ -64,7 +64,8 @@ def test_auto_ensemble_lifecycle(self,
candidate_pool=candidate_pool(regression_head.RegressionHead(),
feature_columns, optimizer),
max_iteration_steps=10)
keras_model.compile(loss="mse", metrics=metrics)
keras_model.compile(loss=tf.keras.losses.MeanSquaredError(),
metrics=metrics)
if want_metrics_names is None:
want_metrics_names = ["loss"]
self.assertEqual(want_metrics_names, keras_model.metrics_names)
Expand All @@ -75,12 +76,14 @@ def test_auto_ensemble_lifecycle(self,

eval_results = keras_model.evaluate(train_data, steps=3)
# TODO: Currently model training and evaluation are not
# producing deterministic results. Look into properly
# seeding the subnetworks to make this test deterministic.
# producing deterministic results. Look into properly
# seeding the subnetworks to make this test deterministic.
self.assertIsNotNone(eval_results[0])
if metrics:
self.assertLen(eval_results[1:], len(metrics))

# TODO: Change the assertion to actually check the values rather
# than the length of the returned predictions array.
predict_data = lambda: tf.data.Dataset.from_tensors(({"x": [[1., 0.]]}))
predictions = keras_model.predict(predict_data)
self.assertLen(predictions, 1)
Expand Down
198 changes: 95 additions & 103 deletions adanet/keras/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,74 @@
import tensorflow as tf


def _dataset_to_input_fn(dataset):
"""Converts a `tf.data.Dataset` to an input_fn."""
class _KerasHead(object):
"""A `tf.estimator.Head`-like alternative for usage within AdaNet."""

def input_fn(params=None):
del params # unused
return dataset()
def __init__(self, logits_dimension, loss, metrics):
"""Initialize a _KerasHead object.
return input_fn
Args:
logits_dimension: The dimension of the final layer of any subnetworks.
loss: A `tf.keras.losses.Loss`. Note: must set `from_logits` to True if
the loss is a non-regression loss.
metrics: List of lambdas that return `tf.keras.metric.Metric` objects.
Each metric object must have `name` set to some string and `from_logits`
set to True if it is a non-regression metric.
Raises:
ValueError: If `from_logits` isn't `True` for a non-regression `loss`.
"""

self.logits_dimension = logits_dimension
self.metrics = metrics

if hasattr(loss, "from_logits") and not loss.from_logits:
raise ValueError("from_logits must be True for non-regression losses.")
self.loss = loss

def create_estimator_spec(self, features, mode, logits, labels, train_op_fn):
"""Returns EstimatorSpec that a `model_fn` can return."""

del features, train_op_fn # unused

eval_metric_ops = None
export_outputs = None
loss = None
train_op = None
# TODO: Currently the predictions are the raw logits which
# means that the predictions will not be correct for anything other than
# regression. Should look into how Keras handles this.
predictions = {"predictions": logits}

if mode == tf.estimator.ModeKeys.PREDICT:
# TODO: Populate export_outputs for SavedModel.
export_outputs = {}
elif mode == tf.estimator.ModeKeys.EVAL:
eval_results = {}
for metric in self.metrics:
# We wrap the metric within a function since Estimator subnetworks
# need to have this created within their graphs.
metric = metric()
metric.update_state(y_true=labels, y_pred=logits)
eval_results[metric.name] = metric
eval_metric_ops = eval_results
loss = tf.math.reduce_mean(self.loss(y_true=labels, y_pred=logits))
elif mode == tf.estimator.ModeKeys.TRAIN:
loss = tf.math.reduce_mean(self.loss(y_true=labels, y_pred=logits))
train_op = tf.no_op()

return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
eval_metric_ops=eval_metric_ops,
export_outputs=export_outputs,
train_op=train_op)


class Model(object):
"""A `tf.keras.Model`-like object for training, evaluation, and serving."""

# Usage of lambdas here to defer the instantiation of these objects. This
# behavior is required since Estimator subnetworks expect these objects to
# be created within functions.

# pylint: disable=g-long-lambda
_metrics_map = {
"auc":
lambda: tf.keras.metrics.AUC(name="auc"),
"accuracy":
lambda: tf.keras.metrics.Accuracy(name="accuracy"),
"precision":
lambda: tf.keras.metrics.Precision(name="precision"),
"mae":
lambda: tf.keras.metrics.MeanAbsoluteError(name="mae"),
"mean_absolute_error":
lambda: tf.keras.metrics.MeanAbsoluteError(name="mean_absolute_error"
),
"recall":
lambda: tf.keras.metrics.Recall(name="recall"),
}

# pylint: enable=g-long-lambda

def __init__(self,
subnetwork_generator,
max_iteration_steps,
Expand Down Expand Up @@ -119,24 +151,6 @@ def __init__(self,
self._model = None
self._metrics_names = ["loss"]

# Import here to avoid strict BUILD deps check.
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
from tensorflow_estimator.python.estimator.head import binary_class_head
from tensorflow_estimator.python.estimator.head import multi_class_head
from tensorflow_estimator.python.estimator.head import regression_head
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top

self._loss_head_map = {
"binary_crossentropy":
lambda: binary_class_head.BinaryClassHead(), # pylint: disable=unnecessary-lambda
"mse":
lambda: regression_head.RegressionHead(self._logits_dimension),
"mean_squared_error":
lambda: regression_head.RegressionHead(self._logits_dimension),
"sparse_categorical_crossentropy":
lambda: multi_class_head.MultiClassHead(self._logits_dimension),
}

@property
def metrics_names(self):
return self._metrics_names
Expand Down Expand Up @@ -165,8 +179,7 @@ def fit(self, x, epochs=1, steps_per_epoch=None, callbacks=None):

if self._model is not None:
for _ in range(epochs):
self._model.train(
input_fn=_dataset_to_input_fn(x), steps=steps_per_epoch)
self._model.train(input_fn=x, steps=steps_per_epoch)
else:
raise RuntimeError(
"You must compile your model before training. Use `model.compile(loss)`."
Expand Down Expand Up @@ -196,8 +209,7 @@ def evaluate(self, x, steps=None, callbacks=None):
logging.warning("Callbacks are currently not supported.")

if self._model is not None:
results = self._model.evaluate(
input_fn=_dataset_to_input_fn(x), steps=steps)
results = self._model.evaluate(input_fn=x, steps=steps)
return [results[result] for result in self._metrics_names]
else:
raise RuntimeError(
Expand Down Expand Up @@ -233,8 +245,9 @@ def predict(
logging.warning("Callbacks are currently not supported.")

if self._model is not None:
results = self._model.predict(
_dataset_to_input_fn(x), yield_single_examples=False)
results = self._model.predict(x, yield_single_examples=False)
# TODO: Make predictions match the format of the task class.
logging.warning("Prediction results are in raw logit form.")
# Convert the generator object returned by Estimator's predict method to a
# numpy array of all the predictions.
return next(results)["predictions"]
Expand All @@ -247,63 +260,42 @@ def compile(self, loss, metrics=None):
"""Configures the model for training.
Args:
loss: String of a built in `tf.keras.Loss` function.
metrics: List of metric string names and functions that return metric
objects. (e.g. [lambda: tf.keras.metrics.Accuracy(), "mae"]). If passing
in a function that returns a metric, it is necessary for it to have a
name.
loss: A `tf.keras.losses.Loss`. Note: must set `from_logits` to True if
the loss is a non-regression loss.
metrics: List of lambdas that return `tf.keras.metric.Metric` objects.
Each metric object must have `name` set to some string and
`from_logits` set to True if it is a non-regression metric.
Raises:
ValueError: If the loss is not a supported loss.
ValueError: If one of the metrics passed into metrics is not a supported
metric.
ValueError: If a metric does not have a name.
"""

if metrics is None:
metrics = []

for metric in metrics:
if callable(metric):
self._metrics_names.append(metric().name)
elif metric in Model._metrics_map:
self._metrics_names.append(metric)
else:
raise ValueError(
"'{}' is not a currently supported metric. Currently supported metrics are: {}"
.format(metric, Model._metrics_map.keys()))

def _metric_fn(predictions, features, labels):
"""Internal metric_fn to add passed in metrics to underlying Estimator."""
del features # unused

eval_results = {}
for metric in metrics:
if not callable(metric):
metric = Model._metrics_map[metric]
# We wrap the metric within a function since Estimator subnetworks
# need to have this created within their graphs.
metric = metric()
metric.update_state(y_true=labels, y_pred=predictions["predictions"])
eval_results[metric.name] = metric

return eval_results

head = self._loss_head_map.get(loss, None)
if head is not None:
self._model = core.Estimator(
head=head(),
metric_fn=_metric_fn,
max_iteration_steps=self._max_iteration_steps,
ensemblers=self._ensemblers,
ensemble_strategies=self._ensemble_strategies,
evaluator=self._evaluator,
adanet_loss_decay=self._adanet_loss_decay,
model_dir=self._filepath,
subnetwork_generator=self._subnetwork_generator)
else:
raise ValueError(
"'{}' is not a currently supported loss. Currently supported losses are: {}."
.format(loss, self._loss_head_map.keys()))
# TODO: Assure `from_logits=True` for every metric.
logging.warning(
"Assure non-regression metrics initialized with `from_logits=True`.")

for metric in metrics:
metric = metric()
if metric.name is None:
raise ValueError("Metrics must have names.")
self._metrics_names.append(metric.name)

keras_head = _KerasHead(
logits_dimension=self._logits_dimension, loss=loss, metrics=metrics)

self._model = core.Estimator(
head=keras_head,
subnetwork_generator=self._subnetwork_generator,
max_iteration_steps=self._max_iteration_steps,
ensemblers=self._ensemblers,
ensemble_strategies=self._ensemble_strategies,
evaluator=self._evaluator,
adanet_loss_decay=self._adanet_loss_decay,
model_dir=self._filepath)

# TODO: Implement `adanet.Model#save.`
def save(self):
raise NotImplementedError("Saving is currently not supported.")
53 changes: 30 additions & 23 deletions adanet/keras/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,31 +137,25 @@ def build_subnetwork_train_op(self, subnetwork, loss, var_list, labels,

class ModelTest(tu.AdanetTestCase):

# pylint: disable=g-long-lambda
@parameterized.named_parameters(
{
"testcase_name": "one_step_binary_crossentropy_loss",
"loss": "binary_crossentropy",
"loss": tf.keras.losses.BinaryCrossentropy(from_logits=True),
"metrics":
[lambda: tf.keras.metrics.BinaryCrossentropy(name="bin_acc",
from_logits=True)],
"subnetwork_generator": SimpleGenerator([_DNNBuilder("dnn")]),
"max_iteration_steps": 1,
"epochs": 1,
"steps_per_epoch": 3,
"want_metrics_names": ["loss", "bin_acc"],
"want_loss": 0.7690,
"want_metrics": [0.7690]
},
{
"testcase_name": "one_step_mse_loss",
"loss": "mse",
"metrics": ["mean_absolute_error"],
"subnetwork_generator": SimpleGenerator([_DNNBuilder("dnn")]),
"max_iteration_steps": 1,
"epochs": 1,
"steps_per_epoch": 3,
"want_metrics_names": ["loss", "mean_absolute_error"],
"want_loss": 0.6354,
"want_metrics": [0.6191]
},
{
"testcase_name": "lambda_metric",
"loss": "mse",
"loss": tf.keras.losses.MeanSquaredError(),
"metrics": [lambda: tf.keras.metrics.MeanAbsoluteError(name="mae")],
"subnetwork_generator": SimpleGenerator([_DNNBuilder("dnn")]),
"max_iteration_steps": 1,
Expand All @@ -173,7 +167,8 @@ class ModelTest(tu.AdanetTestCase):
},
{
"testcase_name": "one_step_sparse_categorical_crossentropy_loss",
"loss": "sparse_categorical_crossentropy",
"loss":
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
"subnetwork_generator": SimpleGenerator([_DNNBuilder("dnn")]),
"max_iteration_steps": 1,
"epochs": 1,
Expand All @@ -183,6 +178,7 @@ class ModelTest(tu.AdanetTestCase):
"dataset": lambda: tf.data.Dataset.from_tensors(({"x": XOR_FEATURES}, # pylint: disable=g-long-lambda
XOR_CLASS_LABELS))
})
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def test_lifecycle(self,
loss,
Expand Down Expand Up @@ -228,15 +224,14 @@ def test_lifecycle(self,
if metrics:
self.assertAllClose(want_metrics, eval_results[1:], 1e-3, 1e-3)

# TODO: Predict not currently working for BinaryClassHead and
# MultiClassHead.
if loss == "mse":
prediction_data = lambda: tf.data.Dataset.from_tensors(({ # pylint: disable=g-long-lambda
"x": XOR_FEATURES
}))
prediction_data = lambda: tf.data.Dataset.from_tensors(({ # pylint: disable=g-long-lambda
"x": XOR_FEATURES
}))

predictions = keras_model.predict(prediction_data)
self.assertLen(predictions, 4)
# TODO: Change the assertion to actually check the values rather
# than the length of the returned predictions array.
predictions = keras_model.predict(prediction_data)
self.assertLen(predictions, 4)

@test_util.run_in_graph_and_eager_modes
def test_compile_exceptions(self):
Expand All @@ -255,6 +250,18 @@ def test_compile_exceptions(self):
with self.assertRaises(RuntimeError):
keras_model.predict(predict_data)

@test_util.run_in_graph_and_eager_modes
def test_loss_exceptions(self):
"""Check that ValueError is raised when from_logits=False for loss."""
keras_model = model.Model(
subnetwork_generator=SimpleGenerator([_DNNBuilder("dnn")]),
max_iteration_steps=1)

loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)

with self.assertRaises(ValueError):
keras_model.compile(loss=loss)


if __name__ == "__main__":
tf.enable_v2_behavior()
Expand Down

0 comments on commit a673986

Please sign in to comment.