Skip to content

Commit ac22914

Browse files
committed
Initial refactor
0 parents  commit ac22914

10 files changed

+921
-0
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
.DS_Store
2+
.idea/*
3+
runs/*
4+
input/*
5+
cache/*

PULSE.py

+176
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from stylegan import G_synthesis,G_mapping
2+
from dataclasses import dataclass
3+
from SphericalOptimizer import SphericalOptimizer
4+
from pathlib import Path
5+
import numpy as np
6+
import time
7+
import torch
8+
from loss import LossBuilder
9+
from functools import partial
10+
from drive import open_url
11+
12+
13+
class PULSE(torch.nn.Module):
14+
def __init__(self, cache_dir):
15+
super(PULSE, self).__init__()
16+
self.synthesis = G_synthesis().cuda()
17+
18+
cache_dir = Path(cache_dir)
19+
cache_dir.mkdir(parents=True, exist_ok = True)
20+
21+
print("Loading Synthesis Network")
22+
with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir) as f:
23+
self.synthesis.load_state_dict(torch.load(f))
24+
25+
for param in self.synthesis.parameters():
26+
param.requires_grad = False
27+
28+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2)
29+
30+
if Path("gaussian_fit.pt").exists():
31+
self.gaussian_fit = torch.load("gaussian_fit.pt")
32+
else:
33+
print("Fitting Linear Layer to Mapping Network")
34+
print("\tLoading Mapping Network")
35+
mapping = G_mapping().cuda()
36+
37+
with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir) as f:
38+
mapping.load_state_dict(torch.load(f))
39+
40+
print("\tRunning Mapping Network")
41+
with torch.no_grad():
42+
torch.manual_seed(0)
43+
latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda")
44+
latent_out = torch.nn.LeakyReLU(5)(mapping(latent))
45+
self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)}
46+
torch.save(self.gaussian_fit,"gaussian_fit.pt")
47+
print("\tSaved \"gaussian_fit.pt\"")
48+
49+
def forward(self, ref_im,
50+
seed,
51+
loss_str,
52+
eps,
53+
noise_type,
54+
num_trainable_noise_layers,
55+
tile_latent,
56+
bad_noise_layers,
57+
opt_name,
58+
learning_rate,
59+
steps,
60+
lr_schedule,
61+
save_intermediate,
62+
**kwargs):
63+
64+
if seed:
65+
torch.manual_seed(seed)
66+
torch.cuda.manual_seed(seed)
67+
torch.backends.cudnn.deterministic = True
68+
69+
batch_size = ref_im.shape[0]
70+
71+
# Generate latent tensor
72+
if(tile_latent):
73+
latent = torch.randn(
74+
(batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda')
75+
else:
76+
latent = torch.randn(
77+
(batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda')
78+
79+
# Generate list of noise tensors
80+
noise = [] # stores all of the noise tensors
81+
noise_vars = [] # stores the noise tensors that we want to optimize on
82+
83+
for i in range(18):
84+
# dimension of the ith noise tensor
85+
res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))
86+
87+
if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]):
88+
new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
89+
new_noise.requires_grad = False
90+
elif(noise_type == 'fixed'):
91+
new_noise = torch.randn(res, dtype=torch.float, device='cuda')
92+
new_noise.requires_grad = False
93+
elif (noise_type == 'trainable'):
94+
new_noise = torch.randn(res, dtype=torch.float, device='cuda')
95+
if (i < num_trainable_noise_layers):
96+
new_noise.requires_grad = True
97+
noise_vars.append(new_noise)
98+
else:
99+
new_noise.requires_grad = False
100+
else:
101+
raise Exception("unknown noise type")
102+
103+
noise.append(new_noise)
104+
105+
var_list = [latent]+noise_vars
106+
107+
opt_dict = {
108+
'sgd': torch.optim.SGD,
109+
'adam': torch.optim.Adam,
110+
'sgdm': partial(torch.optim.SGD, momentum=0.9),
111+
'adamax': torch.optim.Adamax
112+
}
113+
opt_func = opt_dict[opt_name]
114+
opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)
115+
116+
schedule_dict = {
117+
'fixed': lambda x: 1,
118+
'linear1cycle': lambda x: (9*(1-np.abs(x/steps-1/2)*2)+1)/10,
119+
'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*steps)-1/2)*2)+1)/10 if x < 0.9*steps else 1/10 + (x-0.9*steps)/(0.1*steps)*(1/1000-1/10),
120+
}
121+
schedule_func = schedule_dict[lr_schedule]
122+
scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)
123+
124+
loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()
125+
126+
min_loss = np.inf
127+
best_summary = ""
128+
start_t = time.time()
129+
if(save_intermediate):
130+
int_HR = []
131+
int_LR = []
132+
133+
print("Optimizing")
134+
for j in range(steps):
135+
opt.opt.zero_grad()
136+
137+
# Duplicate latent in case tile_latent = True
138+
if (tile_latent):
139+
latent_in = latent.expand(-1, 18, -1)
140+
else:
141+
latent_in = latent
142+
143+
# Apply learned linear mapping to match latent distribution to that of the mapping network
144+
latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"])
145+
146+
# Normalize image to [0,1] instead of [-1,1]
147+
gen_im = (self.synthesis(latent_in, noise)+1)/2
148+
149+
# Calculate Losses
150+
loss, loss_dict = loss_builder(latent_in, gen_im)
151+
loss_dict['TOTAL'] = loss
152+
153+
# Save intermediate HR and LR images
154+
if(save_intermediate):
155+
int_HR.append(gen_im.cpu().detach().clamp(0, 1))
156+
int_LR.append(loss_builder.D(gen_im).cpu().detach().clamp(0, 1))
157+
158+
# Save best summary for log
159+
if(loss < min_loss):
160+
min_loss = loss
161+
best_summary = f'BEST ({j+1}) | '+' | '.join(
162+
[f'{x}: {y:.4f}' for x, y in loss_dict.items()])
163+
best_im = gen_im.clone()
164+
165+
loss.backward()
166+
opt.step()
167+
scheduler.step()
168+
169+
total_t = time.time()-start_t
170+
current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
171+
print(best_summary+current_info)
172+
173+
if(save_intermediate):
174+
return best_im.cpu().detach().clamp(0,1), int_HR, int_LR
175+
else:
176+
return best_im.cpu().detach().clamp(0,1)

SphericalOptimizer.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import math
2+
import torch
3+
from torch.optim import Optimizer
4+
5+
# Spherical Optimizer Class
6+
# Uses the first two dimensions as batch information
7+
# Optimizes over the surface of a sphere using the initial radius throughout
8+
#
9+
# Example Usage:
10+
# opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01)
11+
12+
class SphericalOptimizer(Optimizer):
13+
def __init__(self, optimizer, params, **kwargs):
14+
self.opt = optimizer(params, **kwargs)
15+
self.params = params
16+
with torch.no_grad():
17+
self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params}
18+
19+
@torch.no_grad()
20+
def step(self, closure=None):
21+
loss = self.opt.step(closure)
22+
for param in self.params:
23+
param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
24+
param.mul_(self.radii[param])
25+
26+
return loss

bicubic.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
5+
6+
class BicubicDownSample(nn.Module):
7+
def bicubic_kernel(self, x, a=-0.50):
8+
"""
9+
This equation is exactly copied from the website below:
10+
https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
11+
"""
12+
abs_x = torch.abs(x)
13+
if abs_x <= 1.:
14+
return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
15+
elif 1. < abs_x < 2.:
16+
return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
17+
else:
18+
return 0.0
19+
20+
def __init__(self, factor=4, cuda=True, padding='reflect'):
21+
super().__init__()
22+
self.factor = factor
23+
size = factor * 4
24+
k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
25+
for i in range(size)], dtype=torch.float32)
26+
k = k / torch.sum(k)
27+
# k = torch.einsum('i,j->ij', (k, k))
28+
k1 = torch.reshape(k, shape=(1, 1, size, 1))
29+
self.k1 = torch.cat([k1, k1, k1], dim=0)
30+
k2 = torch.reshape(k, shape=(1, 1, 1, size))
31+
self.k2 = torch.cat([k2, k2, k2], dim=0)
32+
self.cuda = '.cuda' if cuda else ''
33+
self.padding = padding
34+
for param in self.parameters():
35+
param.requires_grad = False
36+
37+
def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
38+
# x = torch.from_numpy(x).type('torch.FloatTensor')
39+
filter_height = self.factor * 4
40+
filter_width = self.factor * 4
41+
stride = self.factor
42+
43+
pad_along_height = max(filter_height - stride, 0)
44+
pad_along_width = max(filter_width - stride, 0)
45+
filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
46+
filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
47+
48+
# compute actual padding values for each side
49+
pad_top = pad_along_height // 2
50+
pad_bottom = pad_along_height - pad_top
51+
pad_left = pad_along_width // 2
52+
pad_right = pad_along_width - pad_left
53+
54+
# apply mirror padding
55+
if nhwc:
56+
x = torch.transpose(torch.transpose(
57+
x, 2, 3), 1, 2) # NHWC to NCHW
58+
59+
# downscaling performed by 1-d convolution
60+
x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
61+
x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
62+
if clip_round:
63+
x = torch.clamp(torch.round(x), 0.0, 255.)
64+
65+
x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
66+
x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
67+
if clip_round:
68+
x = torch.clamp(torch.round(x), 0.0, 255.)
69+
70+
if nhwc:
71+
x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
72+
if byte_output:
73+
return x.type('torch.ByteTensor'.format(self.cuda))
74+
else:
75+
return x

drive.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# URL helpers, see https://github.com/NVlabs/stylegan
2+
# ------------------------------------------------------------------------------------------
3+
4+
import requests
5+
import html
6+
import hashlib
7+
import glob
8+
import os
9+
import io
10+
from typing import Any
11+
import re
12+
import uuid
13+
14+
def is_url(obj: Any) -> bool:
15+
"""Determine whether the given object is a valid URL string."""
16+
if not isinstance(obj, str) or not "://" in obj:
17+
return False
18+
try:
19+
res = requests.compat.urlparse(obj)
20+
if not res.scheme or not res.netloc or not "." in res.netloc:
21+
return False
22+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
23+
if not res.scheme or not res.netloc or not "." in res.netloc:
24+
return False
25+
except:
26+
return False
27+
return True
28+
29+
30+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
31+
"""Download the given URL and return a binary-mode file object to access the data."""
32+
assert is_url(url)
33+
assert num_attempts >= 1
34+
35+
# Lookup from cache.
36+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
37+
if cache_dir is not None:
38+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
39+
if len(cache_files) == 1:
40+
return open(cache_files[0], "rb")
41+
42+
# Download.
43+
url_name = None
44+
url_data = None
45+
with requests.Session() as session:
46+
if verbose:
47+
print("Downloading %s ..." % url, end="", flush=True)
48+
for attempts_left in reversed(range(num_attempts)):
49+
try:
50+
with session.get(url) as res:
51+
res.raise_for_status()
52+
if len(res.content) == 0:
53+
raise IOError("No data received")
54+
55+
if len(res.content) < 8192:
56+
content_str = res.content.decode("utf-8")
57+
if "download_warning" in res.headers.get("Set-Cookie", ""):
58+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
59+
if len(links) == 1:
60+
url = requests.compat.urljoin(url, links[0])
61+
raise IOError("Google Drive virus checker nag")
62+
if "Google Drive - Quota exceeded" in content_str:
63+
raise IOError("Google Drive quota exceeded")
64+
65+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
66+
url_name = match[1] if match else url
67+
url_data = res.content
68+
if verbose:
69+
print(" done")
70+
break
71+
except:
72+
if not attempts_left:
73+
if verbose:
74+
print(" failed")
75+
raise
76+
if verbose:
77+
print(".", end="", flush=True)
78+
79+
# Save to cache.
80+
if cache_dir is not None:
81+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
82+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
83+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
84+
os.makedirs(cache_dir, exist_ok=True)
85+
with open(temp_file, "wb") as f:
86+
f.write(url_data)
87+
os.replace(temp_file, cache_file) # atomic
88+
89+
# Return data as file object.
90+
return io.BytesIO(url_data)

gaussian_fit.pt

4.46 KB
Binary file not shown.

0 commit comments

Comments
 (0)