-
Notifications
You must be signed in to change notification settings - Fork 22
IFU dev v2.6 #374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
IFU dev v2.6 #374
Conversation
Signed-off-by: Przemek Tredak <[email protected]>
* tests drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move dir Signed-off-by: Pawel Gadzinski <[email protected]> * tests fox Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Przemek Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemek Tredak <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix README render on PyPI Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update README.rst Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Use anonymous hyperlink for duplicate. Fix indent. Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* Check tensor-recipe compatibility Signed-off-by: Evgeny Tsykunov <[email protected]> * Tensor class in recipe, checking for *Base Signed-off-by: Evgeny Tsykunov <[email protected]> * Extend recipe __repr__ with recipe_type Signed-off-by: Evgeny Tsykunov <[email protected]> * Warn about recipe change Signed-off-by: Evgeny Tsykunov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enable dynamic recipe change: clear fp8 workspace Signed-off-by: Evgeny Tsykunov <[email protected]> * TE 1.x checkpoint compatibility Signed-off-by: Evgeny Tsykunov <[email protected]> * Disable warning for recipe wrappers Signed-off-by: Evgeny Tsykunov <[email protected]> * Test recipe change Signed-off-by: Evgeny Tsykunov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use QuantizedTensorBase Signed-off-by: Evgeny Tsykunov <[email protected]> * Fix circular import Signed-off-by: Evgeny Tsykunov <[email protected]> * Revert previous circular import fix Signed-off-by: Evgeny Tsykunov <[email protected]> * Fix pytorch imports in common Signed-off-by: Evgeny Tsykunov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Let quantizer know about the recipe Signed-off-by: Evgeny Tsykunov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix imports Signed-off-by: Evgeny Tsykunov <[email protected]> --------- Signed-off-by: Evgeny Tsykunov <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix split_overlap_rs aggregate=True chunk offset calculation Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add unit test for aggregate=True Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix unit test Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Guyue Huang <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…te (#1799) * Use an empty torch tensor to indicate no fp8 information in extra_state Signed-off-by: Peter St. John <[email protected]> * Add huggingface from_pretrained / save_pretrained tests Adds integration tests to ensure models containing TransformerLayer objects can be saved and loaded using the from_pretrained and save_pretrained methods. Signed-off-by: Peter St. John <[email protected]> --------- Signed-off-by: Peter St. John <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…n (#1611) * docs drop Signed-off-by: Pawel Gadzinski <[email protected]> * a Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * Update docs/debug/1_getting_started.rst Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Paweł Gadziński <[email protected]> * Update docs/debug/1_getting_started.rst Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Paweł Gadziński <[email protected]> * fixes Signed-off-by: Pawel Gadzinski <[email protected]> * fix imgs Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Paweł Gadziński <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]>
add docstring for CP Signed-off-by: Charlene Yang <[email protected]>
* Add missing docs for C API Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Grammar, typos, copy-paste errors Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * remove contiguous word Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Better wording Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* fix model parallel encoder to be properly sharded Signed-off-by: Sudhakar Singh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
fix saved_tensors Signed-off-by: Pawel Gadzinski <[email protected]>
Fix incorrectly skipped test_quantize_dbias tests Signed-off-by: Jeremy Berchtold <[email protected]>
Remove comm_gemm_overlap docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* Build support for cuda 13 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix build for cudnn 8.9*; cuda 12.1 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * readd include Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…rity (#1811) Make primitive names more granular for better disabling granularity Signed-off-by: Jeremy Berchtold <[email protected]>
Document all recipes Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…#1804) Activation ops support fusing backward pass with quantize Signed-off-by: Tim Moon <[email protected]>
* Fix env variable name in test.sh scripts to properly test pure-JAX implementations Signed-off-by: Jeremy Berchtold <[email protected]> * Update test scripts to use pure-JAX impl in encoder test_custom_call_compute.py already uses pure-JAX impl as reference so testing the pure-JAX impl against itself would be redundant. The encoder tests have their own implementation so testing the pure-JAX impl of primitives is still useful. Signed-off-by: Jeremy Berchtold <[email protected]> * Update qa/L0_jax_unittest/test.sh Co-authored-by: Phuong Nguyen <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]>
* Modify the test cases Signed-off-by: Przemek Tredak <[email protected]> * Make the tests reproducible on different machines Signed-off-by: Przemek Tredak <[email protected]> * Fixed the cache of the gamma_in_weight_dtype setting Signed-off-by: Przemek Tredak <[email protected]> * Reinstate the tests Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * More verbose code and comments Signed-off-by: Przemek Tredak <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* added conda installation Signed-off-by: Santosh Bhavani <[email protected]> * fix for pypi Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Santosh Bhavani <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Fix single FW build with multi FW available Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Some fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * sug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…ion (#1822) Update jax_scaled_masked_softmax to match TE kernel implementation Signed-off-by: Jeremy Berchtold <[email protected]>
* fp8 gemm with direct quant Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
* removes unnecessary reshapes for FP8 GEMM * use nn.jax.scaled_matmul Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
…needed (#1817) * Linear op avoids saving input tensor if weight grad is not needed Signed-off-by: Tim Moon <[email protected]> * Linear op forward avoids producing quantized tensors with unnecessary usages Signed-off-by: Tim Moon <[email protected]> * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Avoid unnecessary usages in fused linear ops Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]>
…#1813) * Changed the Tensor allocation strategy Signed-off-by: Przemek Tredak <[email protected]> * Fixes Signed-off-by: Przemek Tredak <[email protected]> * Disable debug flag Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the double free error Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Przemek Tredak <[email protected]> * Fixed pyTorch recipe extension Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Przemek Tredak <[email protected]> * Fix Signed-off-by: Przemek Tredak <[email protected]> * Hide TensorAllocator and fix the usage in LayerNorm Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleaning Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Przemek Tredak <[email protected]> * Fix permutation Signed-off-by: Przemek Tredak <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Support SWA in CP Ring Attn THD striped sharding Signed-off-by: Hua Huang <[email protected]> * Add some comments; move check to _FusedAttnCPWithP2PHelper.check_supported() Signed-off-by: Hua Huang <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Remove unused check Signed-off-by: Hua Huang <[email protected]> --------- Signed-off-by: Hua Huang <[email protected]>
Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Quantizer update Signed-off-by: Evgeny Tsykunov <[email protected]> * Update import Signed-off-by: Evgeny <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Introduce _update_weight_quantizers and _get_weight_tensors/_get_weight_quantizers Signed-off-by: Evgeny <[email protected]> * Add test Signed-off-by: Evgeny <[email protected]> * Move _quantizer to the QuantizedTensorBase Signed-off-by: Evgeny <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix import Signed-off-by: Evgeny Tsykunov <[email protected]> --------- Signed-off-by: Evgeny Tsykunov <[email protected]> Signed-off-by: Evgeny <[email protected]> Co-authored-by: Evgeny Tsykunov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]>
…talled (#1834) * Add warning for multi framework case Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Alp Dener <[email protected]> * fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Alp Dener <[email protected]>
|
LGTM -- only covered common dir and cpp tests. |
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
Show resolved
Hide resolved
Micky774
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mainly focused on the attention sections, and the install/build. Overall looks good, just some minor nits/questions.
| config.window_size = [2, 2] | ||
| config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) | ||
|
|
||
| is_training = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we not use is_training = config.head_dim_qk <= 192 and config.head_dim_v <= 128 as in line 400 later in this function for determining the available backend? Won't this potentially cause issues if the later is_training=False, where we could have had a certain backend enabled at this step but didn't because we assumed is_training=True? Or is that not a problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed this conflict with git merge. Fixed. Thanks
| #endif | ||
| // ROCm fused attn has two backends: aotriton and ck | ||
| // They both have the same shape and stride for softmax and rng aux tensors | ||
| // CK now support bias features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indent to keep aligned
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks
pyproject.toml
Outdated
| # See LICENSE for license information. | ||
|
|
||
| [build-system] | ||
| requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, this means that when building TE one must have both JAX and PyTorch installed in order to build even just for a single framework right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite familiar with this new pyproject.toml. Based on my experience with pip install --no-build-isolation, the source build does not interact with this pyproject.toml in my mind.
Please correct me if I was wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For lines 369-371
# Deallocate GEMM input tensor if no longer needed
if not weight.requires_grad and not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
we shouldn't make ln_out and ln_out_total None first, we should clear tensor data first and then make them none
Also my commit ([Feat] Add transpose cache to LayerNorm kernel (#279) ) had instead:
if not weight.requires_grad:
if not return_layernorm_output:
clear_tensor_data(ln_out, ln_out_total)
ln_out = None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks
ci/pytorch.sh
Outdated
| #_WORKERS_COUNT=$TEST_WORKERS | ||
| mkdir -p ${TEST_DIR}/checkpoint | ||
| python ${TEST_DIR}/test_checkpoint.py --save-checkpoint all --checkpoint-dir ${TEST_DIR}/checkpoint | ||
| NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=${TEST_DIR}/checkpoint run 1 test_checkpoint.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think test_checkpoint does not involve calling different fused_attn backends so whole this addition should be under if [ $_fus_attn = "$_DEFAULT_FUSED_ATTN" ]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, I did find this test involves attention:
| return te.TransformerLayer(1, 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It indeed creates TransformerLayer but it neither calls fwd nor bwd so FA backend is called. It tests saving of Torch.nn.module derivative classes state save/loading. In fact, even Flash vs Fused attn does not make difference in the state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Done. Thanks
Description
targeted NV upstream commit: ca7407e on 2025/07/18 based on our rocm dev commit 6bbd03c
Fixes https://github.com/ROCm/frameworks-internal/issues/13729
Type of change
Changes
See NV upstream release doc for upstream changes.
Our IFU conflict resolving are listed in the following commits:
1). common: 4d3ca4d
2). jax extension: 5ce0afd
3). pytorch extension: c9c9126
4). build/installation: 9730903
5). cpp gtests: 51bdbb8
6). pytorch pytests: ba59f81
7). jax pytests: 5842c24
Checklist: