How to pass gradients to backward()
#10257
-
In my experiment, I have a loss function, which is not defined by an expression. But I have a formulation of the gradient (formally a subgradient) so I have to pass the gradient manually. In pytorch, I implement it in the following way and it is working fine. self.optimizer.zero_grad()
y_hat = self.model(x_train)
grad = compute_grad(y_hat, y)
y_hat.backward(gradient=grad)
self.optimizer.step() Would the following be a correct implementation in lightning? def training_step(self, batch, batch_idx):
opt = self.optimizers()
x,y = batch
y_hat = self(x)
grad = compute_grad(y_hat, y)
opt.zero_grad()
y_hat.backward(gradient= grad)
opt.step() |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
@JayMan91 Haven't tried, but you could probably try manual optimization. def __init__(...):
...
self.automatic_optimization = False # use manual optimization
...
def training_step(...):
...
self.manual_backward(y_hat, gradient=grad)
... manual optimization: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization |
Beta Was this translation helpful? Give feedback.
@JayMan91 Haven't tried, but you could probably try manual optimization.
manual optimization: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization
LightningModule.manual_backward
: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#manual-backward