From ce1eed4f44c1a66c7729e993721c0c428e1f7332 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Tue, 15 Mar 2022 17:38:42 -0700 Subject: [PATCH] Upgrade to Cog version 0.1 The [new version of Cog](https://github.com/replicate/cog/releases/tag/v0.1.0) improves the Python API, along with several other changes. This PR upgrades CLIPstyler to Cog version >= 0.1. I have already pushed this to Replicate for you, so you don't need to do anything for the demo to keep working: https://replicate.com/paper11667/clipstyler --- predict.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/predict.py b/predict.py index f83d0e7..14328f6 100644 --- a/predict.py +++ b/predict.py @@ -1,5 +1,5 @@ import sys -from pathlib import Path +from typing import Iterator import torch import torch.optim as optim from torchvision import transforms, models @@ -11,11 +11,11 @@ from template import imagenet_templates from torchvision.utils import save_image from torchvision.transforms.functional import adjust_contrast -import cog +from cog import BasePredictor, Input, Path from argparse import Namespace -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def setup(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.VGG = models.vgg19(pretrained=True).features @@ -26,10 +26,12 @@ def setup(self): self.style_net.to(self.device) self.clip_model, preprocess = clip.load('ViT-B/32', self.device, jit=False) - @cog.input("image", type=Path, help="Input image (will be cropped before style transfer)") - @cog.input("text", type=str, help="text for style transfer") - @cog.input("iterations", type=int, default=100, help="training iterations") - def predict(self, image, text, iterations): + def predict( + self, + image: Path = Input(description="Input image (will be cropped before style transfer)"), + text: str = Input(description="text for style transfer"), + iterations: int = Input(default=100, description="training iterations") + ) -> Iterator[Path]: training_args = { "lambda_tv": 2e-3, "lambda_patch": 9000, @@ -141,17 +143,17 @@ def predict(self, image, text, iterations): if epoch % 20 == 0 or epoch == steps: yield checkin(epoch, target, total_loss, content_loss, loss_patch, loss_glob, reg_tv, out_path) - return out_path + yield out_path @torch.no_grad() def checkin(epoch, target, total_loss, content_loss, loss_patch, loss_glob, reg_tv, out_path): - sys.stderr.write(f'After {epoch} iterations') - sys.stderr.write(f'Total loss: {total_loss.item()}') - sys.stderr.write(f'Content loss: {content_loss.item()}') - sys.stderr.write(f'patch loss: {loss_patch.item()}') - sys.stderr.write(f'dir loss: {loss_glob.item()}') - sys.stderr.write(f'TV loss: {reg_tv.item()}') + sys.stderr.write(f'After {epoch} iterations\n') + sys.stderr.write(f'Total loss: {total_loss.item()}\n') + sys.stderr.write(f'Content loss: {content_loss.item()}\n') + sys.stderr.write(f'patch loss: {loss_patch.item()}\n') + sys.stderr.write(f'dir loss: {loss_glob.item()}\n') + sys.stderr.write(f'TV loss: {reg_tv.item()}\n') output_image = target.clone() output_image = torch.clamp(output_image, 0, 1) output_image = adjust_contrast(output_image, 1.5)