Skip to content

IREE produces NaNs in onnx.Det test because of how torch-mlir lower the op #24397

@hanhanW

Description

@hanhanW

onnx.Det (aten.linalg_det) torch-to-linalg lowering produces NaN on matrices that need row pivoting

test_det_2d from the ONNX backend node tests fails with a NaN output on every backend (CPU, Vulkan, HIP). The matrix is non-singular and the determinant is well-defined; the lowering just can't compute it because it does Gaussian elimination without partial pivoting.

This was previously hidden by iree-run-module's old --equality_mode=absolute matcher (fabs(NaN - x) > t is false in IEEE-754 for any t, so all-NaN-vs-finite buffers compared "equal"). After iree#24394 (NumPy-style approximate matcher with explicit NaN handling) the failure surfaces in CI.

Reproduce (no downloads, ~10 lines)

Save as det.mlir:

module {
  func.func @test_det_2d(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64} {
    %0 = torch.operator "onnx.Det"(%arg0) : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[],f32>
    return %0 : !torch.vtensor<[],f32>
  }
}

Run on any backend (CPU shown):

# Build the input matrix [[0, 1], [2, 3]]; det = 0*3 - 1*2 = -2
python3 -c "import numpy as np; np.array([0,1,2,3], dtype=np.float32).tofile('input.bin')"

iree-compile det.mlir \
  --iree-hal-target-device=local \
  --iree-hal-local-target-device-backends=llvm-cpu \
  -o det.vmfb

iree-run-module --module=det.vmfb --device=local-sync \
  --input=2x2xf32=@input.bin \
  --output=@out.bin

python3 -c "import numpy as np; print('det =', np.fromfile('out.bin', dtype=np.float32)[0])"

Observed: det = nan
Expected: det = -2.0

np.linalg.det([[0,1],[2,3]]) == -2.0; onnxruntime returns -2.0.

IR walkthrough — where the NaN comes from

After --mlir-print-ir-after=torch-decompose-complex-ops:

%0 = torch.aten.linalg_det %arg0 : !torch.vtensor<[2,2],f32> -> !torch.vtensor<[],f32>

After torch-to-linalg (lowered body, abridged):

%cst = arith.constant 0.000000e+00 : f32
%pivot = tensor.extract_slice %M[0,0] [1,1] [1,1] : ...   // M[0,0] = 0
%row0  = tensor.extract_slice %M[0,0] [1,2] [1,1] : ...   // [0, 1]
%row1  = tensor.extract_slice %M[1,0] [1,2] [1,1] : ...   // [2, 3]
%col0  = tensor.extract_slice %M[1,0] [1,1] [1,1] : ...   // [2]

// Single Gauss-elim step over row1: row1 -= (col0 / pivot) * row0
%3 = linalg.generic ... {
^bb0(%row0_j: f32, %piv: f32, %row1_j: f32, %col0_below: f32, %out: f32):
  %9  = arith.divf %col0_below, %piv : f32        // 2.0 / 0.0  = +inf
  %10 = arith.cmpf one, %piv, %cst   : f32        // 0.0 != 0.0 -> false
  cf.assert %10, "unimplemented: determinants requiring permutations and singular matrices"
  %11 = arith.mulf %9, %row0_j : f32              // inf * row0_j
  %12 = arith.subf %row1_j, %11 : f32             // row1_j - inf-or-NaN
  linalg.yield %12 : f32
}

// Then take the diagonal of the eliminated matrix and reduce by mul:
%6 = linalg.generic { reduction } ins(%U_diag) ... { arith.mulf } -> f32

For M = [[0,1],[2,3]]:

Step Result
divf 2.0, 0.0 +∞
cmpf one 0.0, 0.0 false (assert fires per-iteration but the divf already populated the buffer for that lane)
subf 3.0, (∞ * 1.0) −∞ (j=1)
subf 2.0, (∞ * 0.0) 2.0 − NaN = NaN (j=0)
prod of U_diag = [0, NaN] (or [0, −∞]) NaN

The matrix is not singular — det = −2. The lowering simply has no row pivoting, so any input with a zero (or near-zero) at M[0,0] produces NaN.

Where to fix

torch-mlir's aten.linalg_det torch-to-linalg conversion does a single elimination sweep and lacks partial pivoting. Implementing partial pivoting (track row swaps and negate the determinant on each swap) closes the gap and removes the unimplemented: determinants requiring permutations and singular matrices assert.

Affected tests / xfail entries

Currently test_det_2d is being added to expected_run_failures in:

  • tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.json
  • tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.json
  • tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan_O0.json

(And presumably the cuda / hip configs once their next CI runs flag it.) Closing this issue should remove all of those entries.

test_det_nd (3 stacks of 2×2 matrices, none with a zero pivot) does not trigger the bug — it picks input matrices the no-pivot LU happens to handle.

Why this surfaced now

Before #24394, iree-run-module --equality_mode=absolute (the default) checked
fabs(actual - expected) > threshold. With actual = NaN, that expression
evaluates to false in IEEE-754 for any finite expected value and any
threshold, so the comparator silently reported a match for every NaN
element. The buggy NaN output was silently accepted.

#24394 replaces the matcher with the NumPy-style formula
|actual - expected| <= atol + rtol * |expected| (with explicit
isnan(actual) && isnan(expected) short-circuit), which correctly flags
NaN-vs-finite as a mismatch. The Det bug is unchanged; the comparator just
now tells the truth about it.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions