|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 9, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [], |
| 8 | + "source": [ |
| 9 | + "import random\n", |
| 10 | + "import torch\n", |
| 11 | + "torch.cuda.set_device(2)\n", |
| 12 | + "import cv2\n", |
| 13 | + "import numpy as np\n", |
| 14 | + "from tools.run_infinity import *\n", |
| 15 | + "\n", |
| 16 | + "model_path='weights/infinity_8b_weights'\n", |
| 17 | + "vae_path='weights/infinity_vae_d56_f8_14_patchify.pth'\n", |
| 18 | + "text_encoder_ckpt = 'weights/flan-t5-xl-official'\n", |
| 19 | + "args=argparse.Namespace(\n", |
| 20 | + " pn='1M',\n", |
| 21 | + " model_path=model_path,\n", |
| 22 | + " cfg_insertion_layer=0,\n", |
| 23 | + " vae_type=14,\n", |
| 24 | + " vae_path=vae_path,\n", |
| 25 | + " add_lvl_embeding_only_first_block=1,\n", |
| 26 | + " use_bit_label=1,\n", |
| 27 | + " model_type='infinity_8b',\n", |
| 28 | + " rope2d_each_sa_layer=1,\n", |
| 29 | + " rope2d_normalized_by_hw=2,\n", |
| 30 | + " use_scale_schedule_embedding=0,\n", |
| 31 | + " sampling_per_bits=1,\n", |
| 32 | + " text_encoder_ckpt=text_encoder_ckpt,\n", |
| 33 | + " text_channels=2048,\n", |
| 34 | + " apply_spatial_patchify=1,\n", |
| 35 | + " h_div_w_template=1.000,\n", |
| 36 | + " use_flex_attn=0,\n", |
| 37 | + " cache_dir='/dev/shm',\n", |
| 38 | + " checkpoint_type='torch_shard',\n", |
| 39 | + " seed=0,\n", |
| 40 | + " bf16=1,\n", |
| 41 | + " save_file='tmp.jpg'\n", |
| 42 | + ")" |
| 43 | + ] |
| 44 | + }, |
| 45 | + { |
| 46 | + "cell_type": "code", |
| 47 | + "execution_count": 10, |
| 48 | + "metadata": {}, |
| 49 | + "outputs": [ |
| 50 | + { |
| 51 | + "name": "stdout", |
| 52 | + "output_type": "stream", |
| 53 | + "text": [ |
| 54 | + "[Loading tokenizer and text encoder]\n" |
| 55 | + ] |
| 56 | + }, |
| 57 | + { |
| 58 | + "data": { |
| 59 | + "application/vnd.jupyter.widget-view+json": { |
| 60 | + "model_id": "3f68ce998b1546f185e6263884b382ef", |
| 61 | + "version_major": 2, |
| 62 | + "version_minor": 0 |
| 63 | + }, |
| 64 | + "text/plain": [ |
| 65 | + "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" |
| 66 | + ] |
| 67 | + }, |
| 68 | + "metadata": {}, |
| 69 | + "output_type": "display_data" |
| 70 | + }, |
| 71 | + { |
| 72 | + "name": "stdout", |
| 73 | + "output_type": "stream", |
| 74 | + "text": [ |
| 75 | + "[Loading Infinity]\n", |
| 76 | + "self.codebook_dim: 56, self.add_lvl_embeding_only_first_block: 1, self.use_bit_label: 1, self.rope2d_each_sa_layer: 1, self.rope2d_normalized_by_hw: 2\n" |
| 77 | + ] |
| 78 | + }, |
| 79 | + { |
| 80 | + "name": "stderr", |
| 81 | + "output_type": "stream", |
| 82 | + "text": [ |
| 83 | + "/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/run_infinity.py:179: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", |
| 84 | + " with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():\n" |
| 85 | + ] |
| 86 | + }, |
| 87 | + { |
| 88 | + "name": "stdout", |
| 89 | + "output_type": "stream", |
| 90 | + "text": [ |
| 91 | + "self.num_blocks_in_a_chunk=5, depth=40, block_chunks=8\n", |
| 92 | + "\n", |
| 93 | + "[constructor] ==== customized_flash_attn=False (using_flash=0/40), fused_mlp=False (fused_mlp=0/40) ==== \n", |
| 94 | + " [Infinity config ] embed_dim=3584, num_heads=28, depth=40, mlp_ratio=4, swiglu=False num_blocks_in_a_chunk=5\n", |
| 95 | + " [drop ratios] drop_rate=0.0, drop_path_rate=0.1 (tensor([0.0000, 0.0026, 0.0051, 0.0077, 0.0103, 0.0128, 0.0154, 0.0179, 0.0205,\n", |
| 96 | + " 0.0231, 0.0256, 0.0282, 0.0308, 0.0333, 0.0359, 0.0385, 0.0410, 0.0436,\n", |
| 97 | + " 0.0462, 0.0487, 0.0513, 0.0538, 0.0564, 0.0590, 0.0615, 0.0641, 0.0667,\n", |
| 98 | + " 0.0692, 0.0718, 0.0744, 0.0769, 0.0795, 0.0821, 0.0846, 0.0872, 0.0897,\n", |
| 99 | + " 0.0923, 0.0949, 0.0974, 0.1000]))\n", |
| 100 | + "\n", |
| 101 | + "[you selected Infinity with model_kwargs={'depth': 40, 'embed_dim': 3584, 'num_heads': 28, 'drop_path_rate': 0.1, 'mlp_ratio': 4, 'block_chunks': 8}] model size: 8.38B, bf16=1\n", |
| 102 | + "[Load Infinity weights]\n" |
| 103 | + ] |
| 104 | + } |
| 105 | + ], |
| 106 | + "source": [ |
| 107 | + "# load text encoder\n", |
| 108 | + "text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)\n", |
| 109 | + "# load vae\n", |
| 110 | + "vae = load_visual_tokenizer(args)\n", |
| 111 | + "# load infinity\n", |
| 112 | + "infinity = load_transformer(vae, args)" |
| 113 | + ] |
| 114 | + }, |
| 115 | + { |
| 116 | + "cell_type": "code", |
| 117 | + "execution_count": 11, |
| 118 | + "metadata": {}, |
| 119 | + "outputs": [ |
| 120 | + { |
| 121 | + "name": "stdout", |
| 122 | + "output_type": "stream", |
| 123 | + "text": [ |
| 124 | + "prompt=a cat holds a board with the text 'diffusion is dead'\n", |
| 125 | + "cfg: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], tau: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n" |
| 126 | + ] |
| 127 | + }, |
| 128 | + { |
| 129 | + "name": "stderr", |
| 130 | + "output_type": "stream", |
| 131 | + "text": [ |
| 132 | + "/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/run_infinity.py:112: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", |
| 133 | + " with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):\n", |
| 134 | + "/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/infinity/models/basic.py:495: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", |
| 135 | + " with torch.cuda.amp.autocast(enabled=False): # disable half precision\n" |
| 136 | + ] |
| 137 | + }, |
| 138 | + { |
| 139 | + "name": "stdout", |
| 140 | + "output_type": "stream", |
| 141 | + "text": [ |
| 142 | + "cost: 1.7465496063232422, infinity cost=1.7265434265136719\n", |
| 143 | + "Save to /mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/ipynb_tmp.jpg\n" |
| 144 | + ] |
| 145 | + } |
| 146 | + ], |
| 147 | + "source": [ |
| 148 | + "prompt = \"\"\"a cat holds a board with the text 'diffusion is dead'\"\"\"\n", |
| 149 | + "cfg = 3\n", |
| 150 | + "tau = 1.0\n", |
| 151 | + "h_div_w = 1/1 # aspect ratio, height:width\n", |
| 152 | + "seed = random.randint(0, 10000)\n", |
| 153 | + "enable_positive_prompt=0\n", |
| 154 | + "\n", |
| 155 | + "h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]\n", |
| 156 | + "scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']\n", |
| 157 | + "scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]\n", |
| 158 | + "generated_image = gen_one_img(\n", |
| 159 | + " infinity,\n", |
| 160 | + " vae,\n", |
| 161 | + " text_tokenizer,\n", |
| 162 | + " text_encoder,\n", |
| 163 | + " prompt,\n", |
| 164 | + " g_seed=seed,\n", |
| 165 | + " gt_leak=0,\n", |
| 166 | + " gt_ls_Bl=None,\n", |
| 167 | + " cfg_list=cfg,\n", |
| 168 | + " tau_list=tau,\n", |
| 169 | + " scale_schedule=scale_schedule,\n", |
| 170 | + " cfg_insertion_layer=[args.cfg_insertion_layer],\n", |
| 171 | + " vae_type=args.vae_type,\n", |
| 172 | + " sampling_per_bits=args.sampling_per_bits,\n", |
| 173 | + " enable_positive_prompt=enable_positive_prompt,\n", |
| 174 | + ")\n", |
| 175 | + "args.save_file = 'ipynb_tmp.jpg'\n", |
| 176 | + "os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)\n", |
| 177 | + "cv2.imwrite(args.save_file, generated_image.cpu().numpy())\n", |
| 178 | + "print(f'Save to {osp.abspath(args.save_file)}')" |
| 179 | + ] |
| 180 | + } |
| 181 | + ], |
| 182 | + "metadata": { |
| 183 | + "fileId": "8ac263ab-b18c-41dc-b409-0fb0f32525f0", |
| 184 | + "filePath": "/mnt/bn/foundation-vision/hanjian.thu123/infinity/infinity/tools/interactive_infer.ipynb", |
| 185 | + "kernelspec": { |
| 186 | + "display_name": "Python 3", |
| 187 | + "language": "python", |
| 188 | + "name": "python3" |
| 189 | + }, |
| 190 | + "language_info": { |
| 191 | + "codemirror_mode": { |
| 192 | + "name": "ipython", |
| 193 | + "version": 3 |
| 194 | + }, |
| 195 | + "file_extension": ".py", |
| 196 | + "mimetype": "text/x-python", |
| 197 | + "name": "python", |
| 198 | + "nbconvert_exporter": "python", |
| 199 | + "pygments_lexer": "ipython3", |
| 200 | + "version": "3.9.2" |
| 201 | + } |
| 202 | + }, |
| 203 | + "nbformat": 4, |
| 204 | + "nbformat_minor": 2 |
| 205 | +} |
0 commit comments