-
Notifications
You must be signed in to change notification settings - Fork 22
JAX FA Benchmarking Script #351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
ipanfilo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why ck_fused_attn_bwd needs modification if the PR is for fwd-pass only?
|
Note that I have added the BWD pass implementation as well to this PR. |
|
Pinging @wangye805 @wenchenvincent in case either of you are interested in reviewing this PR as well, thanks! |
benchmarks/attention/README.md
Outdated
| ## JAX Fused-Attention Benchmarking | ||
| The benchmarking process is split into two stages: *generating* the timing data, and *visualizing* the timing data. The following steps assume you are located in `TransformerEngine/benchmarks/attention` (i.e. where this README is located). First, ensure that you install requirements via `pip install -r requirements.txt`. | ||
|
|
||
| Note: Only forward timings are supported at this point. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| from transformer_engine.jax import fp8_autocast | ||
|
|
||
| # Needed in order to dump timings properly | ||
| os.environ["XLA_FLAGS"]="--xla_gpu_graph_level=0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because you used dumping time function in ck fused attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
| attn_bias_type, bias_shape = bias_config | ||
| window_size = None | ||
| if swa: | ||
| window_size = (s_kv // 10, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do this for SWA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was taken from our JAX FA testing.
|
Have you compared the kernel time measured from CK FA API vs from rocprof? |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: