From 90be642420192ad03a6c69099995794df909686a Mon Sep 17 00:00:00 2001 From: Arin Toaca Date: Wed, 11 Dec 2019 09:59:22 +0000 Subject: [PATCH] Add fp16 support for python backend. --- onnx_tensorrt/backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnx_tensorrt/backend.py b/onnx_tensorrt/backend.py index 010f3db8..b206af22 100644 --- a/onnx_tensorrt/backend.py +++ b/onnx_tensorrt/backend.py @@ -58,7 +58,7 @@ def count_trailing_ones(vals): class TensorRTBackendRep(BackendRep): - def __init__(self, model, device, max_batch_size=32, + def __init__(self, model, device, max_batch_size=32, fp16_mode=False, max_workspace_size=None, serialize_engine=False, **kwargs): if not isinstance(device, Device): device = Device(device) @@ -89,6 +89,7 @@ def __init__(self, model, device, max_batch_size=32, self.builder.max_batch_size = max_batch_size self.builder.max_workspace_size = max_workspace_size + self.builder.fp16_mode = fp16_mode for layer in self.network: print(layer.name) @@ -231,4 +232,4 @@ def supports_device(cls, device_str): prepare = TensorRTBackend.prepare run_node = TensorRTBackend.run_node run_model = TensorRTBackend.run_model -supports_device = TensorRTBackend.supports_device \ No newline at end of file +supports_device = TensorRTBackend.supports_device