|
| 1 | +from src.config.argument_parser import parse_args |
| 2 | +from src.models import get_model, BaseModel, load_model_mlflow |
| 3 | +from src.dataset.dataset import H5Dataset, get_dataset |
| 4 | +import torch |
| 5 | +import torch.utils.data |
| 6 | +import mlflow |
| 7 | +from src.metrics import binary_auc, angle_difference, leds_auc |
| 8 | +from statistics import mean |
| 9 | +from tqdm import trange |
| 10 | +import numpy as np |
| 11 | + |
| 12 | + |
| 13 | +def get_lr_scheduler(schedule_name, optimizer, epochs, lr): |
| 14 | + if schedule_name == 'cosine': |
| 15 | + return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, lr / 100, -1) |
| 16 | + elif schedule_name == 'shark': |
| 17 | + return torch.optim.lr_scheduler.SequentialLR(optimizer, |
| 18 | + [ |
| 19 | + torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs // 2, lr / 100, -1) |
| 20 | + ] * 2, |
| 21 | + [epochs // 2,]) |
| 22 | + |
| 23 | +def train_loop(model : BaseModel, train_dataloader, val_dataloader, device, |
| 24 | + epochs, supervised_count, unsupervised_count, |
| 25 | + lr = .001, validation_rate = 10, |
| 26 | + checkpoint_logging_rate = 10, |
| 27 | + loss_weights = {'pos' : .2,'dist' : .0,'ori' : .0,'led' : .8}, |
| 28 | + lr_schedule = 'cosine' |
| 29 | + ): |
| 30 | + |
| 31 | + |
| 32 | + optimizer = model.optimizer(lr) |
| 33 | + |
| 34 | + lr_schedule = get_lr_scheduler(lr_schedule, optimizer, epochs, lr) |
| 35 | + |
| 36 | + _cuda_weights = {k: torch.tensor([v], device = device) for k, v in loss_weights.items()} |
| 37 | + supervised_count = torch.tensor([supervised_count + 1e-15], device=device) |
| 38 | + unsupervised_count = torch.tensor([unsupervised_count + 1e-15], device=device) |
| 39 | + |
| 40 | + for e in trange(epochs): |
| 41 | + losses = [0] * len(train_dataloader) |
| 42 | + multiple_led_losses = [[0] * 6, ] * len(train_dataloader) |
| 43 | + |
| 44 | + |
| 45 | + preds = [] |
| 46 | + theta_preds = [] |
| 47 | + dist_preds = [] |
| 48 | + |
| 49 | + trues = [] |
| 50 | + dist_trues = [] |
| 51 | + theta_trues = [] |
| 52 | + |
| 53 | + model.train() |
| 54 | + for batch_i, batch in enumerate(train_dataloader): |
| 55 | + optimizer.zero_grad() |
| 56 | + |
| 57 | + image = batch['image'].to(device) |
| 58 | + out = model.forward(image) |
| 59 | + |
| 60 | + led_loss, m_led_loss = model.loss(batch, out) |
| 61 | + |
| 62 | + loss = led_loss.sum() |
| 63 | + loss.backward() |
| 64 | + optimizer.step() |
| 65 | + |
| 66 | + losses[batch_i] = loss.detach().item() |
| 67 | + multiple_led_losses[batch_i] = [l.item() for l in m_led_loss] |
| 68 | + |
| 69 | + multiple_led_losses = np.stack(multiple_led_losses, axis = 0) |
| 70 | + |
| 71 | + mlflow.log_metric('train/loss', sum(losses), e) |
| 72 | + |
| 73 | + mlflow.log_metric('train/loss/coefficients/proj', loss_weights['pos'], e) |
| 74 | + mlflow.log_metric('train/loss/coefficients/dist', loss_weights['dist'], e) |
| 75 | + mlflow.log_metric('train/loss/coefficients/ori', loss_weights['ori'], e) |
| 76 | + mlflow.log_metric('train/loss/coefficients/led', loss_weights['led'], e) |
| 77 | + |
| 78 | + |
| 79 | + for i, led_label, in enumerate(H5Dataset.LED_TYPES): |
| 80 | + mlflow.log_metric(f'train/loss/led/{led_label}', multiple_led_losses[:, i].mean(), e) |
| 81 | + |
| 82 | + mlflow.log_metric('train/lr', lr_schedule.get_last_lr()[0], e) |
| 83 | + |
| 84 | + if e % checkpoint_logging_rate == 0 or e == epochs - 1: |
| 85 | + model.log_checkpoint(e) |
| 86 | + |
| 87 | + lr_schedule.step() |
| 88 | + |
| 89 | + |
| 90 | + |
| 91 | + if val_dataloader and (e % validation_rate == 0 or e == epochs - 1): |
| 92 | + preds = [] |
| 93 | + trues = [] |
| 94 | + losses = [] |
| 95 | + led_preds = [] |
| 96 | + led_trues = [] |
| 97 | + led_visibility = [] |
| 98 | + |
| 99 | + model.eval() |
| 100 | + |
| 101 | + with torch.no_grad(): |
| 102 | + for batch in val_dataloader: |
| 103 | + image = batch['image'].to(device) |
| 104 | + |
| 105 | + out = model.forward(image) |
| 106 | + led_loss, m_led_loss = model.loss(batch, out) |
| 107 | + mean_l_loss = led_loss.mean().detach() |
| 108 | + |
| 109 | + |
| 110 | + loss = mean_l_loss |
| 111 | + losses.append(loss.item()) |
| 112 | + |
| 113 | + pos_preds = model.predict_pos(image) |
| 114 | + preds.extend(pos_preds) |
| 115 | + trues.extend(batch['proj_uvz'][:, :-1].cpu().numpy()) |
| 116 | + |
| 117 | + led_preds.extend(model.predict_leds(batch)) |
| 118 | + led_trues.extend(batch["led_mask"]) |
| 119 | + led_visibility.extend(batch["led_visibility_mask"]) |
| 120 | + |
| 121 | + |
| 122 | + errors = np.linalg.norm(np.stack(preds) - np.stack(trues), axis = 1) |
| 123 | + mlflow.log_metric('validation/position/median_error', np.median(errors), e) |
| 124 | + |
| 125 | + led_preds = np.array(led_preds) |
| 126 | + led_trues = np.array(led_trues) |
| 127 | + led_visibility = np.array(led_visibility) |
| 128 | + |
| 129 | + |
| 130 | + aucs = [] |
| 131 | + for i, led_label in enumerate(H5Dataset.LED_TYPES): |
| 132 | + vis = led_visibility[:, i] |
| 133 | + auc = binary_auc(led_preds[vis, i], led_trues[vis, i]) |
| 134 | + mlflow.log_metric(f'validation/led/auc/{led_label}', auc, e) |
| 135 | + aucs.append(auc) |
| 136 | + mlflow.log_metric('validation/led/auc', mean(aucs), e) |
| 137 | + mlflow.log_metric('validation/loss', mean(losses), e) |
| 138 | + |
| 139 | + |
| 140 | +def main(): |
| 141 | + args = parse_args("train") |
| 142 | + |
| 143 | + if args.checkpoint_id: |
| 144 | + model, run_id = load_model_mlflow(experiment_id=args.experiment_id, mlflow_run_name=args.weights_run_name, checkpoint_idx=args.checkpoint_id, |
| 145 | + model_kwargs={'task' : args.task, 'led_inference' : args.led_inference}, return_run_id=True) |
| 146 | + model = model.to(args.device) |
| 147 | + else: |
| 148 | + model_cls = get_model(args.model_type) |
| 149 | + model = model_cls(task = args.task, led_inference = args.led_inference).to(args.device) |
| 150 | + |
| 151 | + train_dataset = train_dataset = get_dataset(args.dataset, sample_count=args.sample_count, sample_count_seed=args.sample_count_seed, augmentations=True, |
| 152 | + only_visible_robots=args.visible, compute_led_visibility=False, |
| 153 | + supervised_flagging=args.labeled_count, |
| 154 | + supervised_flagging_seed=args.labeled_count_seed, |
| 155 | + non_visible_perc=args.non_visible_perc |
| 156 | + ) |
| 157 | + print(args.batch_size) |
| 158 | + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, num_workers=8, shuffle=True) |
| 159 | + |
| 160 | + """ |
| 161 | + Validation data |
| 162 | + """ |
| 163 | + if args.validation_dataset: |
| 164 | + validation_dataset = get_dataset(args.validation_dataset, augmentations=False, only_visible_robots=True, compute_led_visibility=True) |
| 165 | + validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size = 64, num_workers=8) |
| 166 | + else: |
| 167 | + validation_dataloader = None |
| 168 | + |
| 169 | + |
| 170 | + if args.dry_run: |
| 171 | + for name, val in mlflow.__dict__.items(): |
| 172 | + if callable(val): |
| 173 | + val = lambda *args, **kwargs: (None, ) |
| 174 | + |
| 175 | + loss_weights = { |
| 176 | + 'pos' : 0., |
| 177 | + 'dist' : 0., |
| 178 | + 'ori' : 0., |
| 179 | + 'led' : 1., |
| 180 | + |
| 181 | + } |
| 182 | + ds_size = len(train_dataset) |
| 183 | + supervised_count = args.labeled_count if args.labeled_count else ds_size |
| 184 | + if 'cuda' in args.device: |
| 185 | + torch.backends.cudnn.benchmark = True |
| 186 | + |
| 187 | + with mlflow.start_run(experiment_id=args.experiment_id, run_name=args.run_name) as run: |
| 188 | + mlflow.log_params(vars(args)) |
| 189 | + train_loop(model, train_dataloader, validation_dataloader, args.device, |
| 190 | + epochs=args.epochs, lr=args.learning_rate, loss_weights = loss_weights, |
| 191 | + supervised_count=supervised_count, |
| 192 | + unsupervised_count=ds_size - supervised_count, |
| 193 | + lr_schedule=args.lr_schedule) |
| 194 | + print(run.info.run_id) |
| 195 | + |
| 196 | + |
| 197 | +if __name__ == "__main__": |
| 198 | + main() |
0 commit comments