From 645a643591b71c75a9ac291df886195149c18e5c Mon Sep 17 00:00:00 2001 From: iorymaeda <86793373+iorymaeda@users.noreply.github.com> Date: Sat, 15 Jul 2023 15:38:41 +0300 Subject: [PATCH] Checking for requires_grad --- torch_optimizer/lookahead.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_optimizer/lookahead.py b/torch_optimizer/lookahead.py index 39abf07..53afe3f 100644 --- a/torch_optimizer/lookahead.py +++ b/torch_optimizer/lookahead.py @@ -57,6 +57,9 @@ def __init__( def _update(self, group: Dict[str, Any]) -> None: for fast in group["params"]: + if not fast.requires_grad: + continue + param_state = self.state[fast] if "slow_param" not in param_state: param_state["slow_param"] = torch.clone(fast.data).detach()