Skip to content

Commit 1b78f1b

Browse files
author
hanjian.thu123
committed
[update] release infinity-8b
1 parent ec2f17a commit 1b78f1b

File tree

3 files changed

+230
-9
lines changed

3 files changed

+230
-9
lines changed

README.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
<p>
2121

2222
## 🔥 Updates!!
23+
* Feb 18, 2025: 🔥 Infinity-8B Weights & Code is released!
2324
* Feb 7, 2025: 🌺 Infinity-8B Demo is released! Check [demo](https://opensource.bytedance.com/gmpt/t2i/invite).
2425
* Dec 24, 2024: 🔥 Training and Testing Codes && Checkpoints && Demo released!
2526
* Dec 12, 2024: 💻 Add Project Page
@@ -30,10 +31,11 @@
3031

3132
We provide a [demo website](https://opensource.bytedance.com/gmpt/t2i/invite) for you to play with Infinity and generate images interactively. Enjoy the fun of bitwise autoregressive modeling!
3233

33-
We also provide [interactive_infer.ipynb](tools/interactive_infer.ipynb) for you to see more technical details about Infinity.
34+
We also provide [interactive_infer.ipynb](tools/interactive_infer.ipynb) and [interactive_infer_8b.ipynb](tools/interactive_infer_8b.ipynb) for you to see more technical details about Infinity-2B & Infinity-8B.
3435

3536
## 📑 Open-Source Plan
3637
- [ ] Infinity-20B Checkpoints
38+
- [x] Infinity-8B Checkpoints
3739
- [x] Training Code
3840
- [x] Web Demo
3941
- [x] Inference Code
@@ -86,16 +88,24 @@ We provide Infinity models for you to play with, which are on <a href='https://h
8688
| model | Resolution | GenEval | DPG | HPSv2.1 | HF weights🤗 |
8789
|:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|
8890
| Infinity-2B | 1024 | 0.69 / 0.73 $^{\dagger}$ | 83.5 | 32.2 | [infinity_2b_reg.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_2b_reg.pth) |
91+
| Infinity-8B | 1024 | - | - | - | [infinity_8b.pth](https://huggingface.co/FoundationVision/Infinity/tree/main/infinity_8b_weights) |
8992
| Infinity-20B | 1024 | - | - | - | [Coming Soon](TBD) |
9093

9194
${\dagger}$ result is tested with a [prompt rewriter](tools/prompt_rewriter.py).
9295

93-
You can load these models to generate images via the codes in [interactive_infer.ipynb](tools/interactive_infer.ipynb). Note: you need to download [infinity_vae_d32reg.pth](https://huggingface.co/FoundationVision/Infinity/blob/main/infinity_vae_d32reg.pth) and [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) first.
96+
You can load these models to generate images via the codes in [interactive_infer.ipynb](tools/interactive_infer.ipynb) and [interactive_infer_8b.ipynb](tools/interactive_infer_8b.ipynb) .
9497

9598

9699
## ⚽️ Installation
97100
1. We use FlexAttention to speedup training, which requires `torch>=2.5.1`.
98101
2. Install other pip packages via `pip3 install -r requirements.txt`.
102+
3. Donload weights from huggingface. Besides vae & transformers weights on <a href='https://huggingface.co/FoundationVision/infinity'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20weights-FoundationVision/Infinity-yellow'></a>, you should also download [flan-t5-xl](https://huggingface.co/google/flan-t5-xl).
103+
```
104+
from transformers import T5Tokenizer, T5ForConditionalGeneration
105+
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
106+
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")
107+
```
108+
These three lines will download flan-t5-xl to your ~/.cache/huggingface directory.
99109

100110
## 🎨 Data Preparation
101111
The structure of the training dataset is listed as bellow. The training dataset contains a list of json files with name "[h_div_w_template1]_[num_examples].jsonl". Here [h_div_w_template] is a float number, which is the template ratio of height to width of the image. [num_examples] is the number of examples where $h/w$ is around h_div_w_template. [dataset_t2i_iterable.py](infinity/dataset/dataset_t2i_iterable.py) supports traing with >100M examples. But we have to specify the number of examples for each h/w template ratio in the filename.
@@ -201,10 +211,6 @@ Infinity shows strong scaling capabilities as illustrated before. Thus we are en
201211
| a Chinese model is sitting on a train, magazine cover, clothes made of plastic, photorealistic, futuristic style, gray and green light, movie lighting, 32K HD | ![](assets/2b_8b/3l.webp) | ![](assets/2b_8b/3r.webp) |
202212
| A group of students in a class | ![](assets/2b_20b/4l.jpg) | ![](assets/2b_8b/4r.webp) |
203213

204-
205-
206-
Currently, Infinity-20B is still on the training phrase. We will release Infinity-20B once the training is completed.
207-
208214
## 📖 Citation
209215
If our work assists your research, feel free to give us a star ⭐ or cite us using:
210216

tools/interactive_infer_8b.ipynb

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 9,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import random\n",
10+
"import torch\n",
11+
"torch.cuda.set_device(2)\n",
12+
"import cv2\n",
13+
"import numpy as np\n",
14+
"from tools.run_infinity import *\n",
15+
"\n",
16+
"model_path='weights/infinity_8b_weights'\n",
17+
"vae_path='weights/infinity_vae_d56_f8_14_patchify.pth'\n",
18+
"text_encoder_ckpt = 'weights/flan-t5-xl-official'\n",
19+
"args=argparse.Namespace(\n",
20+
" pn='1M',\n",
21+
" model_path=model_path,\n",
22+
" cfg_insertion_layer=0,\n",
23+
" vae_type=14,\n",
24+
" vae_path=vae_path,\n",
25+
" add_lvl_embeding_only_first_block=1,\n",
26+
" use_bit_label=1,\n",
27+
" model_type='infinity_8b',\n",
28+
" rope2d_each_sa_layer=1,\n",
29+
" rope2d_normalized_by_hw=2,\n",
30+
" use_scale_schedule_embedding=0,\n",
31+
" sampling_per_bits=1,\n",
32+
" text_encoder_ckpt=text_encoder_ckpt,\n",
33+
" text_channels=2048,\n",
34+
" apply_spatial_patchify=1,\n",
35+
" h_div_w_template=1.000,\n",
36+
" use_flex_attn=0,\n",
37+
" cache_dir='/dev/shm',\n",
38+
" checkpoint_type='torch_shard',\n",
39+
" seed=0,\n",
40+
" bf16=1,\n",
41+
" save_file='tmp.jpg'\n",
42+
")"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": 10,
48+
"metadata": {},
49+
"outputs": [
50+
{
51+
"name": "stdout",
52+
"output_type": "stream",
53+
"text": [
54+
"[Loading tokenizer and text encoder]\n"
55+
]
56+
},
57+
{
58+
"data": {
59+
"application/vnd.jupyter.widget-view+json": {
60+
"model_id": "3f68ce998b1546f185e6263884b382ef",
61+
"version_major": 2,
62+
"version_minor": 0
63+
},
64+
"text/plain": [
65+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
66+
]
67+
},
68+
"metadata": {},
69+
"output_type": "display_data"
70+
},
71+
{
72+
"name": "stdout",
73+
"output_type": "stream",
74+
"text": [
75+
"[Loading Infinity]\n",
76+
"self.codebook_dim: 56, self.add_lvl_embeding_only_first_block: 1, self.use_bit_label: 1, self.rope2d_each_sa_layer: 1, self.rope2d_normalized_by_hw: 2\n"
77+
]
78+
},
79+
{
80+
"name": "stderr",
81+
"output_type": "stream",
82+
"text": [
83+
"/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/run_infinity.py:179: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
84+
" with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():\n"
85+
]
86+
},
87+
{
88+
"name": "stdout",
89+
"output_type": "stream",
90+
"text": [
91+
"self.num_blocks_in_a_chunk=5, depth=40, block_chunks=8\n",
92+
"\n",
93+
"[constructor] ==== customized_flash_attn=False (using_flash=0/40), fused_mlp=False (fused_mlp=0/40) ==== \n",
94+
" [Infinity config ] embed_dim=3584, num_heads=28, depth=40, mlp_ratio=4, swiglu=False num_blocks_in_a_chunk=5\n",
95+
" [drop ratios] drop_rate=0.0, drop_path_rate=0.1 (tensor([0.0000, 0.0026, 0.0051, 0.0077, 0.0103, 0.0128, 0.0154, 0.0179, 0.0205,\n",
96+
" 0.0231, 0.0256, 0.0282, 0.0308, 0.0333, 0.0359, 0.0385, 0.0410, 0.0436,\n",
97+
" 0.0462, 0.0487, 0.0513, 0.0538, 0.0564, 0.0590, 0.0615, 0.0641, 0.0667,\n",
98+
" 0.0692, 0.0718, 0.0744, 0.0769, 0.0795, 0.0821, 0.0846, 0.0872, 0.0897,\n",
99+
" 0.0923, 0.0949, 0.0974, 0.1000]))\n",
100+
"\n",
101+
"[you selected Infinity with model_kwargs={'depth': 40, 'embed_dim': 3584, 'num_heads': 28, 'drop_path_rate': 0.1, 'mlp_ratio': 4, 'block_chunks': 8}] model size: 8.38B, bf16=1\n",
102+
"[Load Infinity weights]\n"
103+
]
104+
}
105+
],
106+
"source": [
107+
"# load text encoder\n",
108+
"text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)\n",
109+
"# load vae\n",
110+
"vae = load_visual_tokenizer(args)\n",
111+
"# load infinity\n",
112+
"infinity = load_transformer(vae, args)"
113+
]
114+
},
115+
{
116+
"cell_type": "code",
117+
"execution_count": 11,
118+
"metadata": {},
119+
"outputs": [
120+
{
121+
"name": "stdout",
122+
"output_type": "stream",
123+
"text": [
124+
"prompt=a cat holds a board with the text 'diffusion is dead'\n",
125+
"cfg: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], tau: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n"
126+
]
127+
},
128+
{
129+
"name": "stderr",
130+
"output_type": "stream",
131+
"text": [
132+
"/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/run_infinity.py:112: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
133+
" with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):\n",
134+
"/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/infinity/models/basic.py:495: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
135+
" with torch.cuda.amp.autocast(enabled=False): # disable half precision\n"
136+
]
137+
},
138+
{
139+
"name": "stdout",
140+
"output_type": "stream",
141+
"text": [
142+
"cost: 1.7465496063232422, infinity cost=1.7265434265136719\n",
143+
"Save to /mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/ipynb_tmp.jpg\n"
144+
]
145+
}
146+
],
147+
"source": [
148+
"prompt = \"\"\"a cat holds a board with the text 'diffusion is dead'\"\"\"\n",
149+
"cfg = 3\n",
150+
"tau = 1.0\n",
151+
"h_div_w = 1/1 # aspect ratio, height:width\n",
152+
"seed = random.randint(0, 10000)\n",
153+
"enable_positive_prompt=0\n",
154+
"\n",
155+
"h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]\n",
156+
"scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']\n",
157+
"scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]\n",
158+
"generated_image = gen_one_img(\n",
159+
" infinity,\n",
160+
" vae,\n",
161+
" text_tokenizer,\n",
162+
" text_encoder,\n",
163+
" prompt,\n",
164+
" g_seed=seed,\n",
165+
" gt_leak=0,\n",
166+
" gt_ls_Bl=None,\n",
167+
" cfg_list=cfg,\n",
168+
" tau_list=tau,\n",
169+
" scale_schedule=scale_schedule,\n",
170+
" cfg_insertion_layer=[args.cfg_insertion_layer],\n",
171+
" vae_type=args.vae_type,\n",
172+
" sampling_per_bits=args.sampling_per_bits,\n",
173+
" enable_positive_prompt=enable_positive_prompt,\n",
174+
")\n",
175+
"args.save_file = 'ipynb_tmp.jpg'\n",
176+
"os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)\n",
177+
"cv2.imwrite(args.save_file, generated_image.cpu().numpy())\n",
178+
"print(f'Save to {osp.abspath(args.save_file)}')"
179+
]
180+
}
181+
],
182+
"metadata": {
183+
"fileId": "8ac263ab-b18c-41dc-b409-0fb0f32525f0",
184+
"filePath": "/mnt/bn/foundation-vision/hanjian.thu123/infinity/infinity/tools/interactive_infer.ipynb",
185+
"kernelspec": {
186+
"display_name": "Python 3",
187+
"language": "python",
188+
"name": "python3"
189+
},
190+
"language_info": {
191+
"codemirror_mode": {
192+
"name": "ipython",
193+
"version": 3
194+
},
195+
"file_extension": ".py",
196+
"mimetype": "text/x-python",
197+
"name": "python",
198+
"nbconvert_exporter": "python",
199+
"pygments_lexer": "ipython3",
200+
"version": "3.9.2"
201+
}
202+
},
203+
"nbformat": 4,
204+
"nbformat_minor": 2
205+
}

tools/run_infinity.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def load_infinity(
172172
apply_spatial_patchify=0,
173173
use_flex_attn=False,
174174
bf16=False,
175+
checkpoint_type='torch',
175176
):
176177
print(f'[Loading Infinity]')
177178
text_maxlen = 512
@@ -207,8 +208,12 @@ def load_infinity(
207208
torch.cuda.empty_cache()
208209

209210
print(f'[Load Infinity weights]')
210-
state_dict = torch.load(model_path, map_location=device)
211-
print(infinity_test.load_state_dict(state_dict))
211+
if checkpoint_type == 'torch':
212+
state_dict = torch.load(model_path, map_location=device)
213+
print(infinity_test.load_state_dict(state_dict))
214+
elif checkpoint_type == 'torch_shard':
215+
from transformers.modeling_utils import load_sharded_checkpoint
216+
load_sharded_checkpoint(infinity_test, model_path, strict=False)
212217
infinity_test.rng = torch.Generator(device=device)
213218
return infinity_test
214219

@@ -252,7 +257,7 @@ def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, t
252257
def load_visual_tokenizer(args):
253258
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
254259
# load vae
255-
if args.vae_type in [16,18,20,24,32,64]:
260+
if args.vae_type in [14,16,18,20,24,32,64]:
256261
from infinity.models.bsq_vae.vae import vae_model
257262
schedule_mode = "dynamic"
258263
codebook_dim = args.vae_type
@@ -304,9 +309,13 @@ def load_transformer(vae, args):
304309
else:
305310
slim_model_path = model_path
306311
print(f'load checkpoint from {slim_model_path}')
312+
elif args.checkpoint_type == 'torch_shard':
313+
slim_model_path = model_path
307314

308315
if args.model_type == 'infinity_2b':
309316
kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
317+
elif args.model_type == 'infinity_8b':
318+
kwargs_model = dict(depth=40, embed_dim=3584, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8)
310319
elif args.model_type == 'infinity_layer12':
311320
kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
312321
elif args.model_type == 'infinity_layer16':
@@ -335,6 +344,7 @@ def load_transformer(vae, args):
335344
apply_spatial_patchify=args.apply_spatial_patchify,
336345
use_flex_attn=args.use_flex_attn,
337346
bf16=args.bf16,
347+
checkpoint_type=args.checkpoint_type,
338348
)
339349
return infinity
340350

0 commit comments

Comments
 (0)