Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions guided_diffusion/dpm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,17 @@ def data_prediction_fn(self, x, t):
if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0, t)
return x0

def data_prediction_fn_last_step(self, x, t):
"""
Return both model predictions (with corrector).
"""
noise, cal = self.model(torch.cat((self.img,x), dim=1).to(dtype = torch.float), t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - sigma_t * noise[:,0:1,:,:]) / alpha_t
if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0, t)
return x0, cal

def model_fn(self, x, t):
"""
Expand All @@ -419,6 +430,16 @@ def model_fn(self, x, t):
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)

def model_fn_last_step(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model for the last diffusion step.
"""
if self.algorithm_type == "dpmsolver++":
return self.data_prediction_fn_last_step(x, t)
else:
x = torch.cat((self.img,x), dim=1).to(dtype = torch.float)
return self.model(x, t)

def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
Expand Down Expand Up @@ -1189,10 +1210,15 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time
x = self.correcting_xt_fn(x, t, step + 1)
if return_intermediate:
intermediates.append(x)
cal = None
out = self.model(torch.cat((self.img,x), dim=1).to(dtype = torch.float), t)

out = self.model_fn_last_step(x, t)

if isinstance(out, tuple):
x, cal = out
else:
x = out
cal = None

if return_intermediate:
return x, intermediates
else:
Expand Down
2 changes: 1 addition & 1 deletion guided_diffusion/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def p_sample_loop_known(
method="multistep",
)
sample = sample.detach() ### MODIFIED: for DPM-Solver OOM issue
sample[:,-1,:,:] = norm(sample[:,-1,:,:])
#sample[:,-1,:,:] = norm(sample[:,-1,:,:])
final["sample"] = sample
final["cal"] = cal

Expand Down
28 changes: 11 additions & 17 deletions guided_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils


softmax_helper = lambda x: F.softmax(x, 1)
Expand Down Expand Up @@ -45,16 +46,12 @@ def __exit__(self, *args):
def staple(a):
# a: n,c,h,w detach tensor
mvres = mv(a)
gap = 0.4
if gap > 0.02:
for i, s in enumerate(a):
r = s * mvres
res = r if i == 0 else torch.cat((res,r),0)
nres = mv(res)
gap = torch.mean(torch.abs(mvres - nres))
mvres = nres
a = res
return mvres
for i, s in enumerate(a):
r = s * mvres
res = r if i == 0 else torch.cat((res,r),0)
nres = mv(res)
a = res
return nres

def allone(disc,cup):
disc = np.array(disc) / 255
Expand All @@ -69,10 +66,7 @@ def dice_score(pred, targs):
return 2. * (pred*targs).sum() / (pred+targs).sum()

def mv(a):
# res = Image.fromarray(np.uint8(img_list[0] / 2 + img_list[1] / 2 ))
# res.show()
b = a.size(0)
return torch.sum(a, 0, keepdim=True) / b
return torch.mean(a, dim=0, keepdim=True)

def tensor_to_img_array(tensor):
image = tensor.cpu().detach().numpy()
Expand All @@ -85,10 +79,10 @@ def export(tar, img_path=None):
if c == 3:
vutils.save_image(tar, fp = img_path)
else:
s = th.tensor(tar)[:,-1,:,:].unsqueeze(1)
s = th.cat((s,s,s),1)
s = torch.tensor(tar)[:,-1,:,:].unsqueeze(1)
s = torch.cat((s,s,s),1)
vutils.save_image(s, fp = img_path)

def norm(t):
m, s, v = torch.mean(t), torch.std(t), torch.var(t)
m, s = torch.mean(t), torch.std(t)
return (t - m) / s