-
Notifications
You must be signed in to change notification settings - Fork 376
[sparse] Migrate Float8SemiSparseTensor off of AQT #3361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3361
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit d2f51b6 with merge base 5f33595 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| """Use torchao cutlass kernel for fp8 + 2:4 sparse mm, requires building torchao with CUDA | ||
| """ | ||
| SPARSE_CUTLASS = "sparse_cutlass" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my understanding is this is a new packing format, why is this a new kernel preference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sparse_cutlass vs sparse_cusparselt/hipsparselt is something we will need for AMD support coming up next half, which sounds like kernel preference to me (decide which op to use).
But if this is more a general thing and packing_format is the more specific way to decide op dispatch I am fine with using that as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcaip , would be good to specify if the data format will be different and kernels different, or if data format is the same and kernels different.
| kernel_choice = "sparse_cutlass" | ||
| elif kernel_preference == KernelPreference.SPARSE_CUTLASS: | ||
| # if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel | ||
| assert is_sm_at_least_90(), ( | ||
| "Specified sparse_cutlass kernel and hardware is not >= SM 9.0 (>= H100)" | ||
| ) | ||
| kernel_choice = "sparse_cutlass" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if "sparse_cutlass" is the only option, then I don't think we are dealing with a kernel preference here?
| from .float8_tensor import QuantizeTensorToFloat8Kwargs | ||
|
|
||
|
|
||
| class Float8SemiSparseTensor(TorchAOBaseTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a more descriptive name, something like Float8With2By4SparsityTensor?
| dtype: Optional[torch.dtype] = None, | ||
| ): | ||
| super().__init__() | ||
| self.sparse_quantized_data = sparse_quantized_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about qdata to match other tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do sparse_dqata? but I think just qdata is a bit confusing since qdata is split between the specified values and metadata
| """ | ||
| Sparse packing formats for 2:4 sparsity + FP8 quantization | ||
| """ | ||
| SPARSE_CUTLASS = "sparse_cutlass" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intent is for the sparse tensor to use OPAQUE, and you can keep these formats internal to your workflow
| SPARSE_CUTLASS = "sparse_cutlass" | ||
|
|
||
| """ | ||
| SPARSE_CUSPARSELT will pack the quantized_data into a single tensor, sparse_qdata, which contains both the specified values and appends the metadata. | ||
| This packing format will dispatch to `_cslt_sparse_mm`, which does not fuse per-row scaling into the matmul. | ||
| """ | ||
| SPARSE_CUSPARSELT = "sparse_cusparselt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should these belong to Float8PackingFormat? we structure these by "dtype" currently
This PR migrates
Float8DynamicActivationFloat8SemiSparseWeighConfigoff of using the AQT CutlassSemiSparseLayout subclass.The old AQT flow can still be used by passing
version=1into the configTesting: