Skip to content

Commit aeb20c8

Browse files
committed
release training codes and config files.
1 parent e501cd0 commit aeb20c8

21 files changed

+3154
-7
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ S-Lab, Nanyang Technological University
2020

2121
:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
2222

23-
**[<font color=#d1585d>News</font>]**: :whale: *We regret to inform you that the release of our code will be postponed from its earlier plan. Nevertheless, we assure you that it will be made available **by the end of this April**. Thank you for your understanding and patience. Our apologies for any inconvenience this may cause.*
23+
2424
### Update
25+
- **2023.04.19**: :whale: Training codes and config files are public available now.
2526
- **2023.04.09**: Add features of inpainting and colorization for cropped and aligned face images.
2627
- **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity.
2728
- **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper:
@@ -30,7 +31,7 @@ S-Lab, Nanyang Technological University
3031
- [**More**](docs/history_changelog.md)
3132

3233
### TODO
33-
- [ ] Add training code and config files
34+
- [x] Add training code and config files
3435
- [x] Add checkpoint and script for face inpainting
3536
- [x] Add checkpoint and script for face colorization
3637
- [x] ~~Add background image enhancement~~
@@ -77,13 +78,13 @@ conda install -c conda-forge dlib (only for face detection or cropping with dlib
7778
### Quick Inference
7879

7980
#### Download Pre-trained Models:
80-
Download the facelib and dlib pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases) | [Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command:
81+
Download the facelib and dlib pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command:
8182
```
8283
python scripts/download_pretrained_models.py facelib
8384
python scripts/download_pretrained_models.py dlib (only for dlib face detector)
8485
```
8586

86-
Download the CodeFormer pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases) | [Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command:
87+
Download the CodeFormer pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command:
8788
```
8889
python scripts/download_pretrained_models.py CodeFormer
8990
```
@@ -141,7 +142,8 @@ python inference_colorization.py --input_path [image folder]|[image path]
141142
# (check out the examples in inputs/masked_faces)
142143
python inference_inpainting.py --input_path [image folder]|[image path]
143144
```
144-
145+
#### Training:
146+
You can find training commands in training documents: [English](docs/train.md) **|** [简体中文](docs/train_CN.md).
145147

146148
### Citation
147149
If our work is useful for your research, please consider citing:
@@ -162,4 +164,4 @@ This project is licensed under <a rel="license" href="https://github.com/sczhou/
162164
This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
163165

164166
### Contact
165-
If you have any questions, please feel free to reach me out at `[email protected]`.
167+
If you have any questions, please feel free to reach me out at `[email protected]`.

basicsr/archs/codeformer_arch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,13 @@ class CodeFormer(VQAutoEncoder):
162162
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163163
codebook_size=1024, latent_size=256,
164164
connect_list=['32', '64', '128', '256'],
165-
fix_modules=['quantize','generator']):
165+
fix_modules=['quantize','generator'], vqgan_path=None):
166166
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
167167

168+
if vqgan_path is not None:
169+
self.load_state_dict(
170+
torch.load(vqgan_path, map_location='cpu')['params_ema'])
171+
168172
if fix_modules is not None:
169173
for module in fix_modules:
170174
for param in getattr(self, module).parameters():

basicsr/data/data_util.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import cv2
2+
import math
23
import numpy as np
34
import torch
45
from os import path as osp
6+
from PIL import Image, ImageDraw
57
from torch.nn import functional as F
68

79
from basicsr.data.transforms import mod_crop
@@ -303,3 +305,88 @@ def duf_downsample(x, kernel_size=13, scale=4):
303305
if squeeze_flag:
304306
x = x.squeeze(0)
305307
return x
308+
309+
310+
def brush_stroke_mask(img, color=(255,255,255)):
311+
min_num_vertex = 8
312+
max_num_vertex = 28
313+
mean_angle = 2*math.pi / 5
314+
angle_range = 2*math.pi / 12
315+
# training large mask ratio (training setting)
316+
min_width = 30
317+
max_width = 70
318+
# very large mask ratio (test setting and refine after 200k)
319+
# min_width = 80
320+
# max_width = 120
321+
def generate_mask(H, W, img=None):
322+
average_radius = math.sqrt(H*H+W*W) / 8
323+
mask = Image.new('RGB', (W, H), 0)
324+
if img is not None: mask = img # Image.fromarray(img)
325+
326+
for _ in range(np.random.randint(1, 4)):
327+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
328+
angle_min = mean_angle - np.random.uniform(0, angle_range)
329+
angle_max = mean_angle + np.random.uniform(0, angle_range)
330+
angles = []
331+
vertex = []
332+
for i in range(num_vertex):
333+
if i % 2 == 0:
334+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
335+
else:
336+
angles.append(np.random.uniform(angle_min, angle_max))
337+
338+
h, w = mask.size
339+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
340+
for i in range(num_vertex):
341+
r = np.clip(
342+
np.random.normal(loc=average_radius, scale=average_radius//2),
343+
0, 2*average_radius)
344+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
345+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
346+
vertex.append((int(new_x), int(new_y)))
347+
348+
draw = ImageDraw.Draw(mask)
349+
width = int(np.random.uniform(min_width, max_width))
350+
draw.line(vertex, fill=color, width=width)
351+
for v in vertex:
352+
draw.ellipse((v[0] - width//2,
353+
v[1] - width//2,
354+
v[0] + width//2,
355+
v[1] + width//2),
356+
fill=color)
357+
358+
return mask
359+
360+
width, height = img.size
361+
mask = generate_mask(height, width, img)
362+
return mask
363+
364+
365+
def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
366+
"""Generate a random free form mask with configuration.
367+
Args:
368+
config: Config should have configuration including IMG_SHAPES,
369+
VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
370+
Returns:
371+
tuple: (top, left, height, width)
372+
Link:
373+
https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
374+
"""
375+
height = shape[0]
376+
width = shape[1]
377+
mask = np.zeros((height, width), np.float32)
378+
times = np.random.randint(times-5, times)
379+
for i in range(times):
380+
start_x = np.random.randint(width)
381+
start_y = np.random.randint(height)
382+
for j in range(1 + np.random.randint(5)):
383+
angle = 0.01 + np.random.randint(max_angle)
384+
if i % 2 == 0:
385+
angle = 2 * 3.1415926 - angle
386+
length = 10 + np.random.randint(max_len-20, max_len)
387+
brush_w = 5 + np.random.randint(max_width-30, max_width)
388+
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
389+
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
390+
cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
391+
start_x, start_y = end_x, end_y
392+
return mask.astype(np.float32)

0 commit comments

Comments
 (0)