Skip to content

Commit b2ca2c8

Browse files
janeyx99pytorchmergebot
authored andcommitted
[optim][adagrad] group tensors in foreach to maximize perf (pytorch#92362)
another one Pull Request resolved: pytorch#92362 Approved by: https://github.com/albanD
1 parent 44132cc commit b2ca2c8

File tree

1 file changed

+46
-43
lines changed

1 file changed

+46
-43
lines changed

torch/optim/adagrad.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
55
_differentiable_doc, _maximize_doc)
6+
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
67
from typing import List, Optional
78

89
__all__ = ["Adagrad", "adagrad"]
@@ -322,48 +323,50 @@ def _multi_tensor_adagrad(
322323
if len(params) == 0:
323324
return
324325

325-
if maximize:
326-
grads = torch._foreach_neg(grads)
326+
grouped_tensorlists = _group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps])
327+
for device_params, device_grads, device_state_sums, device_state_steps in grouped_tensorlists.values():
328+
329+
if maximize:
330+
device_grads = torch._foreach_neg(device_grads)
331+
332+
device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
333+
334+
if device_has_sparse_grad:
335+
return _single_tensor_adagrad(
336+
device_params,
337+
device_grads,
338+
device_state_sums,
339+
device_state_steps,
340+
lr=lr,
341+
weight_decay=weight_decay,
342+
lr_decay=lr_decay,
343+
eps=eps,
344+
has_sparse_grad=True,
345+
maximize=False,
346+
differentiable=differentiable,
347+
)
327348

328-
if has_sparse_grad is None:
329-
has_sparse_grad = any(grad.is_sparse for grad in grads)
349+
# Update steps
350+
torch._foreach_add_(device_state_steps, 1)
330351

331-
if has_sparse_grad:
332-
return _single_tensor_adagrad(
333-
params,
334-
grads,
335-
state_sums,
336-
state_steps,
337-
lr=lr,
338-
weight_decay=weight_decay,
339-
lr_decay=lr_decay,
340-
eps=eps,
341-
has_sparse_grad=has_sparse_grad,
342-
maximize=False,
343-
differentiable=differentiable,
344-
)
345-
346-
# Update steps
347-
torch._foreach_add_(state_steps, 1)
348-
349-
if weight_decay != 0:
350-
torch._foreach_add_(grads, params, alpha=weight_decay)
351-
352-
minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in state_steps]
353-
354-
grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
355-
state_sums = [
356-
torch.view_as_real(x) if torch.is_complex(x) else x for x in state_sums
357-
]
358-
torch._foreach_addcmul_(state_sums, grads, grads, value=1)
359-
std = torch._foreach_add(torch._foreach_sqrt(state_sums), eps)
360-
toAdd = torch._foreach_div(torch._foreach_mul(grads, minus_clr), std)
361-
toAdd = [
362-
torch.view_as_complex(x) if torch.is_complex(params[i]) else x
363-
for i, x in enumerate(toAdd)
364-
]
365-
torch._foreach_add_(params, toAdd)
366-
state_sums = [
367-
torch.view_as_complex(x) if torch.is_complex(params[i]) else x
368-
for i, x in enumerate(state_sums)
369-
]
352+
if weight_decay != 0:
353+
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
354+
355+
minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in device_state_steps]
356+
357+
device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
358+
device_state_sums = [
359+
torch.view_as_real(x) if torch.is_complex(x) else x for x in device_state_sums
360+
]
361+
torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1)
362+
std = torch._foreach_add(torch._foreach_sqrt(device_state_sums), eps)
363+
toAdd = torch._foreach_div(torch._foreach_mul(device_grads, minus_clr), std)
364+
toAdd = [
365+
torch.view_as_complex(x) if torch.is_complex(device_params[i]) else x
366+
for i, x in enumerate(toAdd)
367+
]
368+
torch._foreach_add_(device_params, toAdd)
369+
device_state_sums = [
370+
torch.view_as_complex(x) if torch.is_complex(device_params[i]) else x
371+
for i, x in enumerate(device_state_sums)
372+
]

0 commit comments

Comments
 (0)