-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
New Issue Checklist
- I have read the Contribution Guidelines
- I searched for existing GitHub issues
Issue Description
希望将tensorlayer用于自定义estimator,Dropout layer在estimator的训练、预测模式切换时报错。
Reproducible Code
版本:
tensorflow = 1.13.1
tensorlayer = 1.11.1
# ======================================================== #
###### THIS CODE IS AN EXAMPLE, REPLACE WITH YOUR OWN ######
# ======================================================== #
import time
import tensorflow as tf
import tensorlayer as tl
import numpy as np
import pandas as pd
tf.logging.set_verbosity(tf.logging.DEBUG)
tl.logging.set_verbosity(tl.logging.DEBUG)
def inference(x, reuse=False, is_training=True):
# x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
keep = 0.5 # if is_training else 1
# define the network
with tf.variable_scope("model", reuse=reuse):
network = [
tl.layers.ReshapeLayer(tl.layers.InputLayer(x[k], name=f'input_{k}'), shape=(-1, 1), name=f'reshape_{k}')
for k in x.keys()]
network = tl.layers.ConcatLayer(network, concat_dim=1)
network = tl.layers.DropoutLayer(network, keep=keep, name='drop1', is_fix=True, is_train=is_training)
network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu1')
network = tl.layers.DropoutLayer(network, keep=keep, name='drop2', is_fix=True, is_train=is_training)
network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu2')
network = tl.layers.DropoutLayer(network, keep=keep, name='drop3', is_fix=True, is_train=is_training)
# the softmax is implemented internally in tl.cost.cross_entropy(y, y_) to
# speed up computation, so we use identity here.
# see tf.nn.sparse_softmax_cross_entropy_with_logits()
network = tl.layers.DenseLayer(network, n_units=3, act=None, name='output')
# define cost function and metric.
y = network.outputs
return y
def model_fn(features, labels, mode, params):
"""
Model_fn for estimator model
Args:
features (Tensor): Input features to the model.
labels (Tensor): Labels tensor for training and evaluation.
mode (ModeKeys): Specifies if training, evaluation or prediction.
params (HParams): hyper-parameters for estimator model
Returns:
(EstimatorSpec): Model to be run by Estimator.
"""
# check if training stage
if mode == tf.estimator.ModeKeys.TRAIN:
is_training = True
reuse = False
else:
is_training = False
reuse = True
# is_training = False # 1
x = features
logits = inference(x, reuse, is_training)
predicted_classes = tf.argmax(logits, 1) # 预测的结果中最大值即种类
# provide a tf.estimator spec for PREDICT
predictions_dict = {"score": logits,
"label": predicted_classes}
if mode == tf.estimator.ModeKeys.PREDICT:
predictions_output = tf.estimator.export.PredictOutput(predictions_dict)
return tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions_dict,
export_outputs={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: predictions_output
})
# calculate loss
# loss = focal_loss(onehot_labels, logits, gamma=1.5)
loss = tf.losses.sparse_softmax_cross_entropy(labels=tf.cast(labels, tf.int32), logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdagradOptimizer(learning_rate=0.1) # 用它优化损失函数,达到损失最少精度最高
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) # 执行优化!
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
else:
# 评价
accuracy = tf.metrics.accuracy(labels=labels,
predictions=predicted_classes,
name='acc_op') # 计算精度
metrics = {'accuracy': accuracy} # 返回格式
tf.summary.scalar('accuracy', accuracy[1]) # 仅为了后面图表统计使用
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
train_path = tf.keras.utils.get_file(
"iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
"iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0).astype('float32')
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0).astype('float32')
train_y = train.pop('Species').astype('int32')
test_y = test.pop('Species').astype('int32')
# 针对测试的喂食函数
def eval_input_fn(features, labels, batch_size=256):
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset = dataset.batch(batch_size)
# return dataset
return dataset.make_one_shot_iterator().get_next()
def input_fn(features, labels, training=True, batch_size=256):
"""An input function for training or evaluating"""
# 将输入转换为数据集。
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# 如果在训练模式下混淆并重复数据。
if training:
dataset = dataset.shuffle(1000).repeat(batch_size * 10)
return dataset.batch(batch_size)
print('-----------------------define model____________________________')
model = tf.estimator.Estimator(model_fn)
print('-----------------------train model____________________________')
model.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=500)
print('-----------------------eval model____________________________')
print(model.evaluate(input_fn=lambda: eval_input_fn(test, test_y)))
def eval_pred_fn(features, batch_size=256):
dataset = tf.data.Dataset.from_tensor_slices((dict(features)))
dataset = dataset.batch(batch_size)
# return dataset
return dataset.make_one_shot_iterator().get_next()
print('-----------------------pred model____________________________')
result = list(model.predict(input_fn=lambda: eval_pred_fn(test)))
报错信息
-----------------------define model____________________________
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: C:\Users\dengyin\AppData\Local\Temp\tmpw0csioqr
INFO:tensorflow:Using config: {'model_dir': 'C:\Users\dengyin\AppData\Local\Temp\tmpw0csioqr', 'tf_random_seed': None, 'save_summary_steps': 100, 'save_checkpoints_steps': None, 'save_checkpoints_secs': 600, 'session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, 'keep_checkpoint_max': 5, 'keep_checkpoint_every_n_hours': 10000, 'log_step_count_steps': 100, 'train_distribute': None, 'device_fn': None, 'protocol': None, 'eval_distribute': None, 'experimental_distribute': None, 'service': None, 'cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x0000020297831668>, 'task_type': 'worker', 'task_id': 0, 'global_id_in_cluster': 0, 'master': '', 'evaluation_master': '', 'is_chief': True, 'num_ps_replicas': 0, 'num_worker_replicas': 1}
WARNING:tensorflow:Estimator's model_fn (<function model_fn at 0x000002028E861E18>) includes params argument, but params are not passed to Estimator.
-----------------------train model____
INFO:tensorflow:Calling model_fn.
[TL] InputLayer model/input_SepalLength: (?,)
[TL] ReshapeLayer model/reshape_SepalLength: (?, 1)
[TL] InputLayer model/input_SepalWidth: (?,)
[TL] ReshapeLayer model/reshape_SepalWidth: (?, 1)
[TL] InputLayer model/input_PetalLength: (?,)
[TL] ReshapeLayer model/reshape_PetalLength: (?, 1)
[TL] InputLayer model/input_PetalWidth: (?,)
[TL] ReshapeLayer model/reshape_PetalWidth: (?, 1)
[TL] ConcatLayer model/concat_layer: axis: 1
[TL] DropoutLayer model/drop1: keep: 0.500000 is_fix: True
WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorlayer\layers\dropout.py:100: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use rate
instead of keep_prob
. Rate should be set to rate = 1 - keep_prob
.
[TL] DenseLayer model/relu1: 800 relu
[TL] DropoutLayer model/drop2: keep: 0.500000 is_fix: True
[TL] DenseLayer model/relu2: 800 relu
[TL] DropoutLayer model/drop3: keep: 0.500000 is_fix: True
[TL] DenseLayer model/output: 3 No Activation
WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow\python\ops\losses\losses_impl.py:209: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
2021-07-15 11:42:38.703488: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into C:\Users\dengyin\AppData\Local\Temp\tmpw0csioqr\model.ckpt.
INFO:tensorflow:loss = 5.330352, step = 1
INFO:tensorflow:global_step/sec: 54.1661
INFO:tensorflow:loss = 0.72782683, step = 101 (1.846 sec)
INFO:tensorflow:global_step/sec: 55.258
INFO:tensorflow:loss = 0.6588603, step = 201 (1.810 sec)
INFO:tensorflow:global_step/sec: 55.136
INFO:tensorflow:loss = 0.56736594, step = 301 (1.814 sec)
INFO:tensorflow:global_step/sec: 55.9102
INFO:tensorflow:loss = 0.5070057, step = 401 (1.789 sec)
INFO:tensorflow:Saving checkpoints for 500 into C:\Users\dengyin\AppData\Local\Temp\tmpw0csioqr\model.ckpt.
INFO:tensorflow:Loss for final step: 0.5071666.
-----------------------eval model____________________________
INFO:tensorflow:Calling model_fn.
[TL] InputLayer model/input_SepalLength: (?,)
[TL] ReshapeLayer model/reshape_SepalLength: (?, 1)
[TL] InputLayer model/input_SepalWidth: (?,)
[TL] ReshapeLayer model/reshape_SepalWidth: (?, 1)
[TL] InputLayer model/input_PetalLength: (?,)
[TL] ReshapeLayer model/reshape_PetalLength: (?, 1)
[TL] InputLayer model/input_PetalWidth: (?,)
[TL] ReshapeLayer model/reshape_PetalWidth: (?, 1)
[TL] ConcatLayer model/concat_layer: axis: 1
[TL] DropoutLayer model/drop1: keep: 0.500000 is_fix: True
[TL] skip DropoutLayer
[TL] DenseLayer model/relu1: 800 relu
Traceback (most recent call last):
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\IPython\core\interactiveshell.py", line 3343, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 128, in
print(model.evaluate(input_fn=lambda: eval_input_fn(test, test_y)))
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 469, in evaluate
name=name)
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 511, in _actual_eval
return _evaluate()
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 493, in _evaluate
self._evaluate_build_graph(input_fn, hooks, checkpoint_path))
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1424, in _evaluate_build_graph
self._call_model_fn_eval(input_fn, self.config))
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1460, in _call_model_fn_eval
features, labels, model_fn_lib.ModeKeys.EVAL, config)
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1112, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "", line 55, in model_fn
logits = inference(x, reuse, is_training)
File "", line 21, in inference
network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu1')
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorlayer\decorators\deprecated_alias.py", line 24, in wrapper
return f(*args, **kwargs)
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorlayer\layers\dense\base_dense.py", line 90, in init
name='W', shape=(n_in, n_units), initializer=W_init, dtype=LayersConfig.tf_dtype, **self.W_init_args
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1479, in get_variable
aggregation=aggregation)
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1220, in get_variable
aggregation=aggregation)
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 547, in get_variable
aggregation=aggregation)
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 499, in _true_getter
aggregation=aggregation)
File "D:\Program Files\Anaconda3\envs\tf1\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 866, in _get_single_variable
"reuse=tf.AUTO_REUSE in VarScope?" % name)
ValueError: Variable model/relu1/W does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?