From 678c03b76240279f788b2b5d441625077022648c Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Tue, 23 Apr 2024 23:58:08 -0700 Subject: [PATCH] Fix nan issue for fp16 torch.randn/randn_like in ConvertAtenUniformOp (#3184) For ops that use ConvertAtenUniformOp (e.g. torch.randn/randn_like), fp16 datatype returns nan values. Trying to lower [this repro](https://gist.github.com/aviator19941/1c65e658241dea6906ca423f9abaee69) will result in nan's, this PR fixes the issue. --- lib/Conversion/TorchToLinalg/Random.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 3b18844df516..6519a272330e 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -129,6 +129,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { Value generator = adaptor.getGenerator(); RankedTensorType resultType = self.getType().cast(); Type elemTy = resultType.getElementType(); + Type f64Ty = rewriter.getF64Type(); if (!isa(elemTy)) return rewriter.notifyMatchFailure(op, "This op only support float type"); @@ -139,8 +140,8 @@ class ConvertAtenUniformOp : public OpConversionPattern { "generator is supported"); // Get key, min and max used by `linalg.generic` compute payload. Value key = rewriter.create(loc); - Value min = convertScalarToDtype(rewriter, loc, from, elemTy); - Value max = convertScalarToDtype(rewriter, loc, to, elemTy); + Value min = convertScalarToDtype(rewriter, loc, from, f64Ty); + Value max = convertScalarToDtype(rewriter, loc, to, f64Ty); // Construct the `linalg.generic` op. auto resultRank = resultType.getRank(); @@ -179,11 +180,14 @@ class ConvertAtenUniformOp : public OpConversionPattern { // res = cast(F64, tempN) * scale + min Value updateFloat = - b.create(loc, elemTy, randomVal); + b.create(loc, f64Ty, randomVal); Value updateScaled = b.create(loc, updateFloat, scale); Value res = b.create(loc, updateScaled, min); - b.create(loc, res); + Value truncRes = res; + if (elemTy.isa()) + truncRes = b.create(loc, elemTy, res); + b.create(loc, truncRes); }) .getResult(0);