Skip to content

Commit 64b0a42

Browse files
Create lr_finder.py
1 parent a3d2668 commit 64b0a42

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

utils/lr_finder.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from torch_lr_finder import LRFinder
2+
3+
4+
def find_lr(model, train_loader, test_loader, epochs, optimizer, criterion, device):
5+
"""
6+
Find best LR.
7+
"""
8+
lr_finder = LRFinder(model, optimizer, criterion, device=device)
9+
lr_finder.range_test(
10+
train_loader,
11+
val_loader=test_loader,
12+
step_mode="linear",
13+
end_lr=0.5,
14+
num_iter=epochs * len(test_loader),
15+
diverge_th=50,
16+
)
17+
max_lr = lr_finder.plot(suggest_lr=True, skip_start=0, skip_end=0)
18+
lr_finder.reset()
19+
20+
return max_lr[-1]

0 commit comments

Comments
 (0)