-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Unexpected performance drop with float8 training + compiling only nn.Linear layers + using selective per op AC #786
Comments
Thanks for flagging the issue. It looks the memory usage is too high, which might cause throughput degradation. I think it's worth checking
also copying over what @awgu said offline, supporting point 1
|
cc: @vkuzo , fyi |
@tianyu-l I think we are running into another instance of this issue in #808 (see my comment here #808 (comment)). I went ahead and pulled some memory snapshots and analyzed them here #808 (comment). Still not sure of the root cause but please feel free to add any input. If I'm correct that this issue is related, then the scope of this bug is bigger than we thought, and selective per op AC has some deeper issues that affects more than just compiling only linear layers, but also fp8 row-wise quantization. |
At the top of this issue, the peak memory usage in the problematic experiment is 92GB, which is too close to the max memory available on the H100 GPU.
^ is likely related to peak memory being close to machine limit I don't think there is conclusive evidence that #808 and this issue are related, they just both seem to be about selective per-op AC. |
Summary
I'm doing some benchmarking with torchtitan on H100s to compare the experimental feature in #778 with other training configurations.
One important comparison is comparing:
I've ran this comparison using:
However, specifically when using selective per op AC, there is a massive drop in performance when using the production float8 training + using torch.compile on only the nn.Linear layers, as compared to using torch.compile on the full model.
This does not occur when using no AC or full AC.
I would expect some performance degradation compiling only nn.Linear instead of the full model, but the drop off is massive (see screenshot of benchmarks below). TFLOPS drops from 386.58 down to 125.88!
I looked at the traces for these 2 runs, and found some surprising issues:
Only nn.Linears compiled (267ms):
![267ms](https://private-user-images.githubusercontent.com/105610547/402098392-f4f9e278-f043-4ae2-8ac9-6305f65077f1.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkzNjcxNzQsIm5iZiI6MTczOTM2Njg3NCwicGF0aCI6Ii8xMDU2MTA1NDcvNDAyMDk4MzkyLWY0ZjllMjc4LWYwNDMtNGFlMi04YWM5LTYzMDVmNjUwNzdmMS5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjEyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxMlQxMzI3NTRaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT04MzMzM2E2OTEyNTAzYjkxMjNkZjZmZDNhNzA0OTA3Y2FhMjcwMmI2NTg5YjMzNzZiOTliZDIxMjk3YzdjOTVlJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.kbS_-CFYI0H-IUPRC1lgXKK7m9M26ELhiWENA-8BLcU)
Full model compiled (71us):
![71us](https://private-user-images.githubusercontent.com/105610547/402098932-7f511e42-85af-43d3-8811-9418f36c7234.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkzNjcxNzQsIm5iZiI6MTczOTM2Njg3NCwicGF0aCI6Ii8xMDU2MTA1NDcvNDAyMDk4OTMyLTdmNTExZTQyLTg1YWYtNDNkMy04ODExLTk0MThmMzZjNzIzNC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjEyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxMlQxMzI3NTRaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0zMjEyOTAyOWFiMWNmZmM1NGMyYTkzMWUxMjcyNzdkYmQ1OGEwNzZlZGEzYmE0ZGM1MmVkMzZkZDZiNGViOGRhJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.RmUv0j8M9uD4fJUm7nEOd5wC4mc9sSfJDAdUPrqDKuw)
FSDP::post_backward_reduce
call that does not appear in the fully compiled version (or rather, it is orders of magnitude faster).Only nn.Linears compiled:
Steps to reproduce
training_configs/llama3_8b.toml
to run prod float8 + fully compiled model + selective per op AC on H100s:NGPU=4 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
training_configs/llama3_8b.toml
to run prod float8 + only linear layers compiled + selective per op AC on H100s (don't think # of GPUs matters):TORCHTITAN_COMPILE_LINEAR_ONLY=1 NGPU=4 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
cc @vkuzo @soulitzer
The text was updated successfully, but these errors were encountered: