diff --git a/train.py b/train.py index 1c7914e..a615c54 100644 --- a/train.py +++ b/train.py @@ -218,7 +218,7 @@ def main(args): start_time = time() # Labels to condition the model with (feel free to change): - ys = torch.randint(1000, size=(local_batch_size,), device=device) + ys = torch.randint(args.num_classes, size=(local_batch_size,), device=device) use_cfg = args.cfg_scale > 1.0 # Create sampling noise: n = ys.size(0) @@ -227,7 +227,7 @@ def main(args): # Setup classifier-free guidance: if use_cfg: zs = torch.cat([zs, zs], 0) - y_null = torch.tensor([1000] * n, device=device) + y_null = torch.tensor([args.num_classes] * n, device=device) ys = torch.cat([ys, y_null], 0) sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale) model_fn = ema.forward_with_cfg