Autocast and Weak Typing #3869
narendasan
started this conversation in
RFCs
Replies: 1 comment
-
For phase 2, should we let users give us a calibration dataloader? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
TL;DR
Weak typing behavior in TensorRT is deprecated. However it is a good way to maximize performance, Therefore we want to create similar PyTorch native system to use with Torch-TensorRT that recovers some of this behavior.
Goal(s)
An automatic pass that users can opt into which will apply autocasting w.r.t. TRT's perfered ruleset.
Usecases
Proposed APIs / UX
We use the combination of the args
use_explicit_typingandenable_autocastto represent three modes:In this feature, we are focusing on Autocast mode, i.e., setting
use_explicit_typing=Trueandenable_autocast=True. Users can also setlow_precision_type,nodes_to_exclude,targets_to_exclude,data_max, andmax_depth_of_reductionto specify which ops should run in fp32 and which should run in low precision. We aim to keep consistency with NVIDIA ModelOpt Autocast, so the naming of args is similar. Please refer to ModelOpt Autocast doc for details.Example Workflow
In this example,
low_precision_type=torch.float16denotes Autocast should cast normal ops to fp16;nodes_to_excludedenotes the ops with regex pattern^conv2d$should remain in fp32, etc.torch.autocast(fp32)is used in the forward function so that any ops within the context manager will be in fp32.Limitations
Internal Implementation
To implement Autocast in Torch-TRT, we need to 1) determine which ops should be cast to low_precision, like fp16 or bf16, and which ops should be kept in fp32 and 2) modify FX Graph to make every op be the right precision.
1) Rule-based Node Classifier
Similar to ModelOpt Autocast, we have a rule-based node classifier to determine precision of each op. Based on our predefined ruleset, if any rule is met, the op will be in fp32; otherwise, in fp16. If node target is
torch.ops.higher_order.wrap_with_autocastoroperator.getitem, they will be directly skipped.Take the above demo as an example, the node classifier's decision is as follows:
2) Modify FX Graph
From step 1 we have determined the precision of each op. Then, we are going to add a pre_lowering pass to insert Cast op before the op. If node target is
torch.ops.higher_order.wrap_with_autocastoroperator.getitem, they will be directly skipped.The modified graph of the above demo is as follows:
Implementation Phases
Prototype -
#3878
MVP
(<TARGET RELEASE VERSION>)Extension Phase 1
(<TARGET RELEASE VERSION>)Extension Phase 2
(<TARGET RELEASE VERSION>)Beta Was this translation helpful? Give feedback.
All reactions