Skip to content

Commit e2fb375

Browse files
author
litangwei01
committedSep 1, 2023
quick_start
1 parent 438a4e4 commit e2fb375

File tree

10 files changed

+280
-182
lines changed

10 files changed

+280
-182
lines changed
 
42.7 KB
Loading
Loading
Loading
39.1 KB
Loading

‎docs/installation.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ The basic environment is as follows:
2323
All dependencies mentioned above come from a specific default backend. The construction of C++ core does not rely on any of the above dependencies.
2424
:::
2525

26-
## Using NGC base image
26+
## Using NGC base image {#NGC}
2727
The easiest way is to choose NGC mirror for source code compilation (official mirror may still be able to run low version drivers through Forward Compatibility or Minor Version Compatibility).
2828
First, clone the code:
2929

‎docs/introduction.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ To address these issues, TorchPipe provides a thread-safe function interface for
1717

1818

1919
![jpg](.././static/images/EngineFlow-light-english.png)
20-
<center>torchpipe framework diagram</center>
20+
<center>TorchPipe framework diagram</center>
2121

22-
**Features of the torchpipe framework:**
22+
**Features of the TorchPipe framework:**
2323
- Achieves near-optimal performance (peak throughput/TP99) from a business perspective, reducing widespread negative optimization and performance loss between nodes.
2424
- With a fine-grained generic backend, it is easy to expand hardware and weaken the difficulty of hardware vendor ecosystem migration.
2525
- Simple and high-performance modeling, including complex business systems such as multi-model fusion. Typical industrial scenarios include AI systems with up to 10 model nodes in smart cities, and OCR systems that involve subgraph independent scheduling, bucket scheduling, and intelligent batch grouping for extreme optimization.

‎docs/quick_start_new_user.md

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,251 @@ title: Beginner's Guide - A Small Step Forward
44
type: explainer
55
---
66

