|
22 | 22 |
|
23 | 23 | import logging
|
24 | 24 | import numpy as np
|
25 |
| -import six |
26 | 25 | import copy
|
27 | 26 | logger = logging.getLogger(__name__)
|
28 | 27 |
|
@@ -125,7 +124,7 @@ def filter_metrics_schema(self, white_set):
|
125 | 124 |
|
126 | 125 | def add_ad_hoc_plot_blob(self, blob, dtype=None):
|
127 | 126 | assert isinstance(
|
128 |
| - blob, (six.string_types, core.BlobReference) |
| 127 | + blob, (str, core.BlobReference) |
129 | 128 | ), "expect type str or BlobReference, but got {}".format(type(blob))
|
130 | 129 | dtype = dtype or (np.float, (1, ))
|
131 | 130 | self.add_metric_field(str(blob), schema.Scalar(dtype, blob))
|
@@ -173,7 +172,7 @@ def initializer(blob_name):
|
173 | 172 | def add_global_constant(
|
174 | 173 | self, name, array=None, dtype=None, initializer=None
|
175 | 174 | ):
|
176 |
| - assert isinstance(name, six.string_types), ( |
| 175 | + assert isinstance(name, str), ( |
177 | 176 | 'name should be a string as we are using it as map key')
|
178 | 177 | # This is global namescope for constants. They will be created in all
|
179 | 178 | # init_nets and there should be very few of them.
|
@@ -310,7 +309,7 @@ def create_param(self, param_name, shape, initializer, optimizer=None,
|
310 | 309 | ps_param=None, regularizer=None):
|
311 | 310 | if isinstance(param_name, core.BlobReference):
|
312 | 311 | param_name = str(param_name)
|
313 |
| - elif isinstance(param_name, six.string_types): |
| 312 | + elif isinstance(param_name, str): |
314 | 313 | # Parameter name will be equal to current Namescope that got
|
315 | 314 | # resolved with the respect of parameter sharing of the scopes.
|
316 | 315 | param_name = parameter_sharing_context.get_parameter_name(
|
@@ -750,6 +749,6 @@ def breakdown_map(self, breakdown_map):
|
750 | 749 | # TODO(xlwang): provide more rich feature information in breakdown_map;
|
751 | 750 | # and change the assertion accordingly
|
752 | 751 | assert isinstance(breakdown_map, dict)
|
753 |
| - assert all(isinstance(k, six.string_types) for k in breakdown_map) |
| 752 | + assert all(isinstance(k, str) for k in breakdown_map) |
754 | 753 | assert sorted(breakdown_map.values()) == list(range(len(breakdown_map)))
|
755 | 754 | self._breakdown_map = breakdown_map
|
0 commit comments