Skip to content

Commit 89e1bf9

Browse files
committed
Fix backpropagation steps order in basics/quickstart_tutorial.py
1 parent 45a5b17 commit 89e1bf9

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

beginner_source/basics/quickstart_tutorial.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,16 @@ def train(dataloader, model, loss_fn, optimizer):
140140
model.train()
141141
for batch, (X, y) in enumerate(dataloader):
142142
X, y = X.to(device), y.to(device)
143-
143+
144+
# Zero the gradients for batch
145+
optimizer.zero_grad()
144146
# Compute prediction error
145147
pred = model(X)
146148
loss = loss_fn(pred, y)
147-
148149
# Backpropagation
149150
loss.backward()
151+
# Optimizer step
150152
optimizer.step()
151-
optimizer.zero_grad()
152153

153154
if batch % 100 == 0:
154155
loss, current = loss.item(), (batch + 1) * len(X)

0 commit comments

Comments
 (0)