onnx.Det (aten.linalg_det) torch-to-linalg lowering produces NaN on matrices that need row pivoting
test_det_2d from the ONNX backend node tests fails with a NaN output on every backend (CPU, Vulkan, HIP). The matrix is non-singular and the determinant is well-defined; the lowering just can't compute it because it does Gaussian elimination without partial pivoting.
This was previously hidden by iree-run-module's old --equality_mode=absolute matcher (fabs(NaN - x) > t is false in IEEE-754 for any t, so all-NaN-vs-finite buffers compared "equal"). After iree#24394 (NumPy-style approximate matcher with explicit NaN handling) the failure surfaces in CI.
Reproduce (no downloads, ~10 lines)
Save as det.mlir:
module {
func.func @test_det_2d(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64} {
%0 = torch.operator "onnx.Det"(%arg0) : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
}
Run on any backend (CPU shown):
# Build the input matrix [[0, 1], [2, 3]]; det = 0*3 - 1*2 = -2
python3 -c "import numpy as np; np.array([0,1,2,3], dtype=np.float32).tofile('input.bin')"
iree-compile det.mlir \
--iree-hal-target-device=local \
--iree-hal-local-target-device-backends=llvm-cpu \
-o det.vmfb
iree-run-module --module=det.vmfb --device=local-sync \
--input=2x2xf32=@input.bin \
--output=@out.bin
python3 -c "import numpy as np; print('det =', np.fromfile('out.bin', dtype=np.float32)[0])"
Observed: det = nan
Expected: det = -2.0
np.linalg.det([[0,1],[2,3]]) == -2.0; onnxruntime returns -2.0.
IR walkthrough — where the NaN comes from
After --mlir-print-ir-after=torch-decompose-complex-ops:
%0 = torch.aten.linalg_det %arg0 : !torch.vtensor<[2,2],f32> -> !torch.vtensor<[],f32>
After torch-to-linalg (lowered body, abridged):
%cst = arith.constant 0.000000e+00 : f32
%pivot = tensor.extract_slice %M[0,0] [1,1] [1,1] : ... // M[0,0] = 0
%row0 = tensor.extract_slice %M[0,0] [1,2] [1,1] : ... // [0, 1]
%row1 = tensor.extract_slice %M[1,0] [1,2] [1,1] : ... // [2, 3]
%col0 = tensor.extract_slice %M[1,0] [1,1] [1,1] : ... // [2]
// Single Gauss-elim step over row1: row1 -= (col0 / pivot) * row0
%3 = linalg.generic ... {
^bb0(%row0_j: f32, %piv: f32, %row1_j: f32, %col0_below: f32, %out: f32):
%9 = arith.divf %col0_below, %piv : f32 // 2.0 / 0.0 = +inf
%10 = arith.cmpf one, %piv, %cst : f32 // 0.0 != 0.0 -> false
cf.assert %10, "unimplemented: determinants requiring permutations and singular matrices"
%11 = arith.mulf %9, %row0_j : f32 // inf * row0_j
%12 = arith.subf %row1_j, %11 : f32 // row1_j - inf-or-NaN
linalg.yield %12 : f32
}
// Then take the diagonal of the eliminated matrix and reduce by mul:
%6 = linalg.generic { reduction } ins(%U_diag) ... { arith.mulf } -> f32
For M = [[0,1],[2,3]]:
| Step |
Result |
divf 2.0, 0.0 |
+∞ |
cmpf one 0.0, 0.0 |
false (assert fires per-iteration but the divf already populated the buffer for that lane) |
subf 3.0, (∞ * 1.0) |
−∞ (j=1) |
subf 2.0, (∞ * 0.0) |
2.0 − NaN = NaN (j=0) |
prod of U_diag = [0, NaN] (or [0, −∞]) |
NaN |
The matrix is not singular — det = −2. The lowering simply has no row pivoting, so any input with a zero (or near-zero) at M[0,0] produces NaN.
Where to fix
torch-mlir's aten.linalg_det torch-to-linalg conversion does a single elimination sweep and lacks partial pivoting. Implementing partial pivoting (track row swaps and negate the determinant on each swap) closes the gap and removes the unimplemented: determinants requiring permutations and singular matrices assert.
Affected tests / xfail entries
Currently test_det_2d is being added to expected_run_failures in:
tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.json
tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.json
tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan_O0.json
(And presumably the cuda / hip configs once their next CI runs flag it.) Closing this issue should remove all of those entries.
test_det_nd (3 stacks of 2×2 matrices, none with a zero pivot) does not trigger the bug — it picks input matrices the no-pivot LU happens to handle.
Why this surfaced now
Before #24394, iree-run-module --equality_mode=absolute (the default) checked
fabs(actual - expected) > threshold. With actual = NaN, that expression
evaluates to false in IEEE-754 for any finite expected value and any
threshold, so the comparator silently reported a match for every NaN
element. The buggy NaN output was silently accepted.
#24394 replaces the matcher with the NumPy-style formula
|actual - expected| <= atol + rtol * |expected| (with explicit
isnan(actual) && isnan(expected) short-circuit), which correctly flags
NaN-vs-finite as a mismatch. The Det bug is unchanged; the comparator just
now tells the truth about it.
onnx.Det(aten.linalg_det) torch-to-linalg lowering produces NaN on matrices that need row pivotingtest_det_2dfrom the ONNX backend node tests fails with a NaN output on every backend (CPU, Vulkan, HIP). The matrix is non-singular and the determinant is well-defined; the lowering just can't compute it because it does Gaussian elimination without partial pivoting.This was previously hidden by
iree-run-module's old--equality_mode=absolutematcher (fabs(NaN - x) > tisfalsein IEEE-754 for anyt, so all-NaN-vs-finite buffers compared "equal"). After iree#24394 (NumPy-style approximate matcher with explicit NaN handling) the failure surfaces in CI.Reproduce (no downloads, ~10 lines)
Save as
det.mlir:Run on any backend (CPU shown):
Observed:
det = nanExpected:
det = -2.0np.linalg.det([[0,1],[2,3]]) == -2.0;onnxruntimereturns-2.0.IR walkthrough — where the NaN comes from
After
--mlir-print-ir-after=torch-decompose-complex-ops:After torch-to-linalg (lowered body, abridged):
For
M = [[0,1],[2,3]]:divf 2.0, 0.0+∞cmpf one 0.0, 0.0false(assert fires per-iteration but the divf already populated the buffer for that lane)subf 3.0, (∞ * 1.0)−∞(j=1)subf 2.0, (∞ * 0.0)2.0 − NaN = NaN(j=0)prodofU_diag = [0, NaN](or[0, −∞])NaNThe matrix is not singular — det = −2. The lowering simply has no row pivoting, so any input with a zero (or near-zero) at
M[0,0]produces NaN.Where to fix
torch-mlir's
aten.linalg_dettorch-to-linalg conversion does a single elimination sweep and lacks partial pivoting. Implementing partial pivoting (track row swaps and negate the determinant on each swap) closes the gap and removes theunimplemented: determinants requiring permutations and singular matricesassert.Affected tests / xfail entries
Currently
test_det_2dis being added toexpected_run_failuresin:tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.jsontests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.jsontests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan_O0.json(And presumably the cuda / hip configs once their next CI runs flag it.) Closing this issue should remove all of those entries.
test_det_nd(3 stacks of 2×2 matrices, none with a zero pivot) does not trigger the bug — it picks input matrices the no-pivot LU happens to handle.Why this surfaced now
Before #24394,
iree-run-module --equality_mode=absolute(the default) checkedfabs(actual - expected) > threshold. Withactual = NaN, that expressionevaluates to
falsein IEEE-754 for any finite expected value and anythreshold, so the comparator silently reported a match for every NaN
element. The buggy NaN output was silently accepted.
#24394 replaces the matcher with the NumPy-style formula
|actual - expected| <= atol + rtol * |expected|(with explicitisnan(actual) && isnan(expected)short-circuit), which correctly flagsNaN-vs-finite as a mismatch. The Det bug is unchanged; the comparator just
now tells the truth about it.