diff --git a/arch_search_tf.py b/arch_search_tf.py new file mode 100644 index 0000000..fb616a0 --- /dev/null +++ b/arch_search_tf.py @@ -0,0 +1,225 @@ +from functools import total_ordering +from typing import Dict, Optional, List, Union + +import torch +import torch.nn.functional +from tensorflow import keras +import tensorflow as tf +import tensorflow_probability as tfp +from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph + + +@total_ordering +class NasModel(keras.Model): + def __init__(self, model): + super().__init__() + self.mod = model + + def compile(self, *inputs, arch_optimizer, **kwargs): + self.arch_optimizer = arch_optimizer + + self.arch_params = [] + for mod in self.submodules: + if isinstance(mod, MixedModuleTf): + assert mod.built + self.arch_params.append(mod.gumble_arch_params) + self.non_arch_params = [] + for v in self.trainable_variables: + if not any(v is arch_param for arch_param in self.arch_params): + self.non_arch_params.append(v) + + assert len(self.trainable_variables) == len(self.arch_params) + len(self.non_arch_params) + + self.concat_params = self.non_arch_params + self.arch_params + + print("Arch Parameters:", len(self.arch_params)) + print("Non-arch Parameters:", len(self.non_arch_params)) + print("All Parameters:", len(self.concat_params)) + + super().compile(*inputs, **kwargs) + + def train_step(self, data, slow_assert=False): + # Unpack the data. Its structure depends on your model and + # on what you pass to `fit()`. + x, y = data + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) # Forward pass + # Compute the loss value + # (the loss function is configured in `compile()`) + loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) + + gradients = tape.gradient(loss, self.concat_params) + + non_arch_gradients = gradients[0:len(self.non_arch_params)] + arch_gradients = gradients[len(self.non_arch_params):] + + if slow_assert: + assert all(first is second for first, second in zip(self.non_arch_params, self.concat_params[0:len(self.non_arch_params)])) + assert all(first is second for first, second in zip(self.arch_params, self.concat_params[len(self.non_arch_params):])) + + # Compute gradients + non_arch_params = self.non_arch_params + # Update non arch weights + self.optimizer.apply_gradients(zip(non_arch_gradients, non_arch_params)) + + arch_params = self.arch_params + # Update arch weights + self.arch_optimizer.apply_gradients(zip(arch_gradients, arch_params)) + + + # Update metrics (includes the metric that tracks the loss) + self.compiled_metrics.update_state(y, y_pred) + # Return a dict mapping metric names to current value + return {m.name: m.result() for m in self.metrics} + + def __gt__(self, other): + return id(self) > id(other) + + def call(self, inp): + return self.mod(inp) + +def get_flops_inputs(model, input_shape): + + real_model = tf.function(model).get_concrete_function(tf.TensorSpec(input_shape, tf.float32)) + frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(real_model) + + run_meta = tf.compat.v1.RunMetadata() + opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() + flops = tf.compat.v1.profiler.profile(graph=frozen_func.graph, + run_meta=run_meta, cmd='op', options=opts) + return flops.total_float_ops + +class MixedModuleTf(keras.layers.Layer): + def __init__(self, ops: Union[List[keras.layers.Layer], Dict[str, keras.layers.Layer]], + cost_loss_multiplier=0.0): + super().__init__() + if isinstance(ops, list): + ops = {str(i): op for i, op in enumerate(ops)} + assert len(ops) > 1 + for name, module in ops.items(): + self.add_module(name, module) + self.op_names = list(ops.keys()) + self.cost_loss_multiplier = cost_loss_multiplier + + def build(self, input_shape): + print("Build ") + self.gumbel_temperature = self.add_weight( + shape=(), + initializer=tf.constant_initializer(1), + trainable=False + ) + + # TODO(ashaw596): Figure out what to do about op cost + self.ops_cost_static = self.add_weight( + shape=(1, len(self.op_names)), + initializer=tf.constant_initializer(0), + trainable=False + ) + + self.gumble_arch_params = self.add_weight( + shape=(len(self.op_names)), + initializer=tf.constant_initializer(1), + trainable=True + ) + + self.gumbel_dist = tfp.distributions.RelaxedOneHotCategorical( + logits=self.gumble_arch_params, + temperature=self.gumbel_temperature, + ) + + for name in self.op_names: + self.get_module(name).build(input_shape) + super().build(input_shape) + + for i, op_name in enumerate(self.op_names): + flops = get_flops_inputs(self.get_module(op_name), [1] + list(input_shape[1:])) + self.ops_cost_static[0,i].assign(flops) + print("flops", op_name, flops) + + print(self.ops_cost_static) + # self.register_buffer('ops_cost_static', torch.zeros(len(self.ops))) + # self.gumble_arch_params = torch.nn.Parameter(torch.ones(len(self.ops), 1)) + # self.register_buffer('gumbel_temperature', torch.ones(1)) + + + def add_module(self, name, module): + setattr(self, 'sublayer_' + name, module) + + def get_module(self, name): + return getattr(self, 'sublayer_' + name) + + + def call(self, inp, *inputs, **kwargs): + batch_size = tf.shape(inp)[0] + gumbel_weights = self.gumbel_dist.sample(batch_size) + + outputs = [] + for i, name in enumerate(self.op_names): + outputs.append(self.get_module(name)(inp, *inputs, **kwargs)) + + concat_outputs = tf.stack(outputs, axis=1) + print(tf.shape(gumbel_weights)) + orig_shape = tf.shape(gumbel_weights) + shape = tf.shape(concat_outputs) + reshaped_gumbel_weights = tf.reshape(gumbel_weights, shape=[orig_shape[0], orig_shape[1]] + [1]*(len(shape) - 2)) + weighted_outputs = reshaped_gumbel_weights * concat_outputs + + output = tf.math.reduce_sum(weighted_outputs, axis=1) + + #TODO(ashaw596): cost loss + cost = self.ops_cost_static * gumbel_weights + cost_loss = tf.reduce_mean(tf.reduce_sum(cost * self.cost_loss_multiplier, axis=1)) + self.add_loss(cost_loss) + + return output + + + +class SupernetArchWatcherCallback(keras.callbacks.Callback): + def __init__(self, model): + self.gumbel_arch_params = [] + self.op_names = [] + for mod in model.submodules: + if isinstance(mod, MixedModuleTf): + assert mod.built + self.gumbel_arch_params.append(mod.gumble_arch_params) + self.op_names.append(mod.op_names) + + def on_epoch_end(self, epoch, logs=None): + genotype = [] + for names, params in zip(self.op_names, self.gumbel_arch_params): + probs = tf.nn.softmax(params) + gene = {} + for i, name in enumerate(names): + gene[name] = probs[i] + genotype.append(gene) + print("Genotype: epoch: ", epoch, genotype) + + +class SupernetTemperatureCallback(keras.callbacks.Callback): + def __init__(self, model, start_epoch, final_epoch, start_temp, end_temp): + self.temperature_variables = [] + for mod in model.submodules: + if isinstance(mod, MixedModuleTf): + assert mod.built + self.temperature_variables.append(mod.gumbel_temperature) + + print("Temperature Variables Found:", len(self.temperature_variables)) + + self.start_epoch = start_epoch + self.final_epoch = final_epoch + self.start_temp = start_temp + self.end_temp = end_temp + + assert self.start_temp > self.end_temp + + def on_epoch_begin(self, epoch, logs=None): + print("Epoch", epoch) + + delta_temp = (self.end_temp - self.start_temp) / (self.final_epoch - self.start_epoch) + temperature = max(self.start_temp + delta_temp * max(epoch - self.start_epoch, 0), self.end_temp) + print("Temperature", temperature) + + for temp_var in self.temperature_variables: + temp_var.assign(temperature) diff --git a/requirements.txt b/requirements.txt index 237bcee..b509a95 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ rope==0.18.0 autopep8==1.5.5 -thop \ No newline at end of file +thop +torch +tensorflow +tensorflow_probability \ No newline at end of file diff --git a/test_arch_search_tf.py b/test_arch_search_tf.py new file mode 100644 index 0000000..8059b9a --- /dev/null +++ b/test_arch_search_tf.py @@ -0,0 +1,85 @@ +from unittest import TestCase + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers + +from arch_search_tf import MixedModuleTf, SupernetTemperatureCallback, NasModel, SupernetArchWatcherCallback + + +class TestMixedModuleTf(TestCase): + def test_build(self): + # resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + # tf.config.experimental_connect_to_cluster(resolver) + # tf.tpu.experimental.initialize_tpu_system(resolver) + # strategy = tf.distribute.TPUStrategy(resolver) + + strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") + + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + num_classes = 10 + input_shape = (28, 28, 1) + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + print("x_train shape:", x_train.shape) + print(x_train.shape[0], "train samples") + print(x_test.shape[0], "test samples") + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + with strategy.scope(): + model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + # layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + MixedModuleTf({ + '3x3': layers.Conv2D(32, kernel_size=(3, 3), activation="relu", padding="same"), + '1x1': layers.Conv2D(32, kernel_size=(1, 1), activation="relu", padding="same") + }, cost_loss_multiplier=1e-6), + layers.MaxPooling2D(pool_size=(2, 2)), + MixedModuleTf([ + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu") + ]), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + # MyDenseLayer(num_classes) + MixedModuleTf([ + layers.Dense(num_classes, activation="softmax"), + layers.Dense(num_classes, activation="softmax"), + ]) + ] + ) + + arch_optim = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0) + + model = NasModel(model) + model.build([None] + list(input_shape)) + + model.summary() + batch_size = 128 + + callbacks = [ + SupernetTemperatureCallback(model, start_epoch=2, final_epoch=5, start_temp=5, end_temp=1), + SupernetArchWatcherCallback(model), + ] + + model.compile(loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"], + arch_optimizer=arch_optim) + model.fit(x_train, y_train, + steps_per_epoch=10, + batch_size=batch_size, + epochs=10, + validation_split=0.1, + callbacks=callbacks) + # self.fail()