Skip to content

Improve GPU utilization #62

@PatrickRMiles

Description

@PatrickRMiles
  • Pass fused=True to ADAM construction
  • Pass foreach=True to clip_grad_norm(...) call
  • Do zero_grad(set_to_none=True) in training
  • Can we compute local_voxel_count and global_total_voxels outside the batch loop to save an all-reduce?
  • torch.compile() the model
  • torch.compile() the unscale->clip->optimizer step->update block
  • Can we reduce per-batch .item() calls for things like gradient logging?

Context

Pass fused=True to ADAM construction:

  • Fuses entire ADAM update into a single kernel per parameter group. Should reduce kernel launch overhead.

Pass foreach=True to clip_grad_norm(...) call:

  • The default clip_grad_norm iterates per-parameter in python, launching one norm kernel per tensor. With foreach=True, torch uses torch._foreach_norm to batch all gradient norms into a single multi-tensor kernel. This collapses the dozens-to-hundreds of individual norm kernels + a python reduction loop into one fused kernel.

Do zero_grad(set_to_none=True) in training:

  • When False, torch runs one memset kernel per parameter to fill gradient tensors with zeros. When set_to_none=True instead, it simply drops the .grad reference.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions