Skip to content

Commit 1b1dfed

Browse files
committed
removing redundant cast and adding torch dtype check
1 parent f513d6c commit 1b1dfed

File tree

1 file changed

+1
-6
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+1
-6
lines changed

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def unify_and_concat_trt_tensors(
5555
layer = ctx.net.add_constant(shape, const_arr)
5656
set_layer_name(layer, target, f"{name}_dim{i}_const")
5757
t = layer.get_output(0)
58-
59-
# optional cast
60-
if cast_dtype and isinstance(t, TRTTensor):
61-
t = cast_trt_tensor(ctx, t, cast_dtype, f"{name}_cast_{i}")
62-
6358
trt_tensors.append(t)
6459

6560
if not has_dynamic and not force_trt_output:
@@ -70,7 +65,7 @@ def unify_and_concat_trt_tensors(
7065
# Explicit cast requested
7166
if isinstance(cast_dtype, _enums.dtype):
7267
final_dtype = cast_dtype.to(trt.DataType)
73-
elif isinstance(cast_dtype, np.dtype):
68+
elif isinstance(cast_dtype, (np.dtype, trt.dtype)):
7469
final_dtype = _enums.dtype._from(cast_dtype).to(trt.DataType)
7570
else:
7671
final_dtype = cast_dtype # already trt.DataType

0 commit comments

Comments
 (0)