Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorflow #13543

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": []
}
3 changes: 3 additions & 0 deletions official/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"""The central place to define flags."""

from absl import flags
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)


def define_flags():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from absl import logging
import numpy as np
import tensorflow as tf, tf_keras
tf.config.optimizer.set_jit(True) # Enable XLA for computation graph optimization



def expand_vector(v: np.ndarray) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import tensorflow as tf, tf_keras
tf.config.optimizer.set_jit(True) # Enable XLA for computation graph optimization

from official.modeling.fast_training.experimental import tf2_utils_2x_wide

Expand Down
1 change: 1 addition & 0 deletions official/modeling/fast_training/progressive/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from official.modeling import performance
from official.modeling.fast_training.progressive import train_lib


FLAGS = flags.FLAGS


Expand Down
18 changes: 18 additions & 0 deletions official/vision/modeling/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_blocks
from official.vision.modeling.layers import nn_layers
from tensorflow.keras.mixed_precision import experimental as mixed_precision

layers = tf_keras.layers

Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
num_classes=1000,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
Expand Down Expand Up @@ -183,6 +185,22 @@ def __init__(
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._bn_trainable = bn_trainable
# Enable mixed precision inside the model
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

self.conv1 = tf.keras.layers.Conv2D(64, 7, activation='relu')
self.pool = tf.keras.layers.MaxPooling2D()
self.flatten = tf.keras.layers.Flatten()
self.fc = tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32') # Keep output in float32

def call(self, inputs):
x = self.conv1(inputs)
x = self.pool(x)
x = self.flatten(x)
return self.fc(x)



if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
Expand Down
8 changes: 8 additions & 0 deletions official/vision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
from official.modeling import performance
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.utils import summary_manager
from tensorflow.keras.mixed_precision import experimental as mixed_precision

# Enable mixed precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

# Proceed with model training...



FLAGS = flags.FLAGS
Expand Down