7+
8+
# Trial in 30mins(new users)
9+
10+
TorchPipe is a multi-instance pipeline parallel library that provides a seamless integration between lower-level acceleration libraries (such as TensorRT and OpenCV) and RPC frameworks. It guarantees high service throughput while meeting latency requirements. This document is mainly for new users, that is, users who are in the introductory stage of acceleration-related theoretical knowledge, know some python grammar, and can read simple codes. This content mainly includes the use of torchpipe for accelerating service deployment, complemented by performance and effect comparisons.
11+
12+
## Catalogue
13+
* [1. Basic knowledge](#1)
14+
* [2. Environmental installation and configuration](#2)
15+
* [3. Acceleration Case - The service includes only a single model, using ResNet50 as an example.](#3)
16+
* [3.1 Using TensorRT Acceleration Scheme](#3.1)
17+
* [3.2 Using TorchPipe Acceleration Scheme](#3.2)
18+
* [4. Performance and Effect Comparison](#4)
19+
20+
<a name='1'></a>
21+
22+
## 1. Basic knowledge
23+
24+
The field of deep learning has seen rapid advancement in recent years with significant progress in areas such as image recognition, text recognition, and speech recognition. Currently, there are several model acceleration techniques that enhance the inference speed of deep learning models through computational and hardware optimization, and these have resulted in notable achievements in practical applications. These techniques include those based on TensorRT and TVM acceleration. This tutorial will use the simplest business case from actual business deployment to demonstrate how to use torchpipe for online service deployment. The entire service only includes a single ResNet50 model. The overall service flow process is as illustrated below.
25+
![pipeline](images/quick_start_new_user/pipeline_en.png)
26+
27+
We will briefly explain some concepts that need to be understood in model deployment. We hope to be helpful to you who are experiencing TorchPipe for the first time. For details, please refer to [Preliminary Knowledge](./preliminaries).
28+
29+
30+
31+
<a name='2'></a>
32+
33+
## 2. Environmental installation and configuration
34+
35+
For specific installation steps, please refer to [installation](installation.mdx). We provide two methods for configuring the TorchPipe environment:
36+
- [Using NGC base image.](installation.mdx#NGC)
37+
- [Customizing Dockerfile](installation.mdx#selfdocker)
38+
39+
40+
<a name='3'></a>
41+
42+
## 3. Acceleration Case: Advancing from TensorRT to torchpipe.
43+
44+
This section begin by discussing the application of the TensorRT acceleration solution,and provide a general acceleration strategy for service deployment.Then, leveraging this solution, we will employ torchpipe to further optimize the acceleration across the entire service.
45+
<a name='3.1'></a>
46+
47+
### 3.1 Using TensorRT Acceleration Scheme {#UTAS}
48+
49+
![pipeline](images/quick_start_new_user/trt_en.png)
50+
51+
52+
TensorRT is an SDK that facilitates high-performance machine learning inference. It focuses specifically on running an already-trained network quickly and efficiently on NVIDIA hardware. However, TensorRT only supports optimization and acceleration for a model. Therefore, during the deployment of this service, we still use conventional operations for data decoding and preprocessing, both of which are done in Python. The model acceleration is achieved by using TensorRT to build the engine.
53+
54+
The details of each part are as follows:
55+
56+
1、Data decoding
57+
This part primarily relies on CPU data decoding to execute the operation.
58+
59+
```py
60+
## Data decoding(CPU decoding)
61+
img = cv2.imdecode(img, flags=cv2.IMREAD_COLOR)
62+
```
63+
64+
2、Preprocessing
65+
In this part, we mainly uses the built-in functions of pytorch to complete the operation
66+
67+
```py
68+
## Preprocessing
69+
precls_trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]), ])
70+
71+
img = precls_trans(cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), (224,224)))
72+
```
73+
74+
3、TensorRT acceleration
75+
76+
```py
77+
def load_classifier(net, max_batch_size,fp16):
78+
x = torch.ones((1, 3, 224, 224))
79+
if device == 'gpu':
80+
x = x.cuda()
81+
net.cuda()
82+
net.eval()
83+
trtmodel = torch2trt(net,
84+
[x],
85+
fp16_mode = fp16,
86+
max_batch_size=max_batch_size,
87+
max_workspace_size=32 * max_batch_size)
88+
del x
89+
del net
90+
return trtmodel
91+
92+
```
93+
94+
The overall online service deployment can be found at [main_trt.py](https://g.hz.netease.com/deploy/torchpipe/-/blob/develop/examples/resnet50/main_trt.py)
95+
96+
:::tip
97+
Since TensorRT is not thread-safe, when using this method for model acceleration, it is necessary to handle locking (with self.lock:) during the service deployment process.
98+
:::
99+
100+
101+
102+
<a name='3.2'></a>
103+
104+
### 3.2 Using TorchPipe Acceleration Scheme
105+
106+
From the above process, it's clear that when accelerating a single model, the focus is primarily on the acceleration of the model itself, while other factors in the service, such as data decoding and preprocessing operations, are overlooked. These preprocessing steps can impact the service's throughput and latency. Therefore, to achieve optimal throughput and latency, we use TorchPipe to optimize the entire service. The specific steps include:
107+
108+
- Multi-instance, dynamic batch processing, and bucketing on a single computing node
109+
- Pipeline scheduling across multiple nodes
110+
- Logical control flow between nodes
111+
112+
![](images/quick_start_new_user/torchpipe_en.png)
113+
114+
We've made adjustments to the deployment of our service using TorchPipe.The overall online service deployment can be found at [main_torchpipe.py](https://g.hz.netease.com/deploy/torchpipe/-/blob/develop/examples/resnet50/main_torchpipe.py).
115+
The core function modifications as follows:
116+
117+
```py
118+
# ------- main -------
119+
num_images = len(requests)
120+
for i in range(num_images):
121+
bin_data_list.append({TASK_DATA_KEY:requests[i].data, "node_name":"cpu_decoder"})
122+
123+
124+
toml_path = "resnet50.toml"
125+
classifier = pipe(toml_path)
126+
classifier(bin_data_list)
127+
128+
129+
if TASK_RESULT_KEY not in bin_data.keys():
130+
print("error decode")
131+
return results
132+
else:
133+
dis = self.softmax(bin_data[TASK_RESULT_KEY])
134+
135+
```
136+
From the above, we see a reduction in code volume compared to the original main function. The key lies within the contents of the toml file, which includes three nodes: [cpu_decoder], [cpu_posdecoder], and [resnet50]. These nodes operate in sequence, corresponding to the three parts mentioned in [section 3.1](quick_start_new_user.md#UTAS), as shown below:
137+
138+
![](images/quick_start_new_user/torchpipe_pipeline_en.png)
139+
140+
The contents of the toml file are as follows:
141+
142+
```bash
143+
# Schedule'parameter
144+
batching_timeout = 5
145+
instance_num = 8
146+
precision = "fp16"
147+
148+
## Data decoding
149+
#
150+
# This corresponds to 3.1(1).data decoding
151+
# img = cv2.imdecode(img, flags=cv2.IMREAD_COLOR)
152+
# Note:
153+
# The original decoding output format was BGR
154+
# The DecodeMat backend also defaults to outputting in BGR format
155+
# Since decoding is done on the CPU, DecodeMat is used
156+
# After each node is completed, the name of the next node needs to be
157+
# appended, otherwise the last node is assumed by default
158+
#
159+
[cpu_decoder]
160+
backend = "DecodeMat"
161+
next = "cpu_posprocess"
162+
163+
## preprocessing: resize、cvtColorMat
164+
#
165+
# This corresponds to 3.1(2) preprocessing
166+
# precls_trans = transforms.Compose([transforms.ToTensor(), ])
167+
# img = precls_trans(cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), (224,224)))
168+
# Note:
169+
# The original preprocessing order was resize, cv2.COLOR_BGR2RGB,
170+
# then Normalize.
171+
# However, the normalization step is now integrated into the model
172+
# processing (the [resnet50] node), so the output result after the
173+
# preprocessing in this node is consistent with the preprocessing result
174+
# without normalization.
175+
# After each node is completed, the name of the next node needs to be
176+
# appended, otherwise the last node is assumed by default.
177+
#
178+
[cpu_posdecoder]
179+
backend = "SyncTensor[Sequential[ResizeMat,cvtColorMat,Mat2Tensor]]"
180+
181+
### Parameters for the resize operation
182+
resize_h = 224
183+
resize_w = 224
184+
185+
### Parameters for the cvtColorMat operation:
186+
color = "rgb"
187+
188+
next = "resnet50"
189+
190+
## preprocessing-normalize and model acceleration
191+
#
192+
# This corresponds to 3.1(3) TensorRT acceleration and 3.1(2)Normalize
193+
# Note:
194+
# There's a slight difference from the original method of generating
195+
# engines online. Here, the model needs to be first converted to ONNX
196+
# format.
197+
#
198+
# For the conversion method, see [Converting Torch to ONNX].
199+
#
200+
[resnet50]
201+
backend = "Torch[TensorrtTensor]"
202+
min = 1
203+
max = 4
204+
instance_num = 4
205+
model = "/you/model/path/resnet50.onnx"
206+
207+
mean="123.675, 116.28, 103.53" # 255*"0.485, 0.456, 0.406"
208+
std="58.395, 57.120, 57.375" # 255*"0.229, 0.224, 0.225"
209+
210+
# TensorrtTensor
211+
"model::cache"="/you/model/path/resnet50.trt" # or resnet50.trt.encrypted
212+
213+
```
214+
215+
216+
:::tip
217+
- For the specific usage and functionality of other backend operators, please refer to [Basic Backend](./backend-reference/basic), [OpenCV Backend](./backend-reference/opencv), [Torch Backend](./backend-reference/torch), and [Log](./backend-reference/log).
218+
- This deviates slightly from the original method of generating engines online, as the model needs to be first converted to ONNX format. For the conversion method, see [Converting Torch to ONNX](faq/onnx.mdx).
219+
- TorchPipe has resolved the issue of TensorRT objects not being thread-safe and has undergone extensive experimental testing. Therefore, the lock can be disabled during service operation, i.e., the line `with self.lock:` can be commented out (in [section 3.1]).
220+
:::
221+
222+
223+
224+
<a name='4'></a>
225+
226+
## 4 Performance and Effect Comparison
227+
`python clien_qps.py --img_dir /your/testimg/path/ --port 8888 --request_client 20 --request_batch 1
228+
`
229+
230+
The specific test code can be found at [client_qps.py](https://g.hz.netease.com/deploy/torchpipe/-/blob/develop/examples/resnet50/client_qps.py)
231+
232+
With the same Thrift service interface, testing on a machine with NIDIA-3080 GPU, 8-core CPU, and concurrency of 10, we have the following results:
233+
234+
- throughput:
235+
236+
| Methods | QPS |
237+
|:-: | :-: |
238+
| Pure TensorRT | 747.92 |
239+
| Using TorchPipe |2172.54|
240+
241+
- response time:
242+
243+
| Methods | TP50 | TP99 |
244+
:-: | :-: | :-:|
245+
| Pure TensorRT | 26.74 |35.24|
246+
| Using TorchPipe |8.89|14.28|
247+
248+
- resource utilization:
249+
250+
| Methods | GPU Utilization | CPU Utilization | Memory Utilization |
251+
:-: | :-: | :-:| :-: |
252+
| Pure TensorRT | 42-45% |1473%|3.8%|
253+
| Using TorchPipe |48-50% |983.8%|1.6%|
254+

‎i18n/zh/docusaurus-plugin-content-docs/current/installation.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type: explainer
2323
以上依赖项均来源于默认存在的特定计算后端。构筑c++核心不依赖于以上任意一项。
2424
:::
2525

26-
## 使用NGC基础镜像
26+
## 使用NGC基础镜像 {#NGC}
2727
最简单的方式是可以选择<font color='Brown'>NGC镜像</font>进行源码编译(低版本驱动依靠 Forward Compatibility 或者 Minor Version Compatibility 依然可能跑起来官方镜像)。
2828
首先,克隆代码:
2929

0 commit comments

Comments
 (0)
Please sign in to comment.