From 1ab8c9bedc62a0825e187c836f0e76541f3cfd8a Mon Sep 17 00:00:00 2001
From: Logan Adams <114770087+loadams@users.noreply.github.com>
Date: Fri, 17 Jan 2025 17:20:46 -0800
Subject: [PATCH] Update `torch.norm` to `torch.linalg.norm` and
 `torch.linalg.vector_norm` (#6931)

- [x] Update PR since `torch.norm` and `torch.linalg.norm` have
[different function
signatures](https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm).
- [x] Check if there are any numeric differences between the functions.
- [x] Determine why there appear to be performance differences from
others [here](https://github.com/pytorch/pytorch/issues/136360).
- [x] Update to `torch.linalg.vectornorm`
Follow up PR handles these in the comm folder: #6960
---
 deepspeed/runtime/comm/compressed.py    | 2 +-
 deepspeed/runtime/comm/hccl.py          | 2 +-
 deepspeed/runtime/fp16/onebit/lamb.py   | 2 +-
 deepspeed/runtime/zero/stage3.py        | 2 +-
 deepspeed/runtime/zero/stage_1_and_2.py | 5 +++--
 5 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/deepspeed/runtime/comm/compressed.py b/deepspeed/runtime/comm/compressed.py
index 7f8c7395451dc..2c5482eb1ad76 100644
--- a/deepspeed/runtime/comm/compressed.py
+++ b/deepspeed/runtime/comm/compressed.py
@@ -96,7 +96,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro
 
         compensated_server_m.add_(server_error)
 
-        server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
+        server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
 
         server_error.set_(compensated_server_m -
                           server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
diff --git a/deepspeed/runtime/comm/hccl.py b/deepspeed/runtime/comm/hccl.py
index 09fb11a731b8f..b8639c7da4c99 100644
--- a/deepspeed/runtime/comm/hccl.py
+++ b/deepspeed/runtime/comm/hccl.py
@@ -83,7 +83,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro
 
         compensated_server_m.add_(server_error)
 
-        server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
+        server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
 
         server_error.set_(compensated_server_m -
                           server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py
index 89b6f40a308c3..9e7bae816ecde 100644
--- a/deepspeed/runtime/fp16/onebit/lamb.py
+++ b/deepspeed/runtime/fp16/onebit/lamb.py
@@ -177,7 +177,7 @@ def step(self, closure=None, grads=None):
                 # This is used to reduce compression error during compression stage.
                 momentum_scales = []
                 for group in self.param_groups:
-                    momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) /
+                    momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) /
                                              np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
                                             for p in group['params']])
                 united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])
diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index 28f91cb9b3abd..9c06567ed100a 100644
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -2101,7 +2101,7 @@ def step(self, closure=None):
             return
 
         norm_groups = self._get_norm_groups()
-        scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))
+        scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
 
         # Stash unscaled gradient norm
         self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py
index 0508766f88964..ed3425167944a 100755
--- a/deepspeed/runtime/zero/stage_1_and_2.py
+++ b/deepspeed/runtime/zero/stage_1_and_2.py
@@ -1691,7 +1691,8 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
                     continue
                 if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
                     all_norms.append(
-                        torch.norm(g.data.double().detach(), norm_type).to(get_accelerator().current_device_name()))
+                        torch.linalg.vector_norm(g.data.double().detach(),
+                                                 ord=norm_type).to(get_accelerator().current_device_name()))
             if len(all_norms) > 0:
                 total_norm = torch.stack(all_norms).square().sum().float()
             else:
@@ -1795,7 +1796,7 @@ def scaled_global_norm(self, norm_type=2):
             self._average_expert_grad_norms(norm_groups)
 
         # calculating L2 norm
-        return torch.norm(torch.stack(norm_groups), p=norm_type)
+        return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type)
 
     def get_bit16_param_group(self, group_no):
         bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]