Skip to content

Commit

Permalink
Add fp16 support for python backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
ArinToaca authored and Arin Toaca committed Dec 11, 2019
1 parent 8716c9b commit 90be642
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions onnx_tensorrt/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def count_trailing_ones(vals):


class TensorRTBackendRep(BackendRep):
def __init__(self, model, device, max_batch_size=32,
def __init__(self, model, device, max_batch_size=32, fp16_mode=False,
max_workspace_size=None, serialize_engine=False, **kwargs):
if not isinstance(device, Device):
device = Device(device)
Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(self, model, device, max_batch_size=32,

self.builder.max_batch_size = max_batch_size
self.builder.max_workspace_size = max_workspace_size
self.builder.fp16_mode = fp16_mode

for layer in self.network:
print(layer.name)
Expand Down Expand Up @@ -231,4 +232,4 @@ def supports_device(cls, device_str):
prepare = TensorRTBackend.prepare
run_node = TensorRTBackend.run_node
run_model = TensorRTBackend.run_model
supports_device = TensorRTBackend.supports_device
supports_device = TensorRTBackend.supports_device

0 comments on commit 90be642

Please sign in to comment.