diff --git a/onnx_tensorrt/backend.py b/onnx_tensorrt/backend.py index 010f3db8..b206af22 100644 --- a/onnx_tensorrt/backend.py +++ b/onnx_tensorrt/backend.py @@ -58,7 +58,7 @@ def count_trailing_ones(vals): class TensorRTBackendRep(BackendRep): - def __init__(self, model, device, max_batch_size=32, + def __init__(self, model, device, max_batch_size=32, fp16_mode=False, max_workspace_size=None, serialize_engine=False, **kwargs): if not isinstance(device, Device): device = Device(device) @@ -89,6 +89,7 @@ def __init__(self, model, device, max_batch_size=32, self.builder.max_batch_size = max_batch_size self.builder.max_workspace_size = max_workspace_size + self.builder.fp16_mode = fp16_mode for layer in self.network: print(layer.name) @@ -231,4 +232,4 @@ def supports_device(cls, device_str): prepare = TensorRTBackend.prepare run_node = TensorRTBackend.run_node run_model = TensorRTBackend.run_model -supports_device = TensorRTBackend.supports_device \ No newline at end of file +supports_device = TensorRTBackend.supports_device