Skip to content

Commit 714c632

Browse files
author
hanjian.thu123
committed
[update] code release
1 parent 2346f5c commit 714c632

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+10876
-7
lines changed

.gitignore

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
*.swp
2+
**/__pycache__/**
3+
**/.ipynb_checkpoints/**
4+
.idea/*
5+
llava/
6+
_vis_cached/
7+
_vqgan/
8+
_vae/
9+
_vae*/
10+
ckpt/
11+
log/
12+
tb*/
13+
img*/
14+
local_output*
15+
_auto_*
16+
sd-vae-ft-mse/
17+
stable-diffusion-v1-4/
18+
*.pth
19+
*.pth.tar
20+
*.ckpt
21+
*.log
22+
*.txt
23+
*.ipynb
24+
toscli
25+
*.hydra
26+
wandb
27+
*.jsonl
28+
*.jpg
29+
*.png
30+
*.json
31+
*.csv
32+
*.tar.gz
33+
*.bin
34+
data/
35+
tmp
36+
output
37+
*.tsv
38+
*.mp4
39+
output/*
40+
results/
41+
*.JPEG
42+
debug/
43+
weights
44+
checkpoints
45+
ref.py
46+
wandb

README.md

Lines changed: 141 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Infinity $\infty$: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis
22

33
<div align="center">
4-
4+
5+
[![demo platform](https://img.shields.io/badge/Play%20with%20Infinity%21-Infinity%20demo%20platform-lightblue)](https://opensource.bytedance.com/gmpt/t2i/invite)&nbsp;
56
[![arXiv](https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages)](https://foundationvision.github.io/infinity.project/)&nbsp;
67
[![arXiv](https://img.shields.io/badge/arXiv%20paper-2412.04431-b31b1b.svg)](https://arxiv.org/abs/2412.04431)&nbsp;
7-
8+
[![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-FoundationVision/Infinity-yellow)](https://huggingface.co/FoundationVision/infinity)&nbsp;
89

910
</div>
1011
<p align="center" style="font-size: larger;">
@@ -17,21 +18,154 @@
1718
<p>
1819

1920
## 🔥 Updates!!
21+
* Dec 24, 2024: 🔥 Training and Testing Codes && Checkpoints && Demo released!
2022
* Dec 12, 2024: 💻 Add Project Page
2123
* Dec 5, 2024: 🤗 Paper release
2224

25+
## 🕹️ Try and Play with Infinity!
26+
27+
We provide a [demo website](https://opensource.bytedance.com/gmpt/t2i/invite) for you to play with Infinity and generate images interactively. Enjoy the fun of bitwise autoregressive modeling!
28+
29+
We also provide [interactive_infer.ipynb](tools/interactive_infer.ipynb) for you to see more technical details about Infinity.
30+
2331
## 📑 Open-Source Plan
32+
- [ ] Infinity-20B Checkpoints
33+
- [x] Training Code
34+
- [x] Web Demo
35+
- [x] Inference Code
36+
- [x] Infinity-2B Checkpoints
37+
- [x] VAE Checkpoints
2438

25-
- Infinity-2B (Text-to-Image Model)
26-
- [ ] Web Demo
27-
- [ ] Inference
28-
- [ ] Checkpoints
2939

3040

3141
## 📖 Introduction
32-
We present Infinity, a Bitwise Visual AutoRegressive Modeling capable of generating high-resolution, photorealistic images following language instruction. Infinity refactors visual autoregressive model under a bitwise token prediction framework with an infinite-vocabulary classifier and bitwise self-correction mechanism. By theoretically expanding the tokenizer vocabulary size to infinity in Transformer, our method significantly unleashes powerful scaling capabilities to infinity compared to vanilla VAR. Extensive experiments indicate Infinity outperforms AutoRegressive Text-to-Image models by large margins, matches or exceeds leading diffusion models. Without extra optimization, Infinity generates a 1024 $\times$ 1024 image in 0.8s, 2.6 $\times$ faster than SD3-Medium, making it the fastest Text-to-Image model. Models and codes are released to promote further exploration of Infinity for visual generation.
42+
We present Infinity, a Bitwise Visual AutoRegressive Modeling capable of generating high-resolution and photorealistic images. Infinity redefines visual autoregressive model under a bitwise token prediction framework with an infinite-vocabulary tokenizer & classifier and bitwise self-correction. Theoretically scaling the tokenizer vocabulary size to infinity and concurrently scaling the transformer size, our method significantly unleashes powerful scaling capabilities. Infinity sets a new record for autoregressive text-to-image models, outperforming top-tier diffusion models like SD3-Medium and SDXL. Notably, Infinity surpasses SD3-Medium by improving the GenEval benchmark score from 0.62 to 0.73 and the ImageReward benchmark score from 0.87 to 0.96, achieving a win rate of 66%. Without extra optimization, Infinity generates a high-quality 1024×1024 image in 0.8 seconds, making it 2.6× faster than SD3-Medium and establishing it as the fastest text-to-image model.
43+
44+
### 🔥 Redefines VAR under a bitwise token prediction framework 🚀:
45+
46+
<p align="center">
47+
<img src="assets/framework_row.png" width=95%>
48+
<p>
49+
50+
Infinite-Vocabulary Tokenizer✨: We proposes a new bitwise multi-scale residual quantizer, which significantly reduces memory usage, enabling the training of extremely large vocabulary, e.g. $V_d = 2^{32}$ or $V_d = 2^{64}$.
51+
52+
Infinite-Vocabulary Classifier✨: Conventional classifier predicts $2^d$ indices. IVC predicts $d$ bits instead. Slight perturbations to near-zero values in continuous features cause a complete change of indices labels. Bit labels change subtly and still provide steady supervision. Besides, if d = 32 and h = 2048, a conventional classifier requires 8.8T parameters. IVC only requires 0.13M.
53+
54+
Bitwise Self-Correction✨: Teacher-forcing training in AR brings severe train-test discrepancy. It lets the transformer only refine features without recognizing and correcting mistakes. Mistakes will be propagated and amplified, finally messing up generated images. We propose Bitwise Self-Correction (BSC) to mitigate the train-test discrepancy.
55+
56+
### 🔥 Scaling Vocabulary benefits Reconstruction and Generation 📈:
57+
58+
<p align="center">
59+
<img src="assets/scaling_vocabulary.png" width=95%>
60+
<p>
61+
62+
### 🔥 Discovering Scaling Laws in Infinity transformers 📈:
63+
64+
<p align="center">
65+
<img src="assets/scaling_models.png" width=95%>
66+
<p>
67+
68+
## Infinity Model ZOO
69+
We provide Infinity models for you to play with, which are on <a href='https://huggingface.co/FoundationVision/infinity'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20weights-FoundationVision/Infinity-yellow'></a> or can be downloaded from the following links:
70+
71+
### Visual Tokenizer
72+
73+
| vocabulary | stride | IN-256 rFID $\downarrow$ | IN-256 PSNR $\uparrow$ | IN-512 rFID $\downarrow$ | IN-512 PSNR $\uparrow$ | HF weights🤗 |
74+
|:----------:|:-----:|:--------:|:---------:|:-------:|:-------:|:------------------------------------------------------------------------------------|
75+
| $V_d=2^{16}$ | 16 | 1.22 | 20.9 | 0.31 | 22.6 | [infinity_vae_d16.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d16.pth) |
76+
| $V_d=2^{24}$ | 16 | 0.75 | 22.0 | 0.30 | 23.5 | [infinity_vae_d24.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d24.pth) |
77+
| $V_d=2^{32}$ | 16 | 0.61 | 22.7 | 0.23 | 24.4 | [infinity_vae_d32.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d32.pth) |
78+
| $V_d=2^{64}$ | 16 | 0.33 | 24.9 | 0.15 | 26.4 | [infinity_vae_d64.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d64.pth) |
79+
| $V_d=2^{32}$ | 16 | 0.75 | 21.9 | 0.32 | 23.6 | [infinity_vae_d32_reg.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d32_reg.pth) |
80+
81+
### Infinity
82+
| model | Resolution | GenEval | DPG | HPSv2.1 | HF weights🤗 |
83+
|:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|
84+
| Infinity-2B | 1024 | 0.69 / 0.73 $^{\dagger}$ | 83.5 | 32.2 | [infinity_2B.pth](https://huggingface.co/FoundationVision/var/resolve/main/infinity_2b_reg.pth) |
85+
| Infinity-20B | 1024 | - | - | - | [Coming Soon](TBD) |
86+
87+
${\dagger}$ result is tested with a [prompt rewriter](tools/prompt_rewriter.py).
88+
89+
You can load these models to generate images via the codes in [interactive_infer.ipynb](tools/interactive_infer.ipynb). Note: you need to download [infinity_vae_d32reg.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) and [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) first.
90+
91+
92+
## Installation
93+
1. We use FlexAttention to speedup training, which requires `torch>=2.5.1`.
94+
2. Install other pip packages via `pip3 install -r requirements.txt`.
95+
96+
## Data Preparation
97+
The structure of the training dataset is listed as bellow. The training dataset contains a list of json files with name "[h_div_w_template1]_[num_examples].jsonl". Here [h_div_w_template] is a float number, which is the template ratio of height to width of the image. [num_examples] is the number of examples where $h/w$ is around h_div_w_template. [dataset_t2i_iterable.py](infinity/dataset/dataset_t2i_iterable.py) supports traing with >100M examples. But we have to specify the number of examples for each h/w template ratio in the filename.
98+
99+
```
100+
/path/to/dataset/:
101+
[h_div_w_template1]_[num_examples].jsonl
102+
[h_div_w_template2]_[num_examples].jsonl
103+
[h_div_w_template3]_[num_examples].jsonl
104+
```
105+
106+
Each "[h_div_w_template1]_[num_examples].jsonl" file contains lines of dumped json item. Each json item contains the following information:
107+
```
108+
{
109+
"image_path": "path/to/image, required",
110+
"h_div_w": "float value of h_div_w for the image, required",
111+
"long_caption": long_caption of the image, required",
112+
"long_caption_type": "InternVL 2.0, required",
113+
"short_caption": "short of the image, optional",
114+
"short_caption_type": "user prompt, , optional"
115+
}
116+
```
117+
118+
Still have questions about the data preparation? Easy, we have provided a toy dataset with 10 images. You can prepare your dataset by referring [this](data/infinity_toy_data).
119+
120+
121+
## Training Scripts
122+
We provide [train.sh](scripts/train.sh) for train Infinity-2B with one command
123+
```shell
124+
bash scripts/train.sh
125+
```
126+
127+
To train Infinity with different model sizes {125M, 1B, 2B} and different {256/512/1024} resolutions, you can run the following command:
128+
```shell
129+
# 125M, layer12, pixel number = 256 x 256 = 0.06M Pixels
130+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
131+
--model=layer12c4 --pn 0.06M --exp_name=infinity_125M_pn_0.06M \
132+
# 1B, layer24, pixel number = 256 x 256 = 0.06M Pixels
133+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
134+
--model=layer24c4 --pn 0.06M --exp_name=infinity_1B_pn_0.06M \
135+
# 2B, layer32, pixel number = 256 x 256 = 0.06M Pixels
136+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
137+
--model=2bc8 --pn 0.06M --exp_name=infinity_2B_pn_0.06M \
138+
# 2B, layer32, pixel number = 512 x 512 = 0.25M Pixels
139+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
140+
--model=2bc8 --pn 0.25M --exp_name=infinity_2B_pn_0.25M \
141+
# 2B, layer32, pixel number = 1024 x 1024 = 1M Pixels
142+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
143+
--model=2bc8 --pn 1M --exp_name=infinity_2B_pn_1M \
144+
```
145+
A folder named `local_output` will be created to save the checkpoints and logs.
146+
You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`. We highly recommend you use [wandb](https://wandb.ai/site/) for detailed logging.
147+
148+
If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth`.
149+
150+
## Evaluation
151+
We provide [eval.sh](scripts/eval.sh) for evaluation on various benchmarks with only one command. In particular, [eval.sh](scripts/eval.sh) supports evaluation on commonly used metrics such as [GenEval](https://github.com/djghosh13/geneval), [ImageReward](https://github.com/THUDM/ImageReward), [HPSv2.1](https://github.com/tgxs002/HPSv2), FID and Validation Loss. Please refer to [evaluation/README.md](evaluation/README.md) for more details.
152+
```shell
153+
bash scripts/eval.sh
154+
```
155+
156+
## One More Thing: Infinity-20B is coming soon 📆
157+
Infinity shows strong scaling capabilities as illustrated before. Thus we are encouraged to continue to scale up the model size to 20B. Here we present the side-by-side comparison results between Infinity-2B and Infinity-20B.
158+
159+
| Prompt | Infinity (# params=2B) | Infinity (# params=20B) |
160+
| ------------ | -------- | -------- |
161+
| Create an image with the text "Always Priority" on a wooden sign | ![](assets/2b_20b/1l.jpg) | ![](assets/2b_20b/1r.jpg) |
162+
| Show the text 'Driver Unknown Hard Clearly' in a surreal, imaginative style with a dreamlike landscape backdrop. | ![](assets/2b_20b/2l.jpg) | ![](assets/2b_20b/2r.jpg) |
163+
| A photograph of a quaint two-story house with a distinctive red-tiled gable roof. The house is painted in a light, sandy color, which contrasts with the vibrant red roof. | ![](assets/2b_20b/3l.jpg) | ![](assets/2b_20b/3r.jpg) |
164+
| A group of students in a class | ![](assets/2b_20b/4l.jpg) | ![](assets/2b_20b/4r.jpg) |
165+
33166

34167

168+
Currently, Infinity-20B is still on the training phrase. We will release Infinity-20B once the training is completed.
35169

36170
## License
37171
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

assets/2b_20b/.DS_Store

6 KB
Binary file not shown.

conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
HF_TOKEN = '[YOUR HF_TOKEN]'
2+
HF_HOME = '[YOUR HF_HOME]'
3+
4+
GPT_AK = '[YOUR GPT_AK]'

evaluation/README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Overview
2+
We provide [eval.sh](scripts/eval.sh) for evaluation on various benchmarks with only one command. In particular, [eval.sh](scripts/eval.sh) supports evaluation on commonly used metrics such as [GenEval](https://github.com/djghosh13/geneval), [ImageReward](https://github.com/THUDM/ImageReward), [HPSv2.1](https://github.com/tgxs002/HPSv2), FID and Validation Loss.
3+
4+
# Usage
5+
6+
7+
## Basic Configuration
8+
9+
```shell
10+
# set arguments
11+
pn=1M
12+
model_type=infinity_2b
13+
infinity_model_path=[infinity_model_path]
14+
out_dir_root=[out_dir_root]
15+
vae_type=32
16+
vae_path=[vae_path]
17+
cfg=4
18+
tau=1
19+
text_encoder_ckpt=[text_encoder_ckpt]
20+
text_channels=2048
21+
sub_fix=cfg${cfg}_tau${tau}
22+
```
23+
24+
25+
## ImageReward
26+
[ImageReward](https://github.com/THUDM/ImageReward) is a metric for evaluating the human preference score of generated images. It learns human preference through fine-tuning CLIP model with 137K human ranked image pairs.
27+
```shell
28+
out_dir=${out_dir_root}/image_reward_${sub_fix}
29+
infer_eval_image_reward
30+
```
31+
32+
## HPS v2.1
33+
[HPSv2.1](https://github.com/tgxs002/HPSv2) is a metric for evaluating the human preference score of generated images. It learns human preference through fine-tuning CLIP model with 798K human ranked image pairs. The human ranked image pairs are from human experts.
34+
```shell
35+
out_dir=${out_dir_root}/hpsv21_${sub_fix}
36+
infer_eval_hpsv21
37+
```
38+
39+
## GenEval
40+
[GenEval](https://github.com/djghosh13/geneval) is an object-focused framework for evaluating Text-to-Image alignment.
41+
```shell
42+
rewrite_prompt=0
43+
out_dir=${out_dir_root}/gen_eval_${sub_fix}
44+
test_gen_eval
45+
```
46+
47+
## FID
48+
For testing FID, you need provide a jsonl file which contains text prompts and ground truth images. We highly recommand the number of examples in the jsonl file is greater than 20000 since testing FID needs abundant of examples.
49+
```shell
50+
long_caption_fid=1
51+
jsonl_filepath=[jsonl_filepath]
52+
out_dir=${out_dir_root}/val_long_caption_fid_${sub_fix}
53+
rm -rf ${out_dir}
54+
test_fid
55+
```
56+
57+
## Validation Loss
58+
For testing Validation Loss, you need provide a jsonl folder like the training jsonl folder. Besides, you should specify the noise applying strength for Bitwise Self-Correction to the same strength used in the training phrase.
59+
```shell
60+
out_dir=${out_dir_root}/val_loss_${sub_fix}
61+
reweight_loss_by_scale=0
62+
jsonl_folder=[jsonl_folder]
63+
noise_apply_strength=0.2
64+
test_val_loss
65+
```
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# dataset settings
2+
dataset_type = 'CocoPanopticDataset'
3+
data_root = 'data/coco/'
4+
img_norm_cfg = dict(
5+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6+
train_pipeline = [
7+
dict(type='LoadImageFromFile'),
8+
dict(
9+
type='LoadPanopticAnnotations',
10+
with_bbox=True,
11+
with_mask=True,
12+
with_seg=True),
13+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
14+
dict(type='RandomFlip', flip_ratio=0.5),
15+
dict(type='Normalize', **img_norm_cfg),
16+
dict(type='Pad', size_divisor=32),
17+
dict(type='SegRescale', scale_factor=1 / 4),
18+
dict(type='DefaultFormatBundle'),
19+
dict(
20+
type='Collect',
21+
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
22+
]
23+
test_pipeline = [
24+
dict(type='LoadImageFromFile'),
25+
dict(
26+
type='MultiScaleFlipAug',
27+
img_scale=(1333, 800),
28+
flip=False,
29+
transforms=[
30+
dict(type='Resize', keep_ratio=True),
31+
dict(type='RandomFlip'),
32+
dict(type='Normalize', **img_norm_cfg),
33+
dict(type='Pad', size_divisor=32),
34+
dict(type='ImageToTensor', keys=['img']),
35+
dict(type='Collect', keys=['img']),
36+
])
37+
]
38+
data = dict(
39+
samples_per_gpu=2,
40+
workers_per_gpu=2,
41+
train=dict(
42+
type=dataset_type,
43+
ann_file=data_root + 'annotations/panoptic_train2017.json',
44+
img_prefix=data_root + 'train2017/',
45+
seg_prefix=data_root + 'annotations/panoptic_train2017/',
46+
pipeline=train_pipeline),
47+
val=dict(
48+
type=dataset_type,
49+
ann_file=data_root + 'annotations/panoptic_val2017.json',
50+
img_prefix=data_root + 'val2017/',
51+
seg_prefix=data_root + 'annotations/panoptic_val2017/',
52+
pipeline=test_pipeline),
53+
test=dict(
54+
type=dataset_type,
55+
ann_file=data_root + 'annotations/panoptic_val2017.json',
56+
img_prefix=data_root + 'val2017/',
57+
seg_prefix=data_root + 'annotations/panoptic_val2017/',
58+
pipeline=test_pipeline))
59+
evaluation = dict(interval=1, metric=['PQ'])
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
checkpoint_config = dict(interval=1)
2+
# yapf:disable
3+
log_config = dict(
4+
interval=50,
5+
hooks=[
6+
dict(type='TextLoggerHook'),
7+
# dict(type='TensorboardLoggerHook')
8+
])
9+
# yapf:enable
10+
custom_hooks = [dict(type='NumClassCheckHook')]
11+
12+
dist_params = dict(backend='nccl')
13+
log_level = 'INFO'
14+
load_from = None
15+
resume_from = None
16+
workflow = [('train', 1)]
17+
18+
# disable opencv multithreading to avoid system being overloaded
19+
opencv_num_threads = 0
20+
# set multi-process start method as `fork` to speed up the training
21+
mp_start_method = 'fork'
22+
23+
# Default setting for scaling LR automatically
24+
# - `enable` means enable scaling LR automatically
25+
# or not by default.
26+
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
27+
auto_scale_lr = dict(enable=False, base_batch_size=16)

0 commit comments

Comments
 (0)