Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from pathlib import Path
from typing import Iterator
import torch
import torch.optim as optim
from torchvision import transforms, models
Expand All @@ -11,11 +11,11 @@
from template import imagenet_templates
from torchvision.utils import save_image
from torchvision.transforms.functional import adjust_contrast
import cog
from cog import BasePredictor, Input, Path
from argparse import Namespace


class Predictor(cog.Predictor):
class Predictor(BasePredictor):
def setup(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.VGG = models.vgg19(pretrained=True).features
Expand All @@ -26,10 +26,12 @@ def setup(self):
self.style_net.to(self.device)
self.clip_model, preprocess = clip.load('ViT-B/32', self.device, jit=False)

@cog.input("image", type=Path, help="Input image (will be cropped before style transfer)")
@cog.input("text", type=str, help="text for style transfer")
@cog.input("iterations", type=int, default=100, help="training iterations")
def predict(self, image, text, iterations):
def predict(
self,
image: Path = Input(description="Input image (will be cropped before style transfer)"),
text: str = Input(description="text for style transfer"),
iterations: int = Input(default=100, description="training iterations")
) -> Iterator[Path]:
training_args = {
"lambda_tv": 2e-3,
"lambda_patch": 9000,
Expand Down Expand Up @@ -141,17 +143,17 @@ def predict(self, image, text, iterations):
if epoch % 20 == 0 or epoch == steps:
yield checkin(epoch, target, total_loss, content_loss, loss_patch, loss_glob, reg_tv, out_path)

return out_path
yield out_path


@torch.no_grad()
def checkin(epoch, target, total_loss, content_loss, loss_patch, loss_glob, reg_tv, out_path):
sys.stderr.write(f'After {epoch} iterations')
sys.stderr.write(f'Total loss: {total_loss.item()}')
sys.stderr.write(f'Content loss: {content_loss.item()}')
sys.stderr.write(f'patch loss: {loss_patch.item()}')
sys.stderr.write(f'dir loss: {loss_glob.item()}')
sys.stderr.write(f'TV loss: {reg_tv.item()}')
sys.stderr.write(f'After {epoch} iterations\n')
sys.stderr.write(f'Total loss: {total_loss.item()}\n')
sys.stderr.write(f'Content loss: {content_loss.item()}\n')
sys.stderr.write(f'patch loss: {loss_patch.item()}\n')
sys.stderr.write(f'dir loss: {loss_glob.item()}\n')
sys.stderr.write(f'TV loss: {reg_tv.item()}\n')
output_image = target.clone()
output_image = torch.clamp(output_image, 0, 1)
output_image = adjust_contrast(output_image, 1.5)
Expand Down