Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MMSIG] Support StableSR Algorithm Reproduction #1941

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions configs/stablesr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# StableSR

> [Exploiting Diffusion Prior for Real-World Image Super-Resolution](https://arxiv.org/abs/2305.07015)

> **Task**: Image Super-Resolution

<!-- [ALGORITHM] -->

## Abstract

<!-- [ABSTRACT] -->

We present a novel approach to leverage prior knowledge encapsulated in pre-trained text-to-image diffusion models for blind super-resolution (SR). Specifically, by employing our time-aware encoder, we can achieve promising restoration results without altering the pre-trained synthesis model, thereby preserving the generative prior and minimizing training cost. To remedy the loss of fidelity caused by the inherent stochasticity of diffusion models, we introduce a controllable feature wrapping module that allows users to balance quality and fidelity by simply adjusting a scalar value during the inference process. Moreover, we develop a progressive aggregation sampling strategy to overcome the fixed-size constraints of pre-trained diffusion models, enabling adaptation to resolutions of any size. A comprehensive evaluation of our method using both synthetic and realworld benchmarks demonstrates its superiority over current state-of-the-art approaches.
42 changes: 42 additions & 0 deletions configs/stablesr/stablesr_512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# 0_Deploy_MFNR/0_SeeBetter/configs/controlnet/controlnet-canny.py
# config for model
stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5'
controlnet_canny_url = 'lllyasviel/sd-controlnet-canny'

model = dict(
type='ControlStableDiffusion',
# vae=dict(type='AutoencoderKL', sample_size=64),
vae=dict(
type='AutoencoderKL',
from_pretrained=stable_diffusion_v15_url,
subfolder='vae'),
# unet=dict(
# sample_size=64,
# type='UNet2DConditionModel',
# down_block_types=('DownBlock2D', ),
# up_block_types=('UpBlock2D', ),
# block_out_channels=(32, ),
# cross_attention_dim=16,
# ),
unet=dict(
type='UNet2DConditionModel',
subfolder='unet',
from_pretrained=stable_diffusion_v15_url),
text_encoder=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_v15_url,
subfolder='text_encoder'),
tokenizer=stable_diffusion_v15_url,
controlnet=dict(
type='ControlNetModel', from_pretrained=controlnet_canny_url),
scheduler=dict(
type='DDPMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
test_scheduler=dict(
type='DDIMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
data_preprocessor=dict(type='DataPreprocessor'),
init_cfg=dict(type='init_from_unet'))
111 changes: 56 additions & 55 deletions mmagic/models/diffusion_schedulers/ddim_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,32 @@

@DIFFUSION_SCHEDULERS.register_module()
class EditDDIMScheduler:
"""```EditDDIMScheduler``` support the diffusion and reverse process
formulated in https://arxiv.org/abs/2010.02502.

The code is heavily influenced by https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py. # noqa
The difference is that we ensemble gradient-guided sampling in step function.

Args:
num_train_timesteps (int, optional): _description_. Defaults to 1000.
beta_start (float, optional): _description_. Defaults to 0.0001.
beta_end (float, optional): _description_. Defaults to 0.02.
beta_schedule (str, optional): _description_. Defaults to "linear".
variance_type (str, optional): _description_. Defaults to 'learned_range'.
timestep_values (_type_, optional): _description_. Defaults to None.
clip_sample (bool, optional): _description_. Defaults to True.
set_alpha_to_one (bool, optional): _description_. Defaults to True.
"""

def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule='linear',
variance_type='learned_range',
timestep_values=None,
clip_sample=True,
set_alpha_to_one=True,
):

def __init__(self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = 'linear',
variance_type: str = 'learned_range',
timestep_values=None,
clip_sample: bool = True,
set_alpha_to_one=True):
"""```EditDDIMScheduler``` support the diffusion and reverse process
formulated in https://arxiv.org/abs/2010.02502.

The code is heavily influenced by https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py. # noqa
The difference is that we ensemble gradient-guided sampling in step function.

Args:
num_train_timesteps (int, optional): _description_. Defaults to 1000.
beta_start (float, optional): _description_. Defaults to 0.0001.
beta_end (float, optional): _description_. Defaults to 0.02.
beta_schedule (str, optional): _description_. Defaults to "linear".
variance_type (str, optional): _description_. Defaults to 'learned_range'.
timestep_values (_type_, optional): _description_. Defaults to None.
clip_sample (bool, optional): _description_. Defaults to True.
set_alpha_to_one (bool, optional): _description_. Defaults to True.
"""
self.num_train_timesteps = num_train_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
Expand Down Expand Up @@ -93,22 +91,6 @@ def set_timesteps(self, num_inference_steps, offset=0):
self.num_train_timesteps // self.num_inference_steps)[::-1].copy()
self.timesteps += offset

def scale_model_input(self,
sample: torch.FloatTensor,
timestep: Optional[int] = None) -> torch.FloatTensor:
"""Ensures interchangeability with schedulers that need to scale the
denoising model input depending on the current timestep.

Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep

Returns:
`torch.FloatTensor`: scaled input sample
"""

return sample

def _get_variance(self, timestep, prev_timestep):
"""get variance."""

Expand All @@ -121,17 +103,15 @@ def _get_variance(self, timestep, prev_timestep):
beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance

def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
cond_fn=None,
cond_kwargs={},
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
):
def step(self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
cond_fn=None,
cond_kwargs={},
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None):
"""step forward."""

