1
+ from stylegan import G_synthesis ,G_mapping
2
+ from dataclasses import dataclass
3
+ from SphericalOptimizer import SphericalOptimizer
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import time
7
+ import torch
8
+ from loss import LossBuilder
9
+ from functools import partial
10
+ from drive import open_url
11
+
12
+
13
+ class PULSE (torch .nn .Module ):
14
+ def __init__ (self , cache_dir ):
15
+ super (PULSE , self ).__init__ ()
16
+ self .synthesis = G_synthesis ().cuda ()
17
+
18
+ cache_dir = Path (cache_dir )
19
+ cache_dir .mkdir (parents = True , exist_ok = True )
20
+
21
+ print ("Loading Synthesis Network" )
22
+ with open_url ("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8" , cache_dir = cache_dir ) as f :
23
+ self .synthesis .load_state_dict (torch .load (f ))
24
+
25
+ for param in self .synthesis .parameters ():
26
+ param .requires_grad = False
27
+
28
+ self .lrelu = torch .nn .LeakyReLU (negative_slope = 0.2 )
29
+
30
+ if Path ("gaussian_fit.pt" ).exists ():
31
+ self .gaussian_fit = torch .load ("gaussian_fit.pt" )
32
+ else :
33
+ print ("Fitting Linear Layer to Mapping Network" )
34
+ print ("\t Loading Mapping Network" )
35
+ mapping = G_mapping ().cuda ()
36
+
37
+ with open_url ("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k" , cache_dir = cache_dir ) as f :
38
+ mapping .load_state_dict (torch .load (f ))
39
+
40
+ print ("\t Running Mapping Network" )
41
+ with torch .no_grad ():
42
+ torch .manual_seed (0 )
43
+ latent = torch .randn ((1000000 ,512 ),dtype = torch .float32 , device = "cuda" )
44
+ latent_out = torch .nn .LeakyReLU (5 )(mapping (latent ))
45
+ self .gaussian_fit = {"mean" : latent_out .mean (0 ), "std" : latent_out .std (0 )}
46
+ torch .save (self .gaussian_fit ,"gaussian_fit.pt" )
47
+ print ("\t Saved \" gaussian_fit.pt\" " )
48
+
49
+ def forward (self , ref_im ,
50
+ seed ,
51
+ loss_str ,
52
+ eps ,
53
+ noise_type ,
54
+ num_trainable_noise_layers ,
55
+ tile_latent ,
56
+ bad_noise_layers ,
57
+ opt_name ,
58
+ learning_rate ,
59
+ steps ,
60
+ lr_schedule ,
61
+ save_intermediate ,
62
+ ** kwargs ):
63
+
64
+ if seed :
65
+ torch .manual_seed (seed )
66
+ torch .cuda .manual_seed (seed )
67
+ torch .backends .cudnn .deterministic = True
68
+
69
+ batch_size = ref_im .shape [0 ]
70
+
71
+ # Generate latent tensor
72
+ if (tile_latent ):
73
+ latent = torch .randn (
74
+ (batch_size , 1 , 512 ), dtype = torch .float , requires_grad = True , device = 'cuda' )
75
+ else :
76
+ latent = torch .randn (
77
+ (batch_size , 18 , 512 ), dtype = torch .float , requires_grad = True , device = 'cuda' )
78
+
79
+ # Generate list of noise tensors
80
+ noise = [] # stores all of the noise tensors
81
+ noise_vars = [] # stores the noise tensors that we want to optimize on
82
+
83
+ for i in range (18 ):
84
+ # dimension of the ith noise tensor
85
+ res = (batch_size , 1 , 2 ** (i // 2 + 2 ), 2 ** (i // 2 + 2 ))
86
+
87
+ if (noise_type == 'zero' or i in [int (layer ) for layer in bad_noise_layers .split ('.' )]):
88
+ new_noise = torch .zeros (res , dtype = torch .float , device = 'cuda' )
89
+ new_noise .requires_grad = False
90
+ elif (noise_type == 'fixed' ):
91
+ new_noise = torch .randn (res , dtype = torch .float , device = 'cuda' )
92
+ new_noise .requires_grad = False
93
+ elif (noise_type == 'trainable' ):
94
+ new_noise = torch .randn (res , dtype = torch .float , device = 'cuda' )
95
+ if (i < num_trainable_noise_layers ):
96
+ new_noise .requires_grad = True
97
+ noise_vars .append (new_noise )
98
+ else :
99
+ new_noise .requires_grad = False
100
+ else :
101
+ raise Exception ("unknown noise type" )
102
+
103
+ noise .append (new_noise )
104
+
105
+ var_list = [latent ]+ noise_vars
106
+
107
+ opt_dict = {
108
+ 'sgd' : torch .optim .SGD ,
109
+ 'adam' : torch .optim .Adam ,
110
+ 'sgdm' : partial (torch .optim .SGD , momentum = 0.9 ),
111
+ 'adamax' : torch .optim .Adamax
112
+ }
113
+ opt_func = opt_dict [opt_name ]
114
+ opt = SphericalOptimizer (opt_func , var_list , lr = learning_rate )
115
+
116
+ schedule_dict = {
117
+ 'fixed' : lambda x : 1 ,
118
+ 'linear1cycle' : lambda x : (9 * (1 - np .abs (x / steps - 1 / 2 )* 2 )+ 1 )/ 10 ,
119
+ 'linear1cycledrop' : lambda x : (9 * (1 - np .abs (x / (0.9 * steps )- 1 / 2 )* 2 )+ 1 )/ 10 if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps )/ (0.1 * steps )* (1 / 1000 - 1 / 10 ),
120
+ }
121
+ schedule_func = schedule_dict [lr_schedule ]
122
+ scheduler = torch .optim .lr_scheduler .LambdaLR (opt .opt , schedule_func )
123
+
124
+ loss_builder = LossBuilder (ref_im , loss_str , eps ).cuda ()
125
+
126
+ min_loss = np .inf
127
+ best_summary = ""
128
+ start_t = time .time ()
129
+ if (save_intermediate ):
130
+ int_HR = []
131
+ int_LR = []
132
+
133
+ print ("Optimizing" )
134
+ for j in range (steps ):
135
+ opt .opt .zero_grad ()
136
+
137
+ # Duplicate latent in case tile_latent = True
138
+ if (tile_latent ):
139
+ latent_in = latent .expand (- 1 , 18 , - 1 )
140
+ else :
141
+ latent_in = latent
142
+
143
+ # Apply learned linear mapping to match latent distribution to that of the mapping network
144
+ latent_in = self .lrelu (latent_in * self .gaussian_fit ["std" ] + self .gaussian_fit ["mean" ])
145
+
146
+ # Normalize image to [0,1] instead of [-1,1]
147
+ gen_im = (self .synthesis (latent_in , noise )+ 1 )/ 2
148
+
149
+ # Calculate Losses
150
+ loss , loss_dict = loss_builder (latent_in , gen_im )
151
+ loss_dict ['TOTAL' ] = loss
152
+
153
+ # Save intermediate HR and LR images
154
+ if (save_intermediate ):
155
+ int_HR .append (gen_im .cpu ().detach ().clamp (0 , 1 ))
156
+ int_LR .append (loss_builder .D (gen_im ).cpu ().detach ().clamp (0 , 1 ))
157
+
158
+ # Save best summary for log
159
+ if (loss < min_loss ):
160
+ min_loss = loss
161
+ best_summary = f'BEST ({ j + 1 } ) | ' + ' | ' .join (
162
+ [f'{ x } : { y :.4f} ' for x , y in loss_dict .items ()])
163
+ best_im = gen_im .clone ()
164
+
165
+ loss .backward ()
166
+ opt .step ()
167
+ scheduler .step ()
168
+
169
+ total_t = time .time ()- start_t
170
+ current_info = f' | time: { total_t :.1f} | it/s: { (j + 1 )/ total_t :.2f} | batchsize: { batch_size } '
171
+ print (best_summary + current_info )
172
+
173
+ if (save_intermediate ):
174
+ return best_im .cpu ().detach ().clamp (0 ,1 ), int_HR , int_LR
175
+ else :
176
+ return best_im .cpu ().detach ().clamp (0 ,1 )
0 commit comments