Skip to content

Commit 0f0c774

Browse files
authored
Merged commit includes the following changes: (tensorflow#8739)
318417714 by jonathanhuang: Internal change. -- 318367213 by sbeery: Pointing users to more documentation for beam -- 318358685 by sbeery: Context R-CNN sample config for GPU -- 318309800 by rathodv: Internal -- 318303364 by ronnyvotel: Adding the option for parsing and including DensePose annotations. http://densepose.org/ -- 318291319 by aom: Adds conv_bn_act conv_block option, and naming convention changes for BiFPN utils. -- 318200598 by ronnyvotel: Updating the TF Example Decoder to parse DensePose annotations. -- 318174065 by jonathanhuang: Internal change. -- 318167805 by rathodv: Add use_tpu flag to TF2 binary. -- 318145285 by aom: Adds option for convolutional keras box predictor to force use_bias. -- PiperOrigin-RevId: 318417714
1 parent 1e4fd82 commit 0f0c774

34 files changed

+967
-1200
lines changed

research/object_detection/core/model.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,9 @@ def regularization_losses(self):
391391
pass
392392

393393
@abc.abstractmethod
394-
def restore_map(self, fine_tune_checkpoint_type='detection'):
394+
def restore_map(self,
395+
fine_tune_checkpoint_type='detection',
396+
load_all_detection_checkpoint_vars=False):
395397
"""Returns a map of variables to load from a foreign checkpoint.
396398
397399
Returns a map of variable names to load from a checkpoint to variables in
@@ -407,13 +409,46 @@ def restore_map(self, fine_tune_checkpoint_type='detection'):
407409
checkpoint (with compatible variable names) or to restore from a
408410
classification checkpoint for initialization prior to training.
409411
Valid values: `detection`, `classification`. Default 'detection'.
412+
load_all_detection_checkpoint_vars: whether to load all variables (when
413+
`fine_tune_checkpoint_type` is `detection`). If False, only variables
414+
within the feature extractor scope are included. Default False.
410415
411416
Returns:
412417
A dict mapping variable names (to load from a checkpoint) to variables in
413418
the model graph.
414419
"""
415420
pass
416421

422+
@abc.abstractmethod
423+
def restore_from_objects(self, fine_tune_checkpoint_type='detection'):
424+
"""Returns a map of variables to load from a foreign checkpoint.
425+
426+
Returns a dictionary of Tensorflow 2 Trackable objects (e.g. tf.Module
427+
or Checkpoint). This enables the model to initialize based on weights from
428+
another task. For example, the feature extractor variables from a
429+
classification model can be used to bootstrap training of an object
430+
detector. When loading from an object detection model, the checkpoint model
431+
should have the same parameters as this detection model with exception of
432+
the num_classes parameter.
433+
434+
Note that this function is intended to be used to restore Keras-based
435+
models when running Tensorflow 2, whereas restore_map (above) is intended
436+
to be used to restore Slim-based models when running Tensorflow 1.x.
437+
438+
TODO(jonathanhuang,rathodv): Check tf_version and raise unimplemented
439+
error for both restore_map and restore_from_objects depending on version.
440+
441+
Args:
442+
fine_tune_checkpoint_type: whether to restore from a full detection
443+
checkpoint (with compatible variable names) or to restore from a
444+
classification checkpoint for initialization prior to training.
445+
Valid values: `detection`, `classification`. Default 'detection'.
446+
447+
Returns:
448+
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
449+
"""
450+
pass
451+
417452
@abc.abstractmethod
418453
def updates(self):
419454
"""Returns a list of update operators for this model.

research/object_detection/core/model_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def updates(self):
5757
def restore_map(self):
5858
return {}
5959

60+
def restore_from_objects(self, fine_tune_checkpoint_type):
61+
pass
62+
6063
def regularization_losses(self):
6164
return []
6265

research/object_detection/core/standard_fields.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ class InputDataFields(object):
6666
groundtruth_keypoint_weights: groundtruth weight factor for keypoints.
6767
groundtruth_label_weights: groundtruth label weights.
6868
groundtruth_weights: groundtruth weight factor for bounding boxes.
69+
groundtruth_dp_num_points: The number of DensePose sampled points for each
70+
instance.
71+
groundtruth_dp_part_ids: Part indices for DensePose points.
72+
groundtruth_dp_surface_coords: Image locations and UV coordinates for
73+
DensePose points.
6974
num_groundtruth_boxes: number of groundtruth boxes.
7075
is_annotated: whether an image has been labeled or not.
7176
true_image_shapes: true shapes of images in the resized images, as resized
@@ -108,6 +113,9 @@ class InputDataFields(object):
108113
groundtruth_keypoint_weights = 'groundtruth_keypoint_weights'
109114
groundtruth_label_weights = 'groundtruth_label_weights'
110115
groundtruth_weights = 'groundtruth_weights'
116+
groundtruth_dp_num_points = 'groundtruth_dp_num_points'
117+
groundtruth_dp_part_ids = 'groundtruth_dp_part_ids'
118+
groundtruth_dp_surface_coords = 'groundtruth_dp_surface_coords'
111119
num_groundtruth_boxes = 'num_groundtruth_boxes'
112120
is_annotated = 'is_annotated'
113121
true_image_shape = 'true_image_shape'

research/object_detection/data_decoders/tf_example_decoder.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from object_detection.core import standard_fields as fields
3131
from object_detection.protos import input_reader_pb2
3232
from object_detection.utils import label_map_util
33+
from object_detection.utils import shape_utils
3334

3435
# pylint: disable=g-import-not-at-top
3536
try:
@@ -170,7 +171,8 @@ def __init__(self,
170171
num_additional_channels=0,
171172
load_multiclass_scores=False,
172173
load_context_features=False,
173-
expand_hierarchy_labels=False):
174+
expand_hierarchy_labels=False,
175+
load_dense_pose=False):
174176
"""Constructor sets keys_to_features and items_to_handlers.
175177
176178
Args:
@@ -201,6 +203,7 @@ def __init__(self,
201203
account the provided hierarchy in the label_map_proto_file. For positive
202204
classes, the labels are extended to ancestor. For negative classes,
203205
the labels are expanded to descendants.
206+
load_dense_pose: Whether to load DensePose annotations.
204207
205208
Raises:
206209
ValueError: If `instance_mask_type` option is not one of
@@ -371,6 +374,34 @@ def __init__(self,
371374
self._decode_png_instance_masks))
372375
else:
373376
raise ValueError('Did not recognize the `instance_mask_type` option.')
377+
if load_dense_pose:
378+
self.keys_to_features['image/object/densepose/num'] = (
379+
tf.VarLenFeature(tf.int64))
380+
self.keys_to_features['image/object/densepose/part_index'] = (
381+
tf.VarLenFeature(tf.int64))
382+
self.keys_to_features['image/object/densepose/x'] = (
383+
tf.VarLenFeature(tf.float32))
384+
self.keys_to_features['image/object/densepose/y'] = (
385+
tf.VarLenFeature(tf.float32))
386+
self.keys_to_features['image/object/densepose/u'] = (
387+
tf.VarLenFeature(tf.float32))
388+
self.keys_to_features['image/object/densepose/v'] = (
389+
tf.VarLenFeature(tf.float32))
390+
self.items_to_handlers[
391+
fields.InputDataFields.groundtruth_dp_num_points] = (
392+
slim_example_decoder.Tensor('image/object/densepose/num'))
393+
self.items_to_handlers[fields.InputDataFields.groundtruth_dp_part_ids] = (
394+
slim_example_decoder.ItemHandlerCallback(
395+
['image/object/densepose/part_index',
396+
'image/object/densepose/num'], self._dense_pose_part_indices))
397+
self.items_to_handlers[
398+
fields.InputDataFields.groundtruth_dp_surface_coords] = (
399+
slim_example_decoder.ItemHandlerCallback(
400+
['image/object/densepose/x', 'image/object/densepose/y',
401+
'image/object/densepose/u', 'image/object/densepose/v',
402+
'image/object/densepose/num'],
403+
self._dense_pose_surface_coordinates))
404+
374405
if label_map_proto_file:
375406
# If the label_map_proto is provided, try to use it in conjunction with
376407
# the class text, and fall back to a materialized ID.
@@ -547,6 +578,14 @@ def expand_field(field_name):
547578
group_of = fields.InputDataFields.groundtruth_group_of
548579
tensor_dict[group_of] = tf.cast(tensor_dict[group_of], dtype=tf.bool)
549580

581+
if fields.InputDataFields.groundtruth_dp_num_points in tensor_dict:
582+
tensor_dict[fields.InputDataFields.groundtruth_dp_num_points] = tf.cast(
583+
tensor_dict[fields.InputDataFields.groundtruth_dp_num_points],
584+
dtype=tf.int32)
585+
tensor_dict[fields.InputDataFields.groundtruth_dp_part_ids] = tf.cast(
586+
tensor_dict[fields.InputDataFields.groundtruth_dp_part_ids],
587+
dtype=tf.int32)
588+
550589
return tensor_dict
551590

552591
def _reshape_keypoints(self, keys_to_tensors):
@@ -697,6 +736,97 @@ def decode_png_mask(image_buffer):
697736
lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
698737
lambda: tf.zeros(tf.cast(tf.stack([0, height, width]), dtype=tf.int32)))
699738

