Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
runs
wandb
wandb
yolov7/torpedoes_2025_a-1/test/labels/

yolov7/torpedoes_2025_a-1/train/labels/

yolov7/torpedoes_2025_a-1/valid/labels/

yolov7/torpedoes_2025_a-1/

*.url
48 changes: 26 additions & 22 deletions yolov7/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
logger = logging.getLogger(__name__)


def train(hyp, opt, device, tb_writer=None):
def train(hyp, opt, device, tb_writer=None, log=False):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank, freeze = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, opt.freeze
Expand All @@ -65,15 +65,16 @@ def train(hyp, opt, device, tb_writer=None):
is_coco = opt.data.endswith('coco.yaml')

# Logging- Doing this before checking the dataset. Might update data_dict
loggers = {'wandb': None} # loggers dict
if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights, map_location=device).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
if log:
loggers = {'wandb': None} # loggers dict
if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights, map_location=device, weights_only=False).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming

nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
Expand All @@ -84,7 +85,7 @@ def train(hyp, opt, device, tb_writer=None):
if pretrained:
with torch_distributed_zero_first(rank):
attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
ckpt = torch.load(weights, map_location=device, weights_only=False) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32
Expand Down Expand Up @@ -394,7 +395,7 @@ def train(hyp, opt, device, tb_writer=None):
# if tb_writer:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph
elif plots and ni == 10 and wandb_logger.wandb:
elif plots and ni == 10 and log and wandb_logger.wandb:
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})

Expand All @@ -411,7 +412,8 @@ def train(hyp, opt, device, tb_writer=None):
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
if log:
wandb_logger.current_epoch = epoch + 1
results, maps, times = test.test(data_dict,
batch_size=batch_size * 2,
imgsz=imgsz_test,
Expand All @@ -421,7 +423,7 @@ def train(hyp, opt, device, tb_writer=None):
save_dir=save_dir,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
wandb_logger=wandb_logger if log else None,
compute_loss=compute_loss,
is_coco=is_coco,
v5_metric=opt.v5_metric)
Expand All @@ -440,14 +442,15 @@ def train(hyp, opt, device, tb_writer=None):
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if tb_writer:
tb_writer.add_scalar(tag, x, epoch) # tensorboard
if wandb_logger.wandb:
if log and wandb_logger.wandb:
wandb_logger.log({tag: x}) # W&B

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
if fi > best_fitness:
best_fitness = fi
wandb_logger.end_epoch(best_result=best_fitness == fi)
if log:
wandb_logger.end_epoch(best_result=best_fitness == fi)

# Save model
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
Expand All @@ -458,7 +461,8 @@ def train(hyp, opt, device, tb_writer=None):
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
'wandb_id': wandb_logger.wandb_run.id if (wandb_logger.wandb and log) else None
}

# Save last, best and delete
torch.save(ckpt, last)
Expand All @@ -472,7 +476,7 @@ def train(hyp, opt, device, tb_writer=None):
torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))
elif epoch >= (epochs-5):
torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))
if wandb_logger.wandb:
if log and wandb_logger.wandb:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model(
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
Expand All @@ -484,7 +488,7 @@ def train(hyp, opt, device, tb_writer=None):
# Plots
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb_logger.wandb:
if log and wandb_logger.wandb:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
Expand Down Expand Up @@ -513,7 +517,7 @@ def train(hyp, opt, device, tb_writer=None):
strip_optimizer(f) # strip optimizers
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
if wandb_logger.wandb and not opt.evolve: # Log the stripped model
if wandb_logger.wandb and log and not opt.evolve: # Log the stripped model
wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['last', 'best', 'stripped'])
Expand Down Expand Up @@ -573,8 +577,8 @@ def train(hyp, opt, device, tb_writer=None):
# check_requirements()

# Resume
wandb_run = check_wandb_resume(opt)
if opt.resume and not wandb_run: # resume an interrupted run
# wandb_run = check_wandb_resume(opt)
if opt.resume: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank
Expand Down
2 changes: 1 addition & 1 deletion yolov7/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
self.path = 'gate_lyps_2025-1/images'
#self.albumentations = Albumentations() if augment else None

try:
Expand Down
2 changes: 1 addition & 1 deletion yolov7/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
return image_weights


def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
def coco80_to_coco91_class(): # converts 80-index (#2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
Expand Down