Skip to content

Commit 5b0f400

Browse files
Flamefirefacebook-github-bot
authored andcommitted
Replace list(map(...)) constructs by list comprehensions (pytorch#46461)
Summary: As discussed in pytorch#46392 this makes the code more readable and possibly more performant. It also fixes a bug detected by this where the argument order of `map` was confused: Flamefire@030a249#diff-5bb26bd3a23ee3bb540aeadcc0385df2a4e48de39f87ed9ea76b21990738fe98L1537-R1537 Fixes pytorch#46392 Pull Request resolved: pytorch#46461 Reviewed By: ailzhang Differential Revision: D24367015 Pulled By: ezyang fbshipit-source-id: d55a67933cc22346b00544c9671f09982ad920e7
1 parent 3d421b3 commit 5b0f400

25 files changed

+63
-63
lines changed

.circleci/cimodel/data/simple/util/versions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def prefixed_parts(self):
99
with the prefix string.
1010
"""
1111
if self.parts:
12-
return [self.prefix + str(self.parts[0])] + list(map(str, self.parts[1:]))
12+
return [self.prefix + str(self.parts[0])] + [str(part) for part in self.parts[1:]]
1313
else:
1414
return [self.prefix]
1515

caffe2/python/fused_8bit_rowwise_conversion_ops_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def floats_to_bytes(floats):
3030
if isinstance(as_bytes[0], int):
3131
byte_matrix[i] = list(as_bytes)
3232
else:
33-
byte_matrix[i] = list(map(ord, as_bytes))
33+
byte_matrix[i] = [ord(i) for i in as_bytes]
3434
return byte_matrix
3535

3636

caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def int8_to_bytes(int8s):
4646
if isinstance(as_bytes[0], int):
4747
byte_matrix[i] = list(as_bytes)
4848
else:
49-
byte_matrix[i] = list(map(ord, as_bytes))
49+
byte_matrix[i] = [ord(i) for i in as_bytes]
5050
return byte_matrix
5151

5252

caffe2/python/operator_test/torch_integration_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def floats_to_bytes(floats):
9090
if isinstance(as_bytes[0], int):
9191
byte_matrix[i] = list(as_bytes)
9292
else:
93-
byte_matrix[i] = list(map(ord, as_bytes))
93+
byte_matrix[i] = [ord(i) for i in as_bytes]
9494
return byte_matrix
9595

9696

caffe2/python/rnn_cell.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _RectifyName(blob_reference_or_name):
4242
def _RectifyNames(blob_references_or_names):
4343
if blob_references_or_names is None:
4444
return None
45-
return list(map(_RectifyName, blob_references_or_names))
45+
return [_RectifyName(i) for i in blob_references_or_names]
4646

4747

4848
class RNNCell(object):
@@ -236,7 +236,7 @@ def get_state_names(self):
236236
'''
237237
Returns recurrent state names with self.name scoping applied
238238
'''
239-
return list(map(self.scope, self.get_state_names_override()))
239+
return [self.scope(name) for name in self.get_state_names_override()]
240240

241241
def get_state_names_override(self):
242242
'''

test/jit/test_list_dict.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -948,11 +948,11 @@ def check_list(fn, li):
948948
check_list(min_intlist, int_list)
949949
check_list(max_intlist, int_list)
950950

951-
bool_li = list(map(lambda x: bool(x), int_list))
951+
bool_li = [bool(x) for x in int_list]
952952
check_list(min_boollist, bool_li)
953953
check_list(max_boollist, bool_li)
954954

955-
float_li = list(map(lambda x: float(x), int_list))
955+
float_li = [float(x) for x in int_list]
956956
check_list(min_floatlist, float_li)
957957
check_list(max_floatlist, float_li)
958958

test/onnx/test_pytorch_onnx_onnxruntime.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
5353
def run_ort(ort_sess, input):
5454
input_copy = copy.deepcopy(input)
5555
input, _ = torch.jit._flatten(input_copy)
56-
inputs = list(map(to_numpy, input))
56+
inputs = [to_numpy(inp) for inp in input]
5757

5858
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
5959
ort_outs = ort_sess.run(None, ort_inputs)
@@ -62,7 +62,7 @@ def run_ort(ort_sess, input):
6262

6363
def ort_compare_with_pytorch(ort_outs, output, rtol, atol):
6464
output, _ = torch.jit._flatten(output)
65-
outputs = list(map(to_numpy, output))
65+
outputs = [to_numpy(outp) for outp in output]
6666

6767
# compare onnxruntime and PyTorch results
6868
assert len(outputs) == len(ort_outs), "number of outputs differ"

test/onnx/verify.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ def run(args):
386386
"it had a different set of parameters. Are you assigning Parameters\n"
387387
"in the forward() of your model definition?")
388388
with errs.addErrCtxt(initializer_order_hint):
389-
errs.requireEqual(list(map(lambda x: x.name, proto.graph.initializer)),
390-
list(map(lambda x: x.name, alt_proto.graph.initializer)),
389+
errs.requireEqual([x.name for x in proto.graph.initializer],
390+
[x.name for x in alt_proto.graph.initializer],
391391
msg="Parameters list differs")
392392

393393
# Now check if the embedded parameters are actually the same

test/test_cuda.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2959,9 +2959,9 @@ def test_reduce_add(self):
29592959
self.assertEqual(result.cpu(), x + y)
29602960

29612961
def _test_reduce_add_coalesced(self, tensors, buffer_size):
2962-
dup_tensors = [tensors, list(map(lambda t: t.cuda(1), tensors))]
2962+
dup_tensors = [tensors, [t.cuda(1) for t in tensors]]
29632963

2964-
r_tensors = list(map(comm.reduce_add, zip(*dup_tensors)))
2964+
r_tensors = [comm.reduce_add(t) for t in zip(*dup_tensors)]
29652965
for r, t in zip(r_tensors, tensors):
29662966
self.assertEqualTypeString(r, t)
29672967
self.assertEqual(r, t * 2)

test/test_jit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8030,7 +8030,7 @@ def contained_blocks(node):
80308030
return len(node.findAllNodes("prim::If")) * 2 + len(node.findAllNodes("prim::Loop"))
80318031
for node in ifs + loops:
80328032
outs = list(node.outputs())
8033-
out_name = list(map(lambda x: x.debugName(), outs))
8033+
out_name = [x.debugName() for x in outs]
80348034
if len(out_name) == 0:
80358035
continue
80368036
fc = FileCheck()

test/test_nn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _ordered_sequence(self, tensor_type):
112112
def _padded_sequence(self, tensor_type):
113113
"""Create Tensor of random padded sequences"""
114114
ordered = self._ordered_sequence(tensor_type)
115-
lengths = list(map(len, ordered))
115+
lengths = [len(i) for i in ordered]
116116
padded_tensor = rnn_utils.pad_sequence(ordered)
117117
return padded_tensor, lengths
118118

@@ -11335,7 +11335,7 @@ def _ordered_sequence(self, device, dtype):
1133511335
def _padded_sequence(self, device, dtype):
1133611336
"""Create Tensor of random padded sequences"""
1133711337
ordered = self._ordered_sequence(device, dtype)
11338-
lengths = list(map(len, ordered))
11338+
lengths = [len(i) for i in ordered]
1133911339
padded_tensor = rnn_utils.pad_sequence(ordered)
1134011340
return padded_tensor, lengths
1134111341

test/test_optim.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ def test_step_lr(self):
878878
# lr = 0.0005 if epoch >= 9
879879
epochs = 10
880880
single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
881-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
881+
targets = [single_targets, [x * epochs for x in single_targets]]
882882
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
883883
self._test(scheduler, targets, epochs)
884884

@@ -897,7 +897,7 @@ def test_get_last_lr_multi_step_lr(self):
897897
# lr = 0.00005 if 9 <= epoch
898898
epochs = 10
899899
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1
900-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
900+
targets = [single_targets, [x * epochs for x in single_targets]]
901901
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
902902
self._test_get_last_lr(scheduler, targets, epochs)
903903

@@ -908,7 +908,7 @@ def test_multi_step_lr(self):
908908
# lr = 0.00005 if epoch >= 9
909909
epochs = 10
910910
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
911-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
911+
targets = [single_targets, [x * epochs for x in single_targets]]
912912
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
913913
self._test(scheduler, targets, epochs)
914914

@@ -919,14 +919,14 @@ def test_multi_step_lr_with_epoch(self):
919919
# lr = 0.00005 if epoch >= 9
920920
epochs = 10
921921
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
922-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
922+
targets = [single_targets, [x * epochs for x in single_targets]]
923923
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
924924
self._test_with_epoch(scheduler, targets, epochs)
925925

926926
def test_exp_lr(self):
927927
epochs = 10
928928
single_targets = [0.05 * (0.9 ** x) for x in range(epochs)]
929-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
929+
targets = [single_targets, [x * epochs for x in single_targets]]
930930
scheduler = ExponentialLR(self.opt, gamma=0.9)
931931
self._test(scheduler, targets, epochs)
932932

@@ -936,7 +936,7 @@ def test_cos_anneal_lr(self):
936936
single_targets = [eta_min + (0.05 - eta_min) *
937937
(1 + math.cos(math.pi * x / epochs)) / 2
938938
for x in range(epochs)]
939-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
939+
targets = [single_targets, [x * epochs for x in single_targets]]
940940
scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
941941
self._test(scheduler, targets, epochs)
942942

@@ -1058,7 +1058,7 @@ def test_compound_step_and_exp_lr(self):
10581058
single_targets += [0.005 * (0.9 ** x) for x in range(3, 6)]
10591059
single_targets += [0.0005 * (0.9 ** x) for x in range(6, 9)]
10601060
single_targets += [0.00005 * (0.9 ** x) for x in range(9, 12)]
1061-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
1061+
targets = [single_targets, [x * epochs for x in single_targets]]
10621062
schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
10631063
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
10641064
self._test(schedulers, targets, epochs)
@@ -1070,7 +1070,7 @@ def test_compound_exp_and_multistep_lr(self):
10701070
single_targets += [0.005 * (0.9 ** x) for x in range(2, 5)]
10711071
single_targets += [0.0005 * (0.9 ** x) for x in range(5, 9)]
10721072
single_targets += [0.00005 * (0.9 ** x) for x in range(9, 11)]
1073-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
1073+
targets = [single_targets, [x * epochs for x in single_targets]]
10741074
schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
10751075
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
10761076
self._test(schedulers, targets, epochs)
@@ -1082,7 +1082,7 @@ def test_compound_cosanneal_and_step_lr(self):
10821082
(1 + math.cos(math.pi * x / epochs)) / 2
10831083
for x in range(epochs)]
10841084
single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)]
1085-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
1085+
targets = [single_targets, [x * epochs for x in single_targets]]
10861086
schedulers = [None] * 2
10871087
schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
10881088
schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
@@ -1096,7 +1096,7 @@ def test_compound_cosanneal_and_multistep_lr(self):
10961096
for x in range(epochs)]
10971097
multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001]
10981098
single_targets = [x * y for x, y in zip(single_targets, multipliers)]
1099-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
1099+
targets = [single_targets, [x * epochs for x in single_targets]]
11001100
schedulers = [None] * 2
11011101
schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
11021102
schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
@@ -1110,7 +1110,7 @@ def test_compound_cosanneal_and_exp_lr(self):
11101110
for x in range(epochs)]
11111111
multipliers = [0.1 ** i for i in range(epochs)]
11121112
single_targets = [x * y for x, y in zip(single_targets, multipliers)]
1113-
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
1113+
targets = [single_targets, [x * epochs for x in single_targets]]
11141114
schedulers = [None] * 2
11151115
schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
11161116
schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
@@ -1222,8 +1222,8 @@ def test_cycle_lr_exp_range_mode_one_lr(self):
12221222
diff_lr = max_lr - base_lr
12231223
gamma = 0.9
12241224
xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
1225-
lr_target = list(map(lambda x: base_lr + x[1] * diff_lr * gamma**x[0], enumerate(xs)))
1226-
momentum_target = list(map(lambda x: max_lr - x[1] * diff_lr * gamma**x[0], enumerate(xs)))
1225+
lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
1226+
momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
12271227
lr_targets = [lr_target, lr_target]
12281228
momentum_targets = [momentum_target, momentum_target]
12291229
scheduler = CyclicLR(self.opt, base_lr=base_lr,
@@ -1234,10 +1234,10 @@ def test_cycle_lr_exp_range_mode_one_lr(self):
12341234

12351235
def test_cycle_lr_triangular_mode(self):
12361236
lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1237-
lr_target_2 = list(map(lambda x: x + 1, lr_target_1))
1237+
lr_target_2 = [x + 1 for x in lr_target_1]
12381238
lr_targets = [lr_target_1, lr_target_2]
12391239
momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
1240-
momentum_target_2 = list(map(lambda x: x + 1, momentum_target_1))
1240+
momentum_target_2 = [x + 1 for x in momentum_target_1]
12411241
momentum_targets = [momentum_target_1, momentum_target_2]
12421242
scheduler = CyclicLR(self.opt, base_lr=[1, 2], max_lr=[5, 6], step_size_up=4,
12431243
cycle_momentum=True, base_momentum=[1, 2], max_momentum=[5, 6],
@@ -1247,11 +1247,11 @@ def test_cycle_lr_triangular_mode(self):
12471247
def test_cycle_lr_triangular2_mode(self):
12481248
lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, 1,
12491249
1.25, 1.50, 1.75, 2.00, 1.75]
1250-
lr_target_2 = list(map(lambda x: x + 2, lr_target_1))
1250+
lr_target_2 = [x + 2 for x in lr_target_1]
12511251
lr_targets = [lr_target_1, lr_target_2]
12521252
momentum_target_1 = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, 3.5,
12531253
3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25]
1254-
momentum_target_2 = list(map(lambda x: x + 2, momentum_target_1))
1254+
momentum_target_2 = [x + 2 for x in momentum_target_1]
12551255
momentum_targets = [momentum_target_1, momentum_target_2]
12561256
scheduler = CyclicLR(self.opt, base_lr=[1, 3], max_lr=[5, 7], step_size_up=4,
12571257
cycle_momentum=True, base_momentum=[1, 3], max_momentum=[5, 7],
@@ -1267,11 +1267,11 @@ def test_cycle_lr_exp_range_mode(self):
12671267

12681268
gamma = 0.9
12691269
xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
1270-
lr_target_1 = list(map(lambda x: base_lr_1 + x[1] * diff_lr_1 * gamma**x[0], enumerate(xs)))
1271-
lr_target_2 = list(map(lambda x: base_lr_2 + x[1] * diff_lr_2 * gamma**x[0], enumerate(xs)))
1270+
lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)]
1271+
lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)]
12721272
lr_targets = [lr_target_1, lr_target_2]
1273-
momentum_target_1 = list(map(lambda x: max_lr_1 - x[1] * diff_lr_1 * gamma**x[0], enumerate(xs)))
1274-
momentum_target_2 = list(map(lambda x: max_lr_2 - x[1] * diff_lr_2 * gamma**x[0], enumerate(xs)))
1273+
momentum_target_1 = [max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)]
1274+
momentum_target_2 = [max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)]
12751275
momentum_targets = [momentum_target_1, momentum_target_2]
12761276
scheduler = CyclicLR(self.opt, base_lr=[base_lr_1, base_lr_2],
12771277
max_lr=[max_lr_1, max_lr_2], step_size_up=4,

test/test_torch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2676,7 +2676,7 @@ def test_permute(self):
26762676
orig = [1, 2, 3, 4, 5, 6, 7]
26772677
perm = torch.randperm(7).tolist()
26782678
x = torch.Tensor(*orig).fill_(0)
2679-
new = list(map(lambda x: x - 1, x.permute(*perm).size()))
2679+
new = [i - 1 for i in x.permute(*perm).size()]
26802680
self.assertEqual(perm, new)
26812681
self.assertEqual(x.size(), orig)
26822682

@@ -8553,7 +8553,7 @@ def consec(size, start=1):
85538553
idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
85548554
idx2_step = random.randrange(1, 8)
85558555
idx2 = slice(idx2_start, idx2_end, idx2_step)
8556-
lst_indexed = list(map(lambda l: l[idx2], lst[idx1]))
8556+
lst_indexed = [l[idx2] for l in lst[idx1]]
85578557
tensor_indexed = tensor[idx1, idx2]
85588558
else:
85598559
lst_indexed = lst[idx1]

tools/codegen/gen.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ def main() -> None:
974974
d = pre_grouped_native_functions[f.func.signature()]
975975
assert f.func.kind() not in d
976976
d[f.func.kind()] = f
977-
grouped_native_functions = list(map(NativeFunctionGroup.from_dict, pre_grouped_native_functions.values()))
977+
grouped_native_functions = [NativeFunctionGroup.from_dict(v) for v in pre_grouped_native_functions.values()]
978978
# NB: At the moment, grouped_native_functions isn't used by anything,
979979
# this code lives here to help potential future consumers; for a live
980980
# example see https://github.com/pytorch/pytorch/pull/45277
@@ -1130,9 +1130,9 @@ def computeSchemaRegister() -> Dict[str, object]:
11301130
}
11311131
cpu_fm.write('SchemaRegister.cpp', computeSchemaRegister)
11321132

1133-
cpu_fm.write('Declarations.yaml', lambda: format_yaml(list(map(compute_declaration_yaml, native_functions))))
1133+
cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]))
11341134
cpu_fm.write('RegistrationDeclarations.h', lambda: {
1135-
'registration_declarations': list(map(compute_registration_declarations, native_functions)),
1135+
'registration_declarations': [compute_registration_declarations(f) for f in native_functions],
11361136
})
11371137

11381138
if options.output_dependencies:

torch/jit/frontend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def build_Expr(ctx, stmt):
394394
@staticmethod
395395
def build_Assign(ctx, stmt):
396396
rhs = build_expr(ctx, stmt.value)
397-
lhs = list(map(lambda x: build_expr(ctx, x), stmt.targets))
397+
lhs = [build_expr(ctx, x) for x in stmt.targets]
398398
return Assign(lhs, rhs)
399399

400400
@staticmethod

torch/nn/parallel/_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def forward(ctx, target_gpus, *inputs):
1313
assert all(map(lambda i: i.device.type != 'cpu', inputs)), (
1414
'Broadcast function not implemented for CPU tensors'
1515
)
16-
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
16+
target_gpus = [_get_device_index(x, True) for x in target_gpus]
1717
ctx.target_gpus = target_gpus
1818
if len(inputs) == 0:
1919
return tuple()
@@ -82,7 +82,7 @@ class Scatter(Function):
8282

8383
@staticmethod
8484
def forward(ctx, target_gpus, chunk_sizes, dim, input):
85-
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
85+
target_gpus = [_get_device_index(x, True) for x in target_gpus]
8686
ctx.dim = dim
8787
ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
8888
streams = None

torch/nn/parallel/data_parallel.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _check_balance(device_ids):
1919
has less than 75% of the memory or cores of GPU {}. You can do so by setting
2020
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
2121
environment variable."""
22-
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
22+
device_ids = [_get_device_index(x, True) for x in device_ids]
2323
dev_props = _get_devices_properties(device_ids)
2424

2525
def warn_imbalance(get_prop):
@@ -135,7 +135,7 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
135135

136136
self.dim = dim
137137
self.module = module
138-
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
138+
self.device_ids = [_get_device_index(x, True) for x in device_ids]
139139
self.output_device = _get_device_index(output_device, True)
140140
self.src_device_obj = torch.device(device_type, self.device_ids[0])
141141

@@ -200,7 +200,7 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
200200
if output_device is None:
201201
output_device = device_ids[0]
202202

203-
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
203+
device_ids = [_get_device_index(x, True) for x in device_ids]
204204
output_device = _get_device_index(output_device, True)
205205
src_device_obj = torch.device(device_type, device_ids[0])
206206

0 commit comments

Comments
 (0)