739+
def _dense_pose_part_indices(self, keys_to_tensors):
740+
"""Creates a tensor that contains part indices for each DensePose point.
741+
742+
Args:
743+
keys_to_tensors: a dictionary from keys to tensors.
744+
745+
Returns:
746+
A 2-D int32 tensor of shape [num_instances, num_points] where each element
747+
contains the DensePose part index (0-23). The value `num_points`
748+
corresponds to the maximum number of sampled points across all instances
749+
in the image. Note that instances with less sampled points will be padded
750+
with zeros in the last dimension.
751+
"""
752+
num_points_per_instances = keys_to_tensors['image/object/densepose/num']
753+
part_index = keys_to_tensors['image/object/densepose/part_index']
754+
if isinstance(num_points_per_instances, tf.SparseTensor):
755+
num_points_per_instances = tf.sparse_tensor_to_dense(
756+
num_points_per_instances)
757+
if isinstance(part_index, tf.SparseTensor):
758+
part_index = tf.sparse_tensor_to_dense(part_index)
759+
part_index = tf.cast(part_index, dtype=tf.int32)
760+
max_points_per_instance = tf.cast(
761+
tf.math.reduce_max(num_points_per_instances), dtype=tf.int32)
762+
num_points_cumulative = tf.concat([
763+
[0], tf.math.cumsum(num_points_per_instances)], axis=0)
764+
765+
def pad_parts_tensor(instance_ind):
766+
points_range_start = num_points_cumulative[instance_ind]
767+
points_range_end = num_points_cumulative[instance_ind + 1]
768+
part_inds = part_index[points_range_start:points_range_end]
769+
return shape_utils.pad_or_clip_nd(part_inds,
770+
output_shape=[max_points_per_instance])
771+
772+
return tf.map_fn(pad_parts_tensor,
773+
tf.range(tf.size(num_points_per_instances)),
774+
dtype=tf.int32)
775+
776+
def _dense_pose_surface_coordinates(self, keys_to_tensors):
777+
"""Creates a tensor that contains surface coords for each DensePose point.
778+
779+
Args:
780+
keys_to_tensors: a dictionary from keys to tensors.
781+
782+
Returns:
783+
A 3-D float32 tensor of shape [num_instances, num_points, 4] where each
784+
point contains (y, x, v, u) data for each sampled DensePose point. The
785+
(y, x) coordinate has normalized image locations for the point, and (v, u)
786+
contains the surface coordinate (also normalized) for the part. The value
787+
`num_points` corresponds to the maximum number of sampled points across
788+
all instances in the image. Note that instances with less sampled points
789+
will be padded with zeros in dim=1.
790+
"""
791+
num_points_per_instances = keys_to_tensors['image/object/densepose/num']
792+
dp_y = keys_to_tensors['image/object/densepose/y']
793+
dp_x = keys_to_tensors['image/object/densepose/x']
794+
dp_v = keys_to_tensors['image/object/densepose/v']
795+
dp_u = keys_to_tensors['image/object/densepose/u']
796+
if isinstance(num_points_per_instances, tf.SparseTensor):
797+
num_points_per_instances = tf.sparse_tensor_to_dense(
798+
num_points_per_instances)
799+
if isinstance(dp_y, tf.SparseTensor):
800+
dp_y = tf.sparse_tensor_to_dense(dp_y)
801+
if isinstance(dp_x, tf.SparseTensor):
802+
dp_x = tf.sparse_tensor_to_dense(dp_x)
803+
if isinstance(dp_v, tf.SparseTensor):
804+
dp_v = tf.sparse_tensor_to_dense(dp_v)
805+
if isinstance(dp_u, tf.SparseTensor):
806+
dp_u = tf.sparse_tensor_to_dense(dp_u)
807+
max_points_per_instance = tf.cast(
808+
tf.math.reduce_max(num_points_per_instances), dtype=tf.int32)
809+
num_points_cumulative = tf.concat([
810+
[0], tf.math.cumsum(num_points_per_instances)], axis=0)
811+
812+
def pad_surface_coordinates_tensor(instance_ind):
813+
"""Pads DensePose surface coordinates for each instance."""
814+
points_range_start = num_points_cumulative[instance_ind]
815+
points_range_end = num_points_cumulative[instance_ind + 1]
816+
y = dp_y[points_range_start:points_range_end]
817+
x = dp_x[points_range_start:points_range_end]
818+
v = dp_v[points_range_start:points_range_end]
819+
u = dp_u[points_range_start:points_range_end]
820+
# Create [num_points_i, 4] tensor, where num_points_i is the number of
821+
# sampled points for instance i.
822+
unpadded_tensor = tf.stack([y, x, v, u], axis=1)
823+
return shape_utils.pad_or_clip_nd(
824+
unpadded_tensor, output_shape=[max_points_per_instance, 4])
825+
826+
return tf.map_fn(pad_surface_coordinates_tensor,
827+
tf.range(tf.size(num_points_per_instances)),
828+
dtype=tf.float32)
829+
700830
def _expand_image_label_hierarchy(self, image_classes, image_confidences):
701831
"""Expand image level labels according to the hierarchy.
702832

research/object_detection/data_decoders/tf_example_decoder_test.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,8 +1096,8 @@ def graph_fn():
10961096
return example_decoder.decode(tf.convert_to_tensor(example))
10971097

10981098
tensor_dict = self.execute_cpu(graph_fn, [])
1099-
self.assertTrue(
1100-
fields.InputDataFields.groundtruth_instance_masks not in tensor_dict)
1099+
self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
1100+
tensor_dict)
11011101

11021102
def testDecodeImageLabels(self):
11031103
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
@@ -1116,8 +1116,7 @@ def graph_fn_1():
11161116
return example_decoder.decode(tf.convert_to_tensor(example))
11171117

11181118
tensor_dict = self.execute_cpu(graph_fn_1, [])
1119-
self.assertTrue(
1120-
fields.InputDataFields.groundtruth_image_classes in tensor_dict)
1119+
self.assertIn(fields.InputDataFields.groundtruth_image_classes, tensor_dict)
11211120
self.assertAllEqual(
11221121
tensor_dict[fields.InputDataFields.groundtruth_image_classes],
11231122
np.array([1, 2]))
@@ -1152,8 +1151,7 @@ def graph_fn_2():
11521151
return example_decoder.decode(tf.convert_to_tensor(example))
11531152

11541153
tensor_dict = self.execute_cpu(graph_fn_2, [])
1155-
self.assertTrue(
1156-
fields.InputDataFields.groundtruth_image_classes in tensor_dict)
1154+
self.assertIn(fields.InputDataFields.groundtruth_image_classes, tensor_dict)
11571155
self.assertAllEqual(
11581156
tensor_dict[fields.InputDataFields.groundtruth_image_classes],
11591157
np.array([1, 3]))
@@ -1345,6 +1343,93 @@ def graph_fn():
13451343
expected_image_confidence,
13461344
tensor_dict[fields.InputDataFields.groundtruth_image_confidences])
13471345

1346+
def testDecodeDensePose(self):
1347+
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
1348+
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
1349+
image_tensor, 'jpeg')
1350+
bbox_ymins = [0.0, 4.0, 2.0]
1351+
bbox_xmins = [1.0, 5.0, 8.0]
1352+
bbox_ymaxs = [2.0, 6.0, 1.0]
1353+
bbox_xmaxs = [3.0, 7.0, 3.3]
1354+
densepose_num = [0, 4, 2]
1355+
densepose_part_index = [2, 2, 3, 4, 2, 9]
1356+
densepose_x = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
1357+
densepose_y = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4]
1358+
densepose_u = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06]
1359+
densepose_v = [0.99, 0.98, 0.97, 0.96, 0.95, 0.94]
1360+
1361+
def graph_fn():
1362+
example = tf.train.Example(
1363+
features=tf.train.Features(
1364+
feature={
1365+
'image/encoded':
1366+
dataset_util.bytes_feature(encoded_jpeg),
1367+
'image/format':
1368+
dataset_util.bytes_feature(six.b('jpeg')),
1369+
'image/object/bbox/ymin':
1370+
dataset_util.float_list_feature(bbox_ymins),
1371+
'image/object/bbox/xmin':
1372+
dataset_util.float_list_feature(bbox_xmins),
1373+
'image/object/bbox/ymax':
1374+
dataset_util.float_list_feature(bbox_ymaxs),
1375+
'image/object/bbox/xmax':
1376+
dataset_util.float_list_feature(bbox_xmaxs),
1377+
'image/object/densepose/num':
1378+
dataset_util.int64_list_feature(densepose_num),
1379+
'image/object/densepose/part_index':
1380+
dataset_util.int64_list_feature(densepose_part_index),
1381+
'image/object/densepose/x':
1382+
dataset_util.float_list_feature(densepose_x),
1383+
'image/object/densepose/y':
1384+
dataset_util.float_list_feature(densepose_y),
1385+
'image/object/densepose/u':
1386+
dataset_util.float_list_feature(densepose_u),
1387+
'image/object/densepose/v':
1388+
dataset_util.float_list_feature(densepose_v),
1389+
1390+
})).SerializeToString()
1391+
1392+
example_decoder = tf_example_decoder.TfExampleDecoder(
1393+
load_dense_pose=True)
1394+
output = example_decoder.decode(tf.convert_to_tensor(example))
1395+
dp_num_points = output[fields.InputDataFields.groundtruth_dp_num_points]
1396+
dp_part_ids = output[fields.InputDataFields.groundtruth_dp_part_ids]
1397+
dp_surface_coords = output[
1398+
fields.InputDataFields.groundtruth_dp_surface_coords]
1399+
return dp_num_points, dp_part_ids, dp_surface_coords
1400+
1401+
dp_num_points, dp_part_ids, dp_surface_coords = self.execute_cpu(
1402+
graph_fn, [])
1403+
1404+
expected_dp_num_points = [0, 4, 2]
1405+
expected_dp_part_ids = [
1406+
[0, 0, 0, 0],
1407+
[2, 2, 3, 4],
1408+
[2, 9, 0, 0]
1409+
]
1410+
expected_dp_surface_coords = np.array(
1411+
[
1412+
# Instance 0 (no points).
1413+
[[0., 0., 0., 0.],
1414+
[0., 0., 0., 0.],
1415+
[0., 0., 0., 0.],
1416+
[0., 0., 0., 0.]],
1417+
# Instance 1 (4 points).
1418+
[[0.9, 0.1, 0.99, 0.01],
1419+
[0.8, 0.2, 0.98, 0.02],
1420+
[0.7, 0.3, 0.97, 0.03],
1421+
[0.6, 0.4, 0.96, 0.04]],
1422+
# Instance 2 (2 points).
1423+
[[0.5, 0.5, 0.95, 0.05],
1424+
[0.4, 0.6, 0.94, 0.06],
1425+
[0., 0., 0., 0.],
1426+
[0., 0., 0., 0.]],
1427+
], dtype=np.float32)
1428+
1429+
self.assertAllEqual(dp_num_points, expected_dp_num_points)
1430+
self.assertAllEqual(dp_part_ids, expected_dp_part_ids)
1431+
self.assertAllClose(dp_surface_coords, expected_dp_surface_coords)
1432+
13481433

13491434
if __name__ == '__main__':
13501435
tf.test.main()

research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf1_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def postprocess(self, prediction_dict, true_image_shapes):
6767
def restore_map(self, checkpoint_path, fine_tune_checkpoint_type):
6868
pass
6969

70+
def restore_from_objects(self, fine_tune_checkpoint_type):
71+
pass
72+
7073
def loss(self, prediction_dict, true_image_shapes):
7174
pass
7275

0 commit comments

Comments
 (0)