diff --git a/guided_diffusion/dpm_solver.py b/guided_diffusion/dpm_solver.py index ce91f1f..fa83b24 100644 --- a/guided_diffusion/dpm_solver.py +++ b/guided_diffusion/dpm_solver.py @@ -410,6 +410,17 @@ def data_prediction_fn(self, x, t): if self.correcting_x0_fn is not None: x0 = self.correcting_x0_fn(x0, t) return x0 + + def data_prediction_fn_last_step(self, x, t): + """ + Return both model predictions (with corrector). + """ + noise, cal = self.model(torch.cat((self.img,x), dim=1).to(dtype = torch.float), t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise[:,0:1,:,:]) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0, cal def model_fn(self, x, t): """ @@ -419,6 +430,16 @@ def model_fn(self, x, t): return self.data_prediction_fn(x, t) else: return self.noise_prediction_fn(x, t) + + def model_fn_last_step(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model for the last diffusion step. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn_last_step(x, t) + else: + x = torch.cat((self.img,x), dim=1).to(dtype = torch.float) + return self.model(x, t) def get_time_steps(self, skip_type, t_T, t_0, N, device): """Compute the intermediate time steps for sampling. @@ -1189,10 +1210,15 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time x = self.correcting_xt_fn(x, t, step + 1) if return_intermediate: intermediates.append(x) - cal = None - out = self.model(torch.cat((self.img,x), dim=1).to(dtype = torch.float), t) + + out = self.model_fn_last_step(x, t) + if isinstance(out, tuple): x, cal = out + else: + x = out + cal = None + if return_intermediate: return x, intermediates else: diff --git a/guided_diffusion/gaussian_diffusion.py b/guided_diffusion/gaussian_diffusion.py index 47a8ac3..d30c628 100644 --- a/guided_diffusion/gaussian_diffusion.py +++ b/guided_diffusion/gaussian_diffusion.py @@ -550,7 +550,7 @@ def p_sample_loop_known( method="multistep", ) sample = sample.detach() ### MODIFIED: for DPM-Solver OOM issue - sample[:,-1,:,:] = norm(sample[:,-1,:,:]) + #sample[:,-1,:,:] = norm(sample[:,-1,:,:]) final["sample"] = sample final["cal"] = cal diff --git a/guided_diffusion/utils.py b/guided_diffusion/utils.py index 7c112eb..b044ed1 100644 --- a/guided_diffusion/utils.py +++ b/guided_diffusion/utils.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torchvision.utils as vutils softmax_helper = lambda x: F.softmax(x, 1) @@ -45,16 +46,12 @@ def __exit__(self, *args): def staple(a): # a: n,c,h,w detach tensor mvres = mv(a) - gap = 0.4 - if gap > 0.02: - for i, s in enumerate(a): - r = s * mvres - res = r if i == 0 else torch.cat((res,r),0) - nres = mv(res) - gap = torch.mean(torch.abs(mvres - nres)) - mvres = nres - a = res - return mvres + for i, s in enumerate(a): + r = s * mvres + res = r if i == 0 else torch.cat((res,r),0) + nres = mv(res) + a = res + return nres def allone(disc,cup): disc = np.array(disc) / 255 @@ -69,10 +66,7 @@ def dice_score(pred, targs): return 2. * (pred*targs).sum() / (pred+targs).sum() def mv(a): - # res = Image.fromarray(np.uint8(img_list[0] / 2 + img_list[1] / 2 )) - # res.show() - b = a.size(0) - return torch.sum(a, 0, keepdim=True) / b + return torch.mean(a, dim=0, keepdim=True) def tensor_to_img_array(tensor): image = tensor.cpu().detach().numpy() @@ -85,10 +79,10 @@ def export(tar, img_path=None): if c == 3: vutils.save_image(tar, fp = img_path) else: - s = th.tensor(tar)[:,-1,:,:].unsqueeze(1) - s = th.cat((s,s,s),1) + s = torch.tensor(tar)[:,-1,:,:].unsqueeze(1) + s = torch.cat((s,s,s),1) vutils.save_image(s, fp = img_path) def norm(t): - m, s, v = torch.mean(t), torch.std(t), torch.var(t) + m, s = torch.mean(t), torch.std(t) return (t - m) / s