|
3 | 3 |
|
4 | 4 | from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
|
5 | 5 | _differentiable_doc, _maximize_doc)
|
| 6 | +from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype |
6 | 7 | from typing import List, Optional
|
7 | 8 |
|
8 | 9 | __all__ = ["Adagrad", "adagrad"]
|
@@ -322,48 +323,50 @@ def _multi_tensor_adagrad(
|
322 | 323 | if len(params) == 0:
|
323 | 324 | return
|
324 | 325 |
|
325 |
| - if maximize: |
326 |
| - grads = torch._foreach_neg(grads) |
| 326 | + grouped_tensorlists = _group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps]) |
| 327 | + for device_params, device_grads, device_state_sums, device_state_steps in grouped_tensorlists.values(): |
| 328 | + |
| 329 | + if maximize: |
| 330 | + device_grads = torch._foreach_neg(device_grads) |
| 331 | + |
| 332 | + device_has_sparse_grad = any(grad.is_sparse for grad in device_grads) |
| 333 | + |
| 334 | + if device_has_sparse_grad: |
| 335 | + return _single_tensor_adagrad( |
| 336 | + device_params, |
| 337 | + device_grads, |
| 338 | + device_state_sums, |
| 339 | + device_state_steps, |
| 340 | + lr=lr, |
| 341 | + weight_decay=weight_decay, |
| 342 | + lr_decay=lr_decay, |
| 343 | + eps=eps, |
| 344 | + has_sparse_grad=True, |
| 345 | + maximize=False, |
| 346 | + differentiable=differentiable, |
| 347 | + ) |
327 | 348 |
|
328 |
| - if has_sparse_grad is None: |
329 |
| - has_sparse_grad = any(grad.is_sparse for grad in grads) |
| 349 | + # Update steps |
| 350 | + torch._foreach_add_(device_state_steps, 1) |
330 | 351 |
|
331 |
| - if has_sparse_grad: |
332 |
| - return _single_tensor_adagrad( |
333 |
| - params, |
334 |
| - grads, |
335 |
| - state_sums, |
336 |
| - state_steps, |
337 |
| - lr=lr, |
338 |
| - weight_decay=weight_decay, |
339 |
| - lr_decay=lr_decay, |
340 |
| - eps=eps, |
341 |
| - has_sparse_grad=has_sparse_grad, |
342 |
| - maximize=False, |
343 |
| - differentiable=differentiable, |
344 |
| - ) |
345 |
| - |
346 |
| - # Update steps |
347 |
| - torch._foreach_add_(state_steps, 1) |
348 |
| - |
349 |
| - if weight_decay != 0: |
350 |
| - torch._foreach_add_(grads, params, alpha=weight_decay) |
351 |
| - |
352 |
| - minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in state_steps] |
353 |
| - |
354 |
| - grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] |
355 |
| - state_sums = [ |
356 |
| - torch.view_as_real(x) if torch.is_complex(x) else x for x in state_sums |
357 |
| - ] |
358 |
| - torch._foreach_addcmul_(state_sums, grads, grads, value=1) |
359 |
| - std = torch._foreach_add(torch._foreach_sqrt(state_sums), eps) |
360 |
| - toAdd = torch._foreach_div(torch._foreach_mul(grads, minus_clr), std) |
361 |
| - toAdd = [ |
362 |
| - torch.view_as_complex(x) if torch.is_complex(params[i]) else x |
363 |
| - for i, x in enumerate(toAdd) |
364 |
| - ] |
365 |
| - torch._foreach_add_(params, toAdd) |
366 |
| - state_sums = [ |
367 |
| - torch.view_as_complex(x) if torch.is_complex(params[i]) else x |
368 |
| - for i, x in enumerate(state_sums) |
369 |
| - ] |
| 352 | + if weight_decay != 0: |
| 353 | + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) |
| 354 | + |
| 355 | + minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in device_state_steps] |
| 356 | + |
| 357 | + device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads] |
| 358 | + device_state_sums = [ |
| 359 | + torch.view_as_real(x) if torch.is_complex(x) else x for x in device_state_sums |
| 360 | + ] |
| 361 | + torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1) |
| 362 | + std = torch._foreach_add(torch._foreach_sqrt(device_state_sums), eps) |
| 363 | + toAdd = torch._foreach_div(torch._foreach_mul(device_grads, minus_clr), std) |
| 364 | + toAdd = [ |
| 365 | + torch.view_as_complex(x) if torch.is_complex(device_params[i]) else x |
| 366 | + for i, x in enumerate(toAdd) |
| 367 | + ] |
| 368 | + torch._foreach_add_(device_params, toAdd) |
| 369 | + device_state_sums = [ |
| 370 | + torch.view_as_complex(x) if torch.is_complex(device_params[i]) else x |
| 371 | + for i, x in enumerate(device_state_sums) |
| 372 | + ] |
0 commit comments