Skip to content

Generation overhead: many GPU syncs per token + PyTorch dispatch overhead #43089

@AmitMY

Description

@AmitMY

Generation overhead: 3.25 GPU syncs per token + PyTorch dispatch overhead

System Info

  • transformers version: 5.0.0.dev0 (main branch)
  • Platform: Linux
  • Python version: 3.12
  • PyTorch version: 2.x with CUDA
  • GPU: NVIDIA (tested)

Who can help?

@gante @zucchini-nlp

Information

  • My own modified scripts

Tasks

  • My own task or dataset (give details below)

Reproduction

We benchmarked generation overhead using a tiny model (hidden_size=16, 1 layer, vocab_size=256) to isolate framework overhead from actual compute.

Benchmark Script

import time
import warnings
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList

warnings.filterwarnings("ignore")

class AlwaysPassStoppingCriteria(StoppingCriteria):
    def __call__(self, input_ids, scores, **kwargs):
        return torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)

class ExtraSoftmaxLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids, scores):
        return torch.log(F.softmax(scores, dim=-1) + 1e-10)

# Tiny model to minimize compute, expose overhead
config = LlamaConfig(
    vocab_size=256, hidden_size=16, intermediate_size=16,
    num_hidden_layers=1, num_attention_heads=1, num_key_value_heads=1,
    max_position_embeddings=2048, use_cache=True,
)

for device in ["cpu", "cuda"]:
    model = LlamaForCausalLM(config).to(device).eval()
    input_ids = torch.tensor([[1]], device=device)
    attention_mask = torch.ones_like(input_ids)

    stopping_criteria = StoppingCriteriaList([AlwaysPassStoppingCriteria()])
    logits_processor = LogitsProcessorList([ExtraSoftmaxLogitsProcessor()])

    # Warmup
    with torch.no_grad():
        for _ in range(3):
            model.generate(input_ids, attention_mask=attention_mask,
                          min_new_tokens=64, max_new_tokens=64,
                          stopping_criteria=stopping_criteria,
                          logits_processor=logits_processor,
                          do_sample=False, pad_token_id=0)

    if device == "cuda":
        torch.cuda.synchronize()

    # Benchmark
    times = []
    with torch.no_grad():
        for _ in range(10):
            start = time.perf_counter()
            model.generate(input_ids, attention_mask=attention_mask,
                          min_new_tokens=64, max_new_tokens=64,
                          stopping_criteria=stopping_criteria,
                          logits_processor=logits_processor,
                          do_sample=False, pad_token_id=0)
            if device == "cuda":
                torch.cuda.synchronize()
            times.append(time.perf_counter() - start)

    mean_ms = sum(times) / len(times) * 1000
    print(f"{device.upper()}: {mean_ms:.1f} ms for 64 tokens ({mean_ms/64:.2f} ms/token)")

Results

Device Time (64 tokens) Per Token
CPU ~226 ms ~3.5 ms
GPU ~578 ms ~9.0 ms

GPU is 2.5x slower than CPU on this tiny model because overhead dominates.


Root Cause Analysis

1. CPU Slowness: PyTorch Dispatch Overhead (~3.5 ms/token)

Even on CPU, each token takes ~3.5ms despite trivial compute. Profiling shows:

Operation Time per call
Raw matmul (16x16) 0.002 ms
nn.Linear(16,16) 0.24 ms
Full forward pass ~3.0 ms

The model has ~8 linear layers per forward. Each nn.Linear call incurs ~0.24ms PyTorch dispatch overhead (Python→C++ transition, tensor metadata, op dispatch). This is unavoidable Python/PyTorch overhead.

2. GPU Slowness: 3.25 GPU→CPU Syncs Per Token (~5.5 ms/token overhead)

Profiling with torch.profiler reveals 26 aten::_local_scalar_dense calls for 8 tokens (the actual GPU→CPU sync operation):

aten::is_nonzero     26 calls    41ms total    ~1.5ms each
aten::item           26 calls    41ms total    ~1.5ms each

26 syncs / 8 tokens = 3.25 syncs per token

Tracing the call sites:

Location Calls per Token Purpose
utils.py:2674 _has_unfinished_sequences 1 Check if generation done
utils.py:522 _cache_dependant_input_preparation 1 Check cache state
masking_utils.py:253 _ignore_causal_mask_sdpa 1 Check mask skip condition

Each sync forces GPU→CPU data transfer and pipeline stall, costing ~1.5ms each.

Total GPU overhead: 3.25 syncs × 1.5ms = ~5ms per token


Expected behavior

Generation should have minimal per-token overhead, especially for:

  • Small/medium models where compute doesn't dominate
  • Latency-sensitive applications
  • Edge deployment scenarios

Potential Improvements

  1. Reduce sync frequency - Check stopping criteria every N tokens instead of every token (see PR Skip attention_mask.all() GPU-CPU sync during generation #43088)

  2. Async stopping criteria - Run sync-causing checks in a separate CUDA stream so they don't block the main compute stream (see PR Add async_stopping_criteria flag to reduce GPU-CPU syncs during generation #43085)

  3. Batch boolean checks - Combine multiple boolean tensor checks into a single sync point

  4. Lazy evaluation - Defer _ignore_causal_mask_sdpa and similar checks when not needed

  5. Compile-friendly paths - torch.compile could potentially fuse operations and reduce sync points

Impact

For real models, this overhead gets amortized by actual compute. But for:

  • Speculative decoding with small draft models
  • On-device/edge models
  • High-throughput serving with small models
  • Distilled/quantized models

...this 3.25 syncs/token overhead becomes significant. Reducing it would improve generation latency across the board.


Related PRs

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions