diff --git a/predict.py b/predict.py index f83d0e7..14328f6 100644 --- a/predict.py +++ b/predict.py @@ -1,5 +1,5 @@ import sys -from pathlib import Path +from typing import Iterator import torch import torch.optim as optim from torchvision import transforms, models @@ -11,11 +11,11 @@ from template import imagenet_templates from torchvision.utils import save_image from torchvision.transforms.functional import adjust_contrast -import cog +from cog import BasePredictor, Input, Path from argparse import Namespace -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def setup(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.VGG = models.vgg19(pretrained=True).features @@ -26,10 +26,12 @@ def setup(self): self.style_net.to(self.device) self.clip_model, preprocess = clip.load('ViT-B/32', self.device, jit=False) - @cog.input("image", type=Path, help="Input image (will be cropped before style transfer)") - @cog.input("text", type=str, help="text for style transfer") - @cog.input("iterations", type=int, default=100, help="training iterations") - def predict(self, image, text, iterations): + def predict( + self, + image: Path = Input(description="Input image (will be cropped before style transfer)"), + text: str = Input(description="text for style transfer"), + iterations: int = Input(default=100, description="training iterations") + ) -> Iterator[Path]: training_args = { "lambda_tv": 2e-3, "lambda_patch": 9000, @@ -141,17 +143,17 @@ def predict(self, image, text, iterations): if epoch % 20 == 0 or epoch == steps: yield checkin(epoch, target, total_loss, content_loss, loss_patch, loss_glob, reg_tv, out_path) - return out_path + yield out_path @torch.no_grad() def checkin(epoch, target, total_loss, content_loss, loss_patch, loss_glob, reg_tv, out_path): - sys.stderr.write(f'After {epoch} iterations') - sys.stderr.write(f'Total loss: {total_loss.item()}') - sys.stderr.write(f'Content loss: {content_loss.item()}') - sys.stderr.write(f'patch loss: {loss_patch.item()}') - sys.stderr.write(f'dir loss: {loss_glob.item()}') - sys.stderr.write(f'TV loss: {reg_tv.item()}') + sys.stderr.write(f'After {epoch} iterations\n') + sys.stderr.write(f'Total loss: {total_loss.item()}\n') + sys.stderr.write(f'Content loss: {content_loss.item()}\n') + sys.stderr.write(f'patch loss: {loss_patch.item()}\n') + sys.stderr.write(f'dir loss: {loss_glob.item()}\n') + sys.stderr.write(f'TV loss: {reg_tv.item()}\n') output_image = target.clone() output_image = torch.clamp(output_image, 0, 1) output_image = adjust_contrast(output_image, 1.5)