Context
Pass fused=True to ADAM construction:
- Fuses entire ADAM update into a single kernel per parameter group. Should reduce kernel launch overhead.
Pass foreach=True to clip_grad_norm(...) call:
- The default
clip_grad_norm iterates per-parameter in python, launching one norm kernel per tensor. With foreach=True, torch uses torch._foreach_norm to batch all gradient norms into a single multi-tensor kernel. This collapses the dozens-to-hundreds of individual norm kernels + a python reduction loop into one fused kernel.
Do zero_grad(set_to_none=True) in training:
- When False, torch runs one
memset kernel per parameter to fill gradient tensors with zeros. When set_to_none=True instead, it simply drops the .grad reference.
fused=Trueto ADAM constructionforeach=Truetoclip_grad_norm(...)callzero_grad(set_to_none=True)in traininglocal_voxel_countandglobal_total_voxelsoutside the batch loop to save an all-reduce?torch.compile()the modeltorch.compile()the unscale->clip->optimizer step->update block.item()calls for things like gradient logging?Context
Pass
fused=Trueto ADAM construction:Pass
foreach=Truetoclip_grad_norm(...)call:clip_grad_normiterates per-parameter in python, launching one norm kernel per tensor. Withforeach=True, torch usestorch._foreach_normto batch all gradient norms into a single multi-tensor kernel. This collapses the dozens-to-hundreds of individual norm kernels + a python reduction loop into one fused kernel.Do
zero_grad(set_to_none=True)in training:memsetkernel per parameter to fill gradient tensors with zeros. Whenset_to_none=Trueinstead, it simply drops the.gradreference.