diff --git a/engines/hf/runner.py b/engines/hf/runner.py index 427de4e..d260e0e 100644 --- a/engines/hf/runner.py +++ b/engines/hf/runner.py @@ -10,6 +10,8 @@ from termcolor import cprint from transformers import AutoModel, AutoProcessor, AutoTokenizer +from mineru_diffusion.utils.runtime import maybe_disable_flash_attention, resolve_torch_dtype + STOP_STRINGS = ("<|endoftext|>", "<|im_end|>") SYSTEM_PROMPT = "You are a helpful assistant." @@ -70,7 +72,21 @@ def run(args: argparse.Namespace) -> None: prompt = args.prompt or TASK_PROMPTS[args.prompt_type] model_path = Path(args.model_path).resolve() device = args.device - dtype = getattr(torch, args.dtype) + flash_attn_disabled = maybe_disable_flash_attention(device) + dtype, resolved_dtype_name = resolve_torch_dtype(device, args.dtype) + + if flash_attn_disabled: + cprint( + "FlashAttention disabled for this GPU; using PyTorch SDPA fallback.", + color="yellow", + flush=True, + ) + if resolved_dtype_name != args.dtype: + cprint( + f"CUDA device does not support {args.dtype}; falling back to {resolved_dtype_name}.", + color="yellow", + flush=True, + ) _print_summary(args, model_path, device, dtype) diff --git a/mineru_diffusion/utils/runtime.py b/mineru_diffusion/utils/runtime.py new file mode 100644 index 0000000..b183f9a --- /dev/null +++ b/mineru_diffusion/utils/runtime.py @@ -0,0 +1,56 @@ +import os + +import torch + + +def _parse_device(device: str) -> torch.device | None: + try: + return torch.device(device) + except (RuntimeError, TypeError, ValueError): + return None + + +def should_disable_flash_attention(device: str) -> bool: + if os.environ.get("MINERU_DISABLE_FLASH_ATTN") == "1": + return True + + torch_device = _parse_device(device) + if torch_device is None or torch_device.type != "cuda" or not torch.cuda.is_available(): + return False + + try: + major, _ = torch.cuda.get_device_capability(torch_device) + except Exception: + return False + return major < 8 + + +def maybe_disable_flash_attention(device: str) -> bool: + if not should_disable_flash_attention(device): + return False + + os.environ["MINERU_DISABLE_FLASH_ATTN"] = "1" + + try: + import flash_attn + except ImportError: + return True + + flash_attn.flash_attn_func = None + return True + + +def resolve_torch_dtype(device: str, requested_dtype: str) -> tuple[torch.dtype, str]: + torch_device = _parse_device(device) + resolved_dtype = requested_dtype + + if ( + requested_dtype == "bfloat16" + and torch_device is not None + and torch_device.type == "cuda" + and torch.cuda.is_available() + and not torch.cuda.is_bf16_supported() + ): + resolved_dtype = "float16" + + return getattr(torch, resolved_dtype), resolved_dtype diff --git a/scripts/run_end2end.py b/scripts/run_end2end.py index af67b6c..9ac2baf 100644 --- a/scripts/run_end2end.py +++ b/scripts/run_end2end.py @@ -21,6 +21,7 @@ sys.path.insert(0, str(REPO_DIR)) from mineru_diffusion.utils.bbox import draw_bbox +from mineru_diffusion.utils.runtime import maybe_disable_flash_attention, resolve_torch_dtype STOP_STRINGS = ("<|endoftext|>", "<|im_end|>") @@ -369,13 +370,25 @@ def __init__( ) -> None: self.model_path = model_path self.device = device - self.torch_dtype = getattr(torch, dtype) + flash_attn_disabled = maybe_disable_flash_attention(device) + self.torch_dtype, resolved_dtype_name = resolve_torch_dtype(device, dtype) self.max_length = max_length self.block_size = block_size self.temperature = temperature self.remask_strategy = remask_strategy self.dynamic_threshold = dynamic_threshold + if flash_attn_disabled: + print( + "FlashAttention disabled for this GPU; using PyTorch SDPA fallback.", + file=sys.stderr, + ) + if resolved_dtype_name != dtype: + print( + f"CUDA device does not support {dtype}; falling back to {resolved_dtype_name}.", + file=sys.stderr, + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=False) self.model = AutoModel.from_pretrained(