Skip to content

[BUG][ATTENTION] GQA sliding-window fwd mismatches non-causal right-window cases #1616

Description

@superAngGao

Summary

Nightly op tests still have a real correctness failure in fixed-length GQA sliding-window forward, separate from the TileLang autotune missing-argument issue and the CI Triton cache permission issue.

Nightly run: https://github.com/tile-ai/TileOPs/actions/runs/28125205861/job/83305954027

The failing op in the nightly report is:

  • GroupedQueryAttentionSlidingWindowFwdOp
  • module reported by nightly: tileops.ops.attention.gqa
  • current test path: tests/ops/test_gqa_sliding_window_fwd.py
  • current op wrapper: GqaSlidingWindowFwdOp
  • kernel selected on H200: GQASlidingWindowFwdWgmmaPipelinedKernel

Failing Cases

Two fixed-length cases fail against the pure PyTorch reference:

Case Window dtype Error
test_gqa_sliding_window_fwd_op[2-512-8-2-64-False-64-64-dtype6-False] non-causal, wl=64, wr=64 bf16 max abs diff 0.474609375, 91170 / 524288 mismatched (17.4%)
test_gqa_sliding_window_fwd_op[2-512-8-2-64-False--1-64-dtype12-False] non-causal, wl=-1, wr=64 fp16 max abs diff 0.444580078125, 77024 / 524288 mismatched (14.7%)

Failure excerpt:

AssertionError: Tensor-likes are not close!
Mismatched elements: 91170 / 524288 (17.4%)
Greatest absolute difference: 0.474609375 at index (1, 26, 2, 45) (up to 0.01 allowed)
AssertionError: Tensor-likes are not close!
Mismatched elements: 77024 / 524288 (14.7%)
Greatest absolute difference: 0.444580078125 at index (1, 256, 5, 7) (up to 0.01 allowed)

Both failures used the default H200 WGMMA config:

GQASlidingWindowFwdWgmmaPipelinedKernel initialized with config: {'block_m': 128, 'block_n': 128, 'num_stages': 3, 'threads': 256}

Useful Clues

The PyTorch reference mask is:

if is_causal:
    mask = mask | (q_pos < k_pos)
if wl >= 0:
    mask = mask | (k_pos < q_pos - wl)
if wr >= 0:
    mask = mask | (k_pos > q_pos + wr)

Observed pattern from the nightly run:

  • fixed-length causal cases pass;
  • fixed-length non-causal no-window case passes;
  • fixed-length non-causal left-window-only case passes;
  • fixed-length non-causal right-window cases fail;
  • fixed-length wl=64, wr=64 passes for fp16 but fails for bf16;
  • fixed-length right-only wl=-1, wr=64 fails for fp16;
  • varlen GQA sliding-window cases with similar window shapes passed in the same run.

This points toward the fixed-length WGMMA kernel's non-causal right-window masking / block-boundary handling rather than the reference implementation or the generic test harness.

Proposed Investigation

  • Reproduce the two failing cases locally on H200 with the WGMMA kernel.
  • Compare fixed-length mask bounds with the varlen implementation, especially the right-window upper bound (k_pos <= q_pos + wr) and how it is applied per block_m/block_n tile.
  • Check whether the issue only appears when wr >= 0, and whether bf16 accumulation/casting exposes an additional error for symmetric windows.
  • Add a narrow regression test for:
    • non-causal right-only wl=-1, wr=64, fp16;
    • non-causal symmetric wl=64, wr=64, bf16.

Acceptance Criteria

  • Both failing fixed-length GQA sliding-window cases pass with atol=1e-2, rtol=1e-2.
  • Existing causal, left-window-only, no-window, and varlen sliding-window cases remain green.
  • The fix is covered by targeted regression cases for right-window masking.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions