Skip to content

Commit da413d1

Browse files
author
Nicholas Carlotti
committed
<EXP> Grad
1 parent c7d82dc commit da413d1

File tree

2 files changed

+286
-0
lines changed

2 files changed

+286
-0
lines changed

src/models/resnet.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import numpy as np
2+
from numpy import unravel_index, stack
3+
4+
import torch
5+
from torchvision.models import MobileNetV2
6+
import torch
7+
from src.models import ModelRegistry, BaseModel
8+
from src.models.fcn import FullyConvPredictorMixin
9+
import torch.nn as nn
10+
from pytorch_grad_cam import AblationCAM
11+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, RawScoresOutputTarget
12+
from src.models import ModelRegistry, BaseModel
13+
14+
@ModelRegistry("cam")
15+
class ResnetCAMWrapper(BaseModel):
16+
17+
18+
def __init__(self, *args, **kwargs) -> None:
19+
kwargs.pop('led_inference')
20+
# kwargs.pop('task')
21+
22+
super().__init__(*args, **kwargs)
23+
self.configs=[
24+
# t, c, n, s
25+
[1, 32, 1, 2],
26+
[6, 64, 2, 2],
27+
[6, 128, 2, 1],
28+
]
29+
self.model = MobileNetV2(num_classes=6, inverted_residual_setting=self.configs)
30+
target_layers = [self.model.features[-1][0]]
31+
self.cam = AblationCAM(self.model, target_layers=target_layers)
32+
33+
34+
def forward(self, x):
35+
return torch.nn.functional.sigmoid(self.model(x))
36+
37+
def loss(self, batch, model_out):
38+
led_preds = model_out
39+
led_labels = batch['led_mask'].to(model_out.device)
40+
losses = torch.zeros_like(led_labels, device = model_out.device, dtype=torch.float32)
41+
for i in range(led_labels.shape[1]):
42+
losses[:, i] = torch.nn.functional.binary_cross_entropy(
43+
led_preds[:, i], led_labels[:, i].float(), reduction='none'
44+
)
45+
# We only care about 4 leds
46+
losses[:, 1] = 0.
47+
losses[:, 2] = 0.
48+
return losses.mean() * 1.5, losses.detach().mean(0)
49+
50+
def predict_leds(self, x):
51+
out = self(x)
52+
return out
53+
54+
def predict_pos(self, images):
55+
led_ids = [0, 3, 4, 5]
56+
coords = np.zeros((images.shape[0], 4, 2))
57+
58+
59+
for image_idx in range(images.shape[0]):
60+
x = images[image_idx, ...][None, ...]
61+
for i, l in enumerate(led_ids):
62+
maps = self.cam(input_tensor=x, targets=[ClassifierOutputTarget(l)])
63+
out_map_shape = maps.shape[-2:]
64+
maps = maps.reshape((*maps.shape[:-2], -1))
65+
max_idx = maps.argmax(1)
66+
indexes = unravel_index(max_idx, out_map_shape)
67+
# x y
68+
indexes = stack([indexes[1], indexes[0]]).T.astype('float32')
69+
indexes /= np.array([out_map_shape[1], out_map_shape[0]])
70+
indexes *= np.array([x.shape[-1], x.shape[-2]])
71+
72+
y_scale_f = x.shape[0] / out_map_shape[0]
73+
x_scale_f = x.shape[1] / out_map_shape[1]
74+
indexes += np.array([x_scale_f, y_scale_f]) / 2
75+
76+
coords[:, i, :] = indexes
77+
return coords.mean(1)
78+
79+
def optimizer(self, learning_rate):
80+
return torch.optim.Adam(self.parameters(), lr=learning_rate)
81+
82+
83+
84+
85+
86+
87+
88+

train_grad_baseline.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from src.config.argument_parser import parse_args
2+
from src.models import get_model, BaseModel, load_model_mlflow
3+
from src.dataset.dataset import H5Dataset, get_dataset
4+
import torch
5+
import torch.utils.data
6+
import mlflow
7+
from src.metrics import binary_auc, angle_difference, leds_auc
8+
from statistics import mean
9+
from tqdm import trange
10+
import numpy as np
11+
12+
13+
def get_lr_scheduler(schedule_name, optimizer, epochs, lr):
14+
if schedule_name == 'cosine':
15+
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, lr / 100, -1)
16+
elif schedule_name == 'shark':
17+
return torch.optim.lr_scheduler.SequentialLR(optimizer,
18+
[
19+
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs // 2, lr / 100, -1)
20+
] * 2,
21+
[epochs // 2,])
22+
23+
def train_loop(model : BaseModel, train_dataloader, val_dataloader, device,
24+
epochs, supervised_count, unsupervised_count,
25+
lr = .001, validation_rate = 10,
26+
checkpoint_logging_rate = 10,
27+
loss_weights = {'pos' : .2,'dist' : .0,'ori' : .0,'led' : .8},
28+
lr_schedule = 'cosine'
29+
):
30+
31+
32+
optimizer = model.optimizer(lr)
33+
34+
lr_schedule = get_lr_scheduler(lr_schedule, optimizer, epochs, lr)
35+
36+
_cuda_weights = {k: torch.tensor([v], device = device) for k, v in loss_weights.items()}
37+
supervised_count = torch.tensor([supervised_count + 1e-15], device=device)
38+
unsupervised_count = torch.tensor([unsupervised_count + 1e-15], device=device)
39+
40+
for e in trange(epochs):
41+
losses = [0] * len(train_dataloader)
42+
multiple_led_losses = [[0] * 6, ] * len(train_dataloader)
43+
44+
45+
preds = []
46+
theta_preds = []
47+
dist_preds = []
48+
49+
trues = []
50+
dist_trues = []
51+
theta_trues = []
52+
53+
model.train()
54+
for batch_i, batch in enumerate(train_dataloader):
55+
optimizer.zero_grad()
56+
57+
image = batch['image'].to(device)
58+
out = model.forward(image)
59+
60+
led_loss, m_led_loss = model.loss(batch, out)
61+
62+
loss = led_loss.sum()
63+
loss.backward()
64+
optimizer.step()
65+
66+
losses[batch_i] = loss.detach().item()
67+
multiple_led_losses[batch_i] = [l.item() for l in m_led_loss]
68+
69+
multiple_led_losses = np.stack(multiple_led_losses, axis = 0)
70+
71+
mlflow.log_metric('train/loss', sum(losses), e)
72+
73+
mlflow.log_metric('train/loss/coefficients/proj', loss_weights['pos'], e)
74+
mlflow.log_metric('train/loss/coefficients/dist', loss_weights['dist'], e)
75+
mlflow.log_metric('train/loss/coefficients/ori', loss_weights['ori'], e)
76+
mlflow.log_metric('train/loss/coefficients/led', loss_weights['led'], e)
77+
78+
79+
for i, led_label, in enumerate(H5Dataset.LED_TYPES):
80+
mlflow.log_metric(f'train/loss/led/{led_label}', multiple_led_losses[:, i].mean(), e)
81+
82+
mlflow.log_metric('train/lr', lr_schedule.get_last_lr()[0], e)
83+
84+
if e % checkpoint_logging_rate == 0 or e == epochs - 1:
85+
model.log_checkpoint(e)
86+
87+
lr_schedule.step()
88+
89+
90+
91+
if val_dataloader and (e % validation_rate == 0 or e == epochs - 1):
92+
preds = []
93+
trues = []
94+
losses = []
95+
led_preds = []
96+
led_trues = []
97+
led_visibility = []
98+
99+
model.eval()
100+
101+
with torch.no_grad():
102+
for batch in val_dataloader:
103+
image = batch['image'].to(device)
104+
105+
out = model.forward(image)
106+
led_loss, m_led_loss = model.loss(batch, out)
107+
mean_l_loss = led_loss.mean().detach()
108+
109+
110+
loss = mean_l_loss
111+
losses.append(loss.item())
112+
113+
pos_preds = model.predict_pos(image)
114+
preds.extend(pos_preds)
115+
trues.extend(batch['proj_uvz'][:, :-1].cpu().numpy())
116+
117+
led_preds.extend(model.predict_leds(batch))
118+
led_trues.extend(batch["led_mask"])
119+
led_visibility.extend(batch["led_visibility_mask"])
120+
121+
122+
errors = np.linalg.norm(np.stack(preds) - np.stack(trues), axis = 1)
123+
mlflow.log_metric('validation/position/median_error', np.median(errors), e)
124+
125+
led_preds = np.array(led_preds)
126+
led_trues = np.array(led_trues)
127+
led_visibility = np.array(led_visibility)
128+
129+
130+
aucs = []
131+
for i, led_label in enumerate(H5Dataset.LED_TYPES):
132+
vis = led_visibility[:, i]
133+
auc = binary_auc(led_preds[vis, i], led_trues[vis, i])
134+
mlflow.log_metric(f'validation/led/auc/{led_label}', auc, e)
135+
aucs.append(auc)
136+
mlflow.log_metric('validation/led/auc', mean(aucs), e)
137+
mlflow.log_metric('validation/loss', mean(losses), e)
138+
139+
140+
def main():
141+
args = parse_args("train")
142+
143+
if args.checkpoint_id:
144+
model, run_id = load_model_mlflow(experiment_id=args.experiment_id, mlflow_run_name=args.weights_run_name, checkpoint_idx=args.checkpoint_id,
145+
model_kwargs={'task' : args.task, 'led_inference' : args.led_inference}, return_run_id=True)
146+
model = model.to(args.device)
147+
else:
148+
model_cls = get_model(args.model_type)
149+
model = model_cls(task = args.task, led_inference = args.led_inference).to(args.device)
150+
151+
train_dataset = train_dataset = get_dataset(args.dataset, sample_count=args.sample_count, sample_count_seed=args.sample_count_seed, augmentations=True,
152+
only_visible_robots=args.visible, compute_led_visibility=False,
153+
supervised_flagging=args.labeled_count,
154+
supervised_flagging_seed=args.labeled_count_seed,
155+
non_visible_perc=args.non_visible_perc
156+
)
157+
print(args.batch_size)
158+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, num_workers=8, shuffle=True)
159+
160+
"""
161+
Validation data
162+
"""
163+
if args.validation_dataset:
164+
validation_dataset = get_dataset(args.validation_dataset, augmentations=False, only_visible_robots=True, compute_led_visibility=True)
165+
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size = 64, num_workers=8)
166+
else:
167+
validation_dataloader = None
168+
169+
170+
if args.dry_run:
171+
for name, val in mlflow.__dict__.items():
172+
if callable(val):
173+
val = lambda *args, **kwargs: (None, )
174+
175+
loss_weights = {
176+
'pos' : 0.,
177+
'dist' : 0.,
178+
'ori' : 0.,
179+
'led' : 1.,
180+
181+
}
182+
ds_size = len(train_dataset)
183+
supervised_count = args.labeled_count if args.labeled_count else ds_size
184+
if 'cuda' in args.device:
185+
torch.backends.cudnn.benchmark = True
186+
187+
with mlflow.start_run(experiment_id=args.experiment_id, run_name=args.run_name) as run:
188+
mlflow.log_params(vars(args))
189+
train_loop(model, train_dataloader, validation_dataloader, args.device,
190+
epochs=args.epochs, lr=args.learning_rate, loss_weights = loss_weights,
191+
supervised_count=supervised_count,
192+
unsupervised_count=ds_size - supervised_count,
193+
lr_schedule=args.lr_schedule)
194+
print(run.info.run_id)
195+
196+
197+
if __name__ == "__main__":
198+
main()

0 commit comments

Comments
 (0)