Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion engines/hf/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from termcolor import cprint
from transformers import AutoModel, AutoProcessor, AutoTokenizer

from mineru_diffusion.utils.runtime import maybe_disable_flash_attention, resolve_torch_dtype


STOP_STRINGS = ("<|endoftext|>", "<|im_end|>")
SYSTEM_PROMPT = "You are a helpful assistant."
Expand Down Expand Up @@ -70,7 +72,21 @@ def run(args: argparse.Namespace) -> None:
prompt = args.prompt or TASK_PROMPTS[args.prompt_type]
model_path = Path(args.model_path).resolve()
device = args.device
dtype = getattr(torch, args.dtype)
flash_attn_disabled = maybe_disable_flash_attention(device)
dtype, resolved_dtype_name = resolve_torch_dtype(device, args.dtype)

if flash_attn_disabled:
cprint(
"FlashAttention disabled for this GPU; using PyTorch SDPA fallback.",
color="yellow",
flush=True,
)
if resolved_dtype_name != args.dtype:
cprint(
f"CUDA device does not support {args.dtype}; falling back to {resolved_dtype_name}.",
color="yellow",
flush=True,
)

_print_summary(args, model_path, device, dtype)

Expand Down
56 changes: 56 additions & 0 deletions mineru_diffusion/utils/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os

import torch


def _parse_device(device: str) -> torch.device | None:
try:
return torch.device(device)
except (RuntimeError, TypeError, ValueError):
return None


def should_disable_flash_attention(device: str) -> bool:
if os.environ.get("MINERU_DISABLE_FLASH_ATTN") == "1":
return True

torch_device = _parse_device(device)
if torch_device is None or torch_device.type != "cuda" or not torch.cuda.is_available():
return False

try:
major, _ = torch.cuda.get_device_capability(torch_device)
except Exception:
return False
return major < 8


def maybe_disable_flash_attention(device: str) -> bool:
if not should_disable_flash_attention(device):
return False

os.environ["MINERU_DISABLE_FLASH_ATTN"] = "1"

try:
import flash_attn
except ImportError:
return True

flash_attn.flash_attn_func = None
return True


def resolve_torch_dtype(device: str, requested_dtype: str) -> tuple[torch.dtype, str]:
torch_device = _parse_device(device)
resolved_dtype = requested_dtype

if (
requested_dtype == "bfloat16"
and torch_device is not None
and torch_device.type == "cuda"
and torch.cuda.is_available()
and not torch.cuda.is_bf16_supported()
):
resolved_dtype = "float16"

return getattr(torch, resolved_dtype), resolved_dtype
15 changes: 14 additions & 1 deletion scripts/run_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
sys.path.insert(0, str(REPO_DIR))

from mineru_diffusion.utils.bbox import draw_bbox
from mineru_diffusion.utils.runtime import maybe_disable_flash_attention, resolve_torch_dtype


STOP_STRINGS = ("<|endoftext|>", "<|im_end|>")
Expand Down Expand Up @@ -369,13 +370,25 @@ def __init__(
) -> None:
self.model_path = model_path
self.device = device
self.torch_dtype = getattr(torch, dtype)
flash_attn_disabled = maybe_disable_flash_attention(device)
self.torch_dtype, resolved_dtype_name = resolve_torch_dtype(device, dtype)
self.max_length = max_length
self.block_size = block_size
self.temperature = temperature
self.remask_strategy = remask_strategy
self.dynamic_threshold = dynamic_threshold

if flash_attn_disabled:
print(
"FlashAttention disabled for this GPU; using PyTorch SDPA fallback.",
file=sys.stderr,
)
if resolved_dtype_name != dtype:
print(
f"CUDA device does not support {dtype}; falling back to {resolved_dtype_name}.",
file=sys.stderr,
)

self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
self.model = AutoModel.from_pretrained(
Expand Down