You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In all of our known practices, TorchPipe can completely replace [`torch2trt`](https://github.com/NVIDIA-AI-IOT/torch2trt)through static ONNX composition, dynamic ONNX, pre-generated TensorRT models, and other methods.
The framework prioritizes dynamic `batch` or static `batch` with `batchsize==1`. In reality, some models cannot be converted to dynamic scale or are prone to errors. We also support [**loading multiple models with different static batch sizes at the same time**](../Intra-node/schedule#single_node_combine) to simulate dynamic scale. The following instructions mainly apply to exporting dynamic batch size models.
- The following operations make dynamic batch size unavailable: ``x.view(int(x.size(0)), -1)``. Check if the model file has hardcoded the batch dimension, such as ``x.view(int(x.size(0)), -1, 1, 1)``, ``x.reshape(int(x.size(0)), -1, 1, 1)``, etc., which may cause problems with dynamic batch size after converting to ONNX. Note that in Transformer-like networks, the batch dimension is not necessarily in the 0th dimension.
18
+
- When the batch dimension is specified as dynamic size, low-version TensorRT has weaker processing capabilities and more redundant operators. For example, for ``x.view(x.size(0), -1)``, Gather and other operators will be introduced in ONNX to calculate the first dimension of x. It can be modified to ``x = x.view(-1, int(x.size(1)*x.size(2)*x.size(3)))`` or ``x = torch.flatten(x, 1)``. This is not necessary.
19
+
- For some models (TensorRT 8.5.1, LSTM, and Transformer), when the batch dimension and non-batch dimension are both dynamic, more resources may be consumed:
20
+
- For LayerNorm layers and Transformer-like networks with dynamic batch size, opset>=17 and TensorRT>=8.6.1 are recommended.
- Whenever possible, keep the batch dimension in the 0th dimension with a length of the default state (i.e., -1) to remove redundant operators.
42
+
- Use onnx-simplify for optimization.
43
+
-[Smaller optimization ranges usually mean faster speeds and less resource consumption](https://github.com/NVIDIA/TensorRT/issues/1166#issuecomment-815551064).
45
44
:::
46
45
47
46
48
-
修改完网络后,可以利用下面代码,将pytorch模型转换为onnx模型。
47
+
After modifying the network, you can use the following code to convert the PyTorch model to an ONNX model:
49
48
50
49
```python
51
50
x = torch.randn(1,*input_shape).cuda()
@@ -60,8 +59,8 @@ torch.onnx.export(torch_model,
60
59
onnx_save_path,
61
60
opset_version=17,
62
61
do_constant_folding=True,
63
-
input_names=["input"], #输入名
64
-
output_names=[f"output_{i}"for i inrange(out_size)], #输出名
62
+
input_names=["input"], #input name
63
+
output_names=[f"output_{i}"for i inrange(out_size)], #output names
Polygraphy is a tool provided by NVIDIA for testing TensorRT or ONNX. It provides model conversion functionality and allows for debugging of FP16 precision loss. It also allows for specifying layers that should not use FP16..
-[PyTorch to ONNX Conversion Tutorial](https://zhuanlan.zhihu.com/p/498425043)
175
+
-[Modifying and Debugging ONNX Models](https://zhuanlan.zhihu.com/p/516920606)
176
+
-[TensorRT Tutorial | Based on version 8.6.1](https://www.bilibili.com/video/BV1jj411Z7wG/?spm_id_from=333.999.0.0&vd_source=c31de98543aa977b5899e24bdd5d8f89)
Copy file name to clipboardExpand all lines: docs/introduction.md
+1-1
Original file line number
Diff line number
Diff line change
@@ -11,7 +11,7 @@ To enhance the peak throughput of deep learning serving, various challenges must
11
11
12
12
There are some industry practices, such as [triton inference server](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models), [Alimama high_service(in chinese)](https://mp.weixin.qq.com/s/Fd2GNXqO3wl3FrA7Wli3jA), and [Meituan Vision GPU Inference Service Deployment Architecture Optimization Practice(in chinese)](https://zhuanlan.zhihu.com/p/605094862).
13
13
14
-
One common complaint from users of the Triton Inference Server is that in a system with multiple intertwined nodes, a lot of business logic needs to be completed on the client side and then called through RPC to the server, which can be cumbersome. For performance reasons, unconventional methods such as shared memory, ensemble, and [BLS](https://github.com/triton-inference-server/python_backend#business-logic-scripting) must be considered.
14
+
One common complaint from users of the Triton Inference Server is that in a system with multiple intertwined nodes, a lot of business logic needs to be completed on the client side and then called through RPC to the server, which can be cumbersome. For performance reasons, unconventional methods such as shared memory, ensemble, and [Business Logic Scripting(BLS)](https://github.com/triton-inference-server/python_backend#business-logic-scripting) must be considered.
15
15
16
16
To address these issues, TorchPipe provides a thread-safe function interface for the PyTorch frontend and a fine-grained backend extension for users, by delving into PyTorch's C++ calculation backend and CUDA stream management, as well as modeling domain-specific languages for multiple nodes.
0 commit comments