-
Notifications
You must be signed in to change notification settings - Fork 150
Expand file tree
/
Copy pathinference.py
More file actions
126 lines (102 loc) · 4.48 KB
/
inference.py
File metadata and controls
126 lines (102 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Local inference entrypoint for the clean JoyAI-Image release."""
from __future__ import annotations
import argparse
import os
import sys
import time
import warnings
from pathlib import Path
import torch
from PIL import Image
ROOT_DIR = Path(__file__).resolve().parent
SRC_DIR = ROOT_DIR / 'src'
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
warnings.filterwarnings('ignore')
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Run local inference without FastAPI.')
parser.add_argument('--ckpt-root', required=True, help='Checkpoint root.')
parser.add_argument('--prompt', required=True, help='Edit prompt or T2I prompt.')
parser.add_argument('--image', help='Optional input image path for image editing.')
parser.add_argument('--output', default='example.png', help='Output image path.')
parser.add_argument('--height', type=int, default=1024, help='Only used for text-to-image inference.')
parser.add_argument('--width', type=int, default=1024, help='Only used for text-to-image inference.')
parser.add_argument('--steps', type=int, default=50)
parser.add_argument('--guidance-scale', type=float, default=5.0)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--neg-prompt', default='')
parser.add_argument('--basesize', type=int, default=1024, help='Resize bucket base size for image editing inputs.')
parser.add_argument('--rewrite-prompt', action='store_true')
parser.add_argument('--config', help='Optional config path. Defaults to <ckpt-root>/infer_config.py.')
parser.add_argument('--rewrite-model', default='gpt-5')
parser.add_argument('--hsdp-shard-dim', type=int, help='Override config hsdp_shard_dim for multi-GPU FSDP inference.')
return parser.parse_args()
def load_input_image(image_path: str | None) -> Image.Image | None:
if not image_path:
return None
return Image.open(image_path).convert('RGB')
def is_rank0() -> bool:
return int(os.environ.get('RANK', '0')) == 0
def resolve_device() -> torch.device:
if not torch.cuda.is_available():
return torch.device('cpu')
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
torch.cuda.set_device(local_rank)
return torch.device(f'cuda:{local_rank}')
def main() -> None:
args = parse_args()
from infer_runtime.model import InferenceParams, build_model
from infer_runtime.settings import load_settings
from modules.utils import maybe_init_distributed, clean_dist_env
from modules.models.attention import describe_attention_backend
dist_initialized = False
try:
settings = load_settings(
ckpt_root=args.ckpt_root,
config_path=args.config,
rewrite_model=args.rewrite_model,
default_seed=args.seed,
)
device = resolve_device()
dist_initialized = maybe_init_distributed()
if is_rank0():
print(f'Chosen device: {device}')
print(f'Attention backend: {describe_attention_backend()}')
print(f'Config path: {settings.config_path}')
print(f'Checkpoint path: {settings.ckpt_path}')
if args.hsdp_shard_dim is not None:
print(f'Override hsdp_shard_dim: {args.hsdp_shard_dim}')
model = build_model(
settings,
device=device,
hsdp_shard_dim_override=args.hsdp_shard_dim,
)
input_image = load_input_image(args.image)
effective_prompt = model.maybe_rewrite_prompt(args.prompt, input_image, args.rewrite_prompt)
start_time = time.time()
output_image = model.infer(
InferenceParams(
prompt=effective_prompt,
image=input_image,
height=args.height,
width=args.width,
steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
neg_prompt=args.neg_prompt,
basesize=args.basesize,
)
)
elapsed = time.time() - start_time
if is_rank0():
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_image.save(output_path)
print(f'Prompt used: {effective_prompt}')
print(f'Saved output: {output_path}')
print(f'Time taken: {elapsed:.2f} seconds')
finally:
if dist_initialized:
clean_dist_env()
if __name__ == '__main__':
main()