Skip to content

[mxfp8 moe traing] to_mxfp8_a2a_dequant flat perf vs high precision a2a #3112

@danielvegamyhre

Description

@danielvegamyhre

Looking at torchtitan traces, it seems the mxfp8 a2a dispatch is ~2x faster than bf16 (1681us vs 3487us), but the mxfp8 a2a combine is (1) roughly exact same duration as bf16 a2a combine, and (2) bf16/mxfp8 a2a impls both take roughly 23x longer than a2a dispatch. This is unexpected.

Tlparse: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpHhVPih/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Trace for llama4 EP=4, using defalut a2a_impl

Image

Trace for llama4 EP=4, to_mxfp8_a2a_dequant

Image

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions