Skip to content

Commit

Permalink
prime->[update] optimzation.py and modeling.py for tensorflow 2x.
Browse files Browse the repository at this point in the history
MOTIVATION
PR to resolve `AttributeError: module 'tensorflow._api.v2.train' has no attribute 'Optimizer'`

`optimization.py`:
  - change `tf.train.Optimizer` -> `tf.compat.v1.train.Optimizer` due to tensorflow 2x.
  src: [ https://www.tensorflow.org/api_docs/python/tf/compat/v1 ]

`modeling.py`: (due to changes made in `optimization.py`)
  - add `@tf.function` decorator above `get_shape_list(...)` due to eager execution enabled by default in tensorflow 2x.
  src: [ https://www.tensorflow.org/guide/function ]

  - change `tf.get_variable` -> `tf.compat.v1.get_variable`.
  - change `tf.variable_score` -> `tf.compat.v1.variable_score`.
  - change `tf.truncated_normal_initializer` -> `tf.compat.v1.truncated_normal_initializer`.
  - change `tf.assert_less_equal` -> `tf.compat.v1.assert_less_equal`.
  - change `tf.get_variable_scope` -> `tf.compat.v1.get_variable_scope`.
  src: [ https://www.tensorflow.org/api_docs/python/tf/compat/v1 ]

  - change `tf.layers.dense` -> `tf.keras.layers.Dense` due to Keras 3.
  - change `tf.contrib.layers.layer_norm` -> `tf.keras.layers.LayerNormalization`.
  src: [ https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense ]
  src: [ https://stackoverflow.com/a/62357941/11492382 ]

CONFIGURATION
software  : Windows 11 Home [21H2]. WSL version 2.2.4.0; openSUSE tumbleweed 20240629
hardware  : 64-bit operating system, x64-based processor
conda_env :
  cudnn      8.9.7.29  conda-forge
  python     3.12.2    conda-forge
  tensorflow 2.16.1    conda-forge

Reported-by  : Майкл Шодеке
Signed-off-by: Майкл Шодеке
  • Loading branch information
p0lyMth committed Jul 20, 2024
1 parent eedf571 commit 938a485
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 44 deletions.
89 changes: 47 additions & 42 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class BertModel(object):
model = modeling.BertModel(config=config, is_training=True,
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
label_embeddings = tf.get_variable(...)
label_embeddings = tf.compat.v1.get_variable(...)
pooled_output = model.get_pooled_output()
logits = tf.matmul(pooled_output, label_embeddings)
...
Expand Down Expand Up @@ -168,8 +168,8 @@ def __init__(self,
if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope("embeddings"):
with tf.compat.v1.variable_scope(scope, default_name="bert"):
with tf.compat.v1.variable_scope("embeddings"):
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
input_ids=input_ids,
Expand All @@ -193,7 +193,7 @@ def __init__(self,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)

with tf.variable_scope("encoder"):
with tf.compat.v1.variable_scope("encoder"):
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
# mask of shape [batch_size, seq_length, seq_length] which is used
# for the attention scores.
Expand Down Expand Up @@ -221,15 +221,16 @@ def __init__(self,
# [batch_size, hidden_size]. This is necessary for segment-level
# (or segment-pair-level) classification tasks where we need a fixed
# dimensional representation of the segment.
with tf.variable_scope("pooler"):

with tf.compat.v1.variable_scope("pooler"):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token. We assume that this has been pre-trained
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
self.pooled_output = tf.layers.dense(
first_token_tensor,
self.pooled_output = tf.keras.layers.Dense(
config.hidden_size,
activation=tf.tanh,
kernel_initializer=create_initializer(config.initializer_range))
kernel_initializer=create_initializer(config.initializer_range)
)(first_token_tensor)

def get_pooled_output(self):
return self.pooled_output
Expand Down Expand Up @@ -361,8 +362,9 @@ def dropout(input_tensor, dropout_prob):

def layer_norm(input_tensor, name=None):
"""Run layer normalization on the last dimension of the tensor."""
return tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
# NEW <-- tf.contrib.layers.layer_norm -> tf.keras.layers.LayerNormalization
layer_norma = tf.keras.layers.LayerNormalization(axis=-1)
return layer_norma(input_tensor)


def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
Expand All @@ -374,7 +376,7 @@ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):

def create_initializer(initializer_range=0.02):
"""Creates a `truncated_normal_initializer` with the given range."""
return tf.truncated_normal_initializer(stddev=initializer_range)
return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range)


def embedding_lookup(input_ids,
Expand Down Expand Up @@ -406,7 +408,7 @@ def embedding_lookup(input_ids,
if input_ids.shape.ndims == 2:
input_ids = tf.expand_dims(input_ids, axis=[-1])

embedding_table = tf.get_variable(
embedding_table = tf.compat.v1.get_variable(
name=word_embedding_name,
shape=[vocab_size, embedding_size],
initializer=create_initializer(initializer_range))
Expand Down Expand Up @@ -473,7 +475,7 @@ def embedding_postprocessor(input_tensor,
if token_type_ids is None:
raise ValueError("`token_type_ids` must be specified if"
"`use_token_type` is True.")
token_type_table = tf.get_variable(
token_type_table = tf.compat.v1.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
initializer=create_initializer(initializer_range))
Expand All @@ -487,9 +489,9 @@ def embedding_postprocessor(input_tensor,
output += token_type_embeddings

if use_position_embeddings:
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
assert_op = tf.compat.v1.assert_less_equal(seq_length, max_position_embeddings)
with tf.control_dependencies([assert_op]):
full_position_embeddings = tf.get_variable(
full_position_embeddings = tf.compat.v1.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
Expand Down Expand Up @@ -663,28 +665,28 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
to_tensor_2d = reshape_to_matrix(to_tensor)

# `query_layer` = [B*F, N*H]
query_layer = tf.layers.dense(
from_tensor_2d,
query_layer = tf.keras.layers.Dense(
num_attention_heads * size_per_head,
activation=query_act,
name="query",
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range)
)(from_tensor_2d)

# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
to_tensor_2d,
key_layer = tf.keras.layers.Dense(
num_attention_heads * size_per_head,
activation=key_act,
name="key",
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range)
)(to_tensor_2d)

# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
to_tensor_2d,
value_layer = tf.keras.layers.Dense(
num_attention_heads * size_per_head,
activation=value_act,
name="value",
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range)
)(to_tensor_2d)

# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
Expand Down Expand Up @@ -824,12 +826,11 @@ def transformer_model(input_tensor,

all_layer_outputs = []
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx):
with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
layer_input = prev_output

with tf.variable_scope("attention"):
with tf.compat.v1.variable_scope("attention"):
attention_heads = []
with tf.variable_scope("self"):
with tf.compat.v1.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
Expand All @@ -854,28 +855,31 @@ def transformer_model(input_tensor,

# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
attention_output,

with tf.compat.v1.variable_scope("output"):
attention_output = tf.keras.layers.Dense(
hidden_size,
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range)
)(attention_output)
attention_output = dropout(attention_output, hidden_dropout_prob)
attention_output = layer_norm(attention_output + layer_input)

# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
attention_output,

with tf.compat.v1.variable_scope("intermediate"):
intermediate_output = tf.keras.layers.Dense(
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range)
)(attention_output)

# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,

with tf.compat.v1.variable_scope("output"):
layer_output = tf.keras.layers.Dense(
hidden_size,
kernel_initializer=create_initializer(initializer_range))
kernel_initializer=create_initializer(initializer_range)
)(intermediate_output,)
layer_output = dropout(layer_output, hidden_dropout_prob)
layer_output = layer_norm(layer_output + attention_output)
prev_output = layer_output
Expand All @@ -892,6 +896,7 @@ def transformer_model(input_tensor,
return final_output


@tf.function
def get_shape_list(tensor, expected_rank=None, name=None):
"""Returns a list of the shape of tensor, preferring static dimensions.
Expand All @@ -908,7 +913,7 @@ def get_shape_list(tensor, expected_rank=None, name=None):
as tf.Tensor scalars.
"""
if name is None:
name = tensor.name
name = tensor.name

if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
Expand Down Expand Up @@ -979,7 +984,7 @@ def assert_rank(tensor, expected_rank, name=None):

actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
scope_name = tf.get_variable_scope().name
scope_name = tf.compat.v1.get_variable_scope().name
raise ValueError(
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
Expand Down
3 changes: 1 addition & 2 deletions optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
return train_op


class AdamWeightDecayOptimizer(tf.train.Optimizer):
class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer):
"""A basic Adam optimizer that includes "correct" L2 weight decay."""

def __init__(self,
Expand Down

0 comments on commit 938a485

Please sign in to comment.