Skip to content

Commit

Permalink
Remove AdaNet controller. #1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 288714068
  • Loading branch information
csvillalta authored and cweill committed Jan 14, 2020
1 parent 712bc8e commit 36531c5
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 213 deletions.
16 changes: 0 additions & 16 deletions adanet/experimental/controllers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,3 @@ py_library(
"//adanet/experimental/work_units:work_unit",
],
)

py_library(
name = "adanet_controller",
srcs = ["adanet_controller.py"],
srcs_version = "PY3",
visibility = ["//adanet/experimental:__subpackages__"],
deps = [
":controller",
"//adanet/experimental/keras:ensemble_model",
"//adanet/experimental/phases:phase",
"//adanet/experimental/storages:in_memory_storage",
"//adanet/experimental/storages:storage",
"//adanet/experimental/work_units:keras_trainer",
"//adanet/experimental/work_units:work_unit",
],
)
151 changes: 0 additions & 151 deletions adanet/experimental/controllers/adanet_controller.py

This file was deleted.

1 change: 0 additions & 1 deletion adanet/experimental/keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ py_strict_test(
":ensemble_model",
":model_search",
":testing_utils",
"//adanet/experimental/controllers:adanet_controller",
"//adanet/experimental/controllers:sequential_controller",
"//adanet/experimental/phases:keras_tuner_phase",
"//adanet/experimental/phases:train_keras_models_phase",
Expand Down
48 changes: 3 additions & 45 deletions adanet/experimental/keras/model_search_test.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
# Lint as: python3
# Copyright 2019 The AdaNet Authors. All Rights Reserved.
#

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#

# https://www.apache.org/licenses/LICENSE-2.0
#

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for adanet.experimental.keras.ModelSearch."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil
import sys

from absl import flags
from absl.testing import parameterized
from adanet.experimental.controllers.adanet_controller import AdaNetCandidatePhase
from adanet.experimental.controllers.adanet_controller import AdaNetController
from adanet.experimental.controllers.adanet_controller import AdaNetEnsemblePhase
from adanet.experimental.controllers.sequential_controller import SequentialController
from adanet.experimental.keras import testing_utils
from adanet.experimental.keras.ensemble_model import MeanEnsemble
Expand Down Expand Up @@ -157,41 +150,6 @@ def build_ensemble():
self.assertIsInstance(
model_search.get_best_models(num_models=1)[0], MeanEnsemble)

def test_adanet_controller_end_to_end(self):
train_dataset, test_dataset = testing_utils.get_test_data(
train_samples=1280,
test_samples=640,
input_shape=(10,),
num_classes=10,
random_seed=42)

train_dataset = train_dataset.batch(32)
test_dataset = test_dataset.batch(32)

candidate_phase = AdaNetCandidatePhase(
train_dataset,
candidates_per_iteration=2,
optimizer='adam',
loss='sparse_categorical_crossentropy',
output_units=10)
# TODO: Setting candidates_per_iteration greater than the one
# for the candidate phase will lead to unexpected behavior.
ensemble_phase = AdaNetEnsemblePhase(
train_dataset,
candidates_per_iteration=2,
optimizer='adam',
loss='sparse_categorical_crossentropy')

adanet_controller = AdaNetController(
candidate_phase,
ensemble_phase,
iterations=5)

model_search = ModelSearch(adanet_controller)
model_search.run()
self.assertIsInstance(
model_search.get_best_models(num_models=1)[0], MeanEnsemble)


if __name__ == '__main__':
tf.enable_v2_behavior()
Expand Down

0 comments on commit 36531c5

Please sign in to comment.