output = {}
Expand Down Expand Up @@ -250,10 +230,31 @@ def add_noise(self, original_samples, noise, timesteps):

sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5

if not isinstance(sqrt_alpha_prod, float):
sqrt_alpha_prod = float(sqrt_alpha_prod)
if not isinstance(sqrt_one_minus_alpha_prod, float):
sqrt_one_minus_alpha_prod = float(sqrt_one_minus_alpha_prod)
noisy_samples = (
sqrt_alpha_prod * original_samples +
sqrt_one_minus_alpha_prod * noise)
return noisy_samples

def scale_model_input(self,
sample: torch.FloatTensor,
timestep: Optional[int] = None) -> torch.FloatTensor:
"""Ensures interchangeability with schedulers that need to scale the
denoising model input depending on the current timestep.

Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep

Returns:
`torch.FloatTensor`: scaled input sample
"""

return sample

def __len__(self):
return self.num_train_timesteps
19 changes: 9 additions & 10 deletions mmagic/models/diffusion_schedulers/ddpm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def __init__(self,
beta_end: float = 0.02,
beta_schedule: str = 'linear',
trained_betas: Optional[Union[np.array, list]] = None,
variance_type='fixed_small',
clip_sample=True):
variance_type: str = 'fixed_small',
clip_sample: bool = True):
"""```EditDDPMScheduler``` support the diffusion and reverse process
formulated in https://arxiv.org/abs/2006.11239.

Expand Down Expand Up @@ -46,6 +46,8 @@ def __init__(self,
original image (x0) to [-1, 1]. Defaults to True.
"""
self.num_train_timesteps = num_train_timesteps
self.variance_type = variance_type
self.clip_sample = clip_sample
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == 'linear':
Expand Down Expand Up @@ -74,18 +76,16 @@ def __init__(self,
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

self.variance_type = variance_type
self.clip_sample = clip_sample

def set_timesteps(self, num_inference_steps):
"""set timesteps."""
def set_timesteps(self, num_inference_steps, offset=0):
"""set time steps."""

num_inference_steps = min(self.num_train_timesteps,
num_inference_steps)
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.num_train_timesteps,
self.num_train_timesteps // self.num_inference_steps)[::-1].copy()
self.timesteps += offset

def _get_variance(self, t, predicted_variance=None, variance_type=None):
"""get variance."""
Expand Down Expand Up @@ -133,13 +133,12 @@ def step(self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
predict_epsilon=True,
predict_epsilon: bool = True,
cond_fn=None,
cond_kwargs={},
generator=None):

"""step forward."""
t = timestep
"""step forward"""

if model_output.shape[1] == sample.shape[
1] * 2 and self.variance_type in ['learned', 'learned_range']:
Expand Down
Loading