Skip to content

Paged Stashing#2690

Draft
nanz-nv wants to merge 56 commits intoNVIDIA:devfrom
vasunvidia:paged_offloading
Draft

Paged Stashing#2690
nanz-nv wants to merge 56 commits intoNVIDIA:devfrom
vasunvidia:paged_offloading

Conversation

@nanz-nv
Copy link

@nanz-nv nanz-nv commented Dec 17, 2025

Main contributors (Equal Contribution, sorted alphabetically): Nan Zheng (@nanz-nv), Vasudevan Rengasamy (@vasunvidia)
Other contributors (sorted alphabetically): Dennis Liu(@Victarry), Hongbin Liu(@lhb8125), Qi Zhang(@QiZhangNV), Robin Zhang(@buptzyb), Tong Liu(@Autumn1998), Zijie Yan(@yanring)

Background

In token-dropless MoE training, the number of tokens received by each expert might vary, resulting in dynamic shaped tensors. Dynamic shaped tensors are naturally supported by PyTorch, thanks to its eager mode nature. This is done by creating a tensor lazily when the shape of the tensor is known at run-time. Albeit working well in eager mode, dynamic shaped tensor poses challenges for CUDA graphs because the the size of a tensor cannot be dynamically adjusted at runtime without the intervene of the host. In order to remove the sync and enable CUDA graph, one solution is to oversize the buffer in the expert part. This however causes significantly higher memory consumption compared to the eager-mode baseline through the form of memory fragmentation.

image

Idea overview

To address this problem, paged stashing decouples the need of oversized buffers for compute and the need of a properly sized buffer for storing activations for the backward pass. Paged stashing achieves this through adding one level of indirection: stashing and restoring. The stash operation copies the activation from the oversized static buffer to a pre-allocated stashing buffer after the forward for that module is done, and the restore operation does the reverse operation during the backward pass.

image

The key of saving memory lies in the fact that the stash operation packs the variable-size activation into a contiguous stashing buffer to reduce memory fragmentation. For simple scheduling where the activation allocation and deallocation follows a first-in-last-out pattern, stash and restore can be done easily in a bump-allocation manner. To accommodate complicated scheduling schedules, e.g. pipeline parallel, paging can be used, hence the name paged stashing.

page management

To accomodate complex scheduling such as that needed in pipeline parallelism, activations are partitioned into pages and a light-weight memory management kernel is in charge of allocate and deallocate pages for stashing. Pages are managed by lightweight GPU memory management kernels that can be fused with the stash/restore GPU kernels. It maintains a freelist which is implemented as a circular buffer. Each freelist keeps track of one type of pages.

CPU offloading

Paged stashing naturally supports offloading. When the stashing buffer is a pinned CPU tensor, the activation is offloaded to the host memory during forward and is reloaded to the GPU during backward.
Furthermore, one can easily extend the paging management system to accommodate partial offloading or on-demand offloading. This feature is currently WIP.

scheduling

Overlapping stashing and restore operations with compute can be implemented by inserting two autograd functions before and after the expert compute layer: pre-scheduler and post-scheduler that schedules stash and restore operations. The roles of these autograd functions are enumerated below:

  • Pre-scheduler forward: Wait for previous stash op. to complete, free the max-capacity sized temporary activations for the completed stash op. The wait is performed here instead of Post-scheduler forward to reduce the peak memory usage since the following expert compute layer will allocate another set of max-capacity sized temporary activations.
  • Post-scheduler forward: Since this is after experts compute, stashing operations for the current layer activations are scheduled here. If the next layer in the execution is a backward pass layer, schedule restore operations for the next layer.
    Additionally, in case of pipeline parallelism, this can be used to record the pipeline schedule during the first iteration.
  • Post-scheduler backward: Wait for previous stash op. to complete, free the max-capacity sized temporary activations for the completed stash op. The wait is performed here instead of Pre-scheduler backward to reduce the peak memory usage since the following expert compute BPROP layer will allocate another set of max-capacity sized temporary activations.
    Wait for restore operation for the current layer to complete. Additionally, in case of pipeline parallelism, this can be used to record the pipeline schedule during the first iteration.
  • Pre-scheduler backward: If the next layer in the execution is a backward pass layer, schedule restore operations for the next layer.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 17, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Victarry
Copy link
Contributor

/ok to test 3e8c042

@github-actions
Copy link
Contributor

Thank you for your contribution!

NVIDIA Megatron-LM is currently transitioning to development on Github. We will aim to review your PR after we complete our transition and stabilize our Github development process.

Thank you for your understanding.

f"{self.paged_tensors_to_reload[pp_schedule_layer]}"
)

def allocate_stash_buffers(self, stash_buffer_size_factor=1.10):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious how stash_buffer_size_factor is going to be determined? Is 1.10 be reasonable enough?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether 1.1 is enough depends on the distribution of the token distribution in each layer.
One can use the remaining GPU memory after fitting the model/activation as the stashing buffer size. Or this can be done through some iterative trial, similar to deciding the best sharding/microbatch size that can fit on GPU with load imbalance.

ksivaman and others added 15 commits March 16, 2026 23:59
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
… to use_transformer_engine_op_fuser

Enforce Router padding for paged stashing

Initial commit to enable paged stashing for TE fused op

Enable stashing for 1D shape, colwise_scale_inv tensors

Use moe_paged_stash to enable/disable stashing with fused op

Use use_transformer_engine_op_fuser to enable/disable fused op

Dynamic-shape no-stashing fallback for non-CG

Dynamic-shape no-stashing fallback + Full CG

Eliminate sync in mtp loss cal

enable 1f1b overlap

Add overflow check back temporarily before changes for PagedStashRunner is ready

nanz/megatron-lm!1 - Paged stashing fallback
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants