From a02c830dc981a537b13fd338e38933012b537b5e Mon Sep 17 00:00:00 2001 From: Albert Date: Tue, 23 Jun 2026 22:21:32 +0800 Subject: [PATCH 1/2] Fix(Conv2d) --- musa_ext/kernels/math/musa_conv2d_op.cc | 11 +---------- test/ops/conv2d_op_test.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/musa_ext/kernels/math/musa_conv2d_op.cc b/musa_ext/kernels/math/musa_conv2d_op.cc index 26e16414..be447000 100644 --- a/musa_ext/kernels/math/musa_conv2d_op.cc +++ b/musa_ext/kernels/math/musa_conv2d_op.cc @@ -160,16 +160,7 @@ Status RunMusaConv2D(OpKernelContext* ctx, const Tensor& input, static_cast(status)); } - mConvolution::Algorithm algo; - status = conv.GetRecommendForwardAlgorithm(handle, algo, y, x, w); - if (status != mStatus::SUCCESS) { - return errors::Internal( - "muDNN Convolution::GetRecommendForwardAlgorithm failed. status=", - static_cast(status), ", data_format=NHWC", - ", input_shape=", input.shape().DebugString(), - ", filter_shape=", filter.shape().DebugString(), - ", output_shape=", output->shape().DebugString()); - } + const mConvolution::Algorithm algo = mConvolution::Algorithm::IMPLICIT_GEMM; size_t workspace_size = 0; status = conv.GetForwardWorkspaceSize(handle, workspace_size, y, x, w, algo); diff --git a/test/ops/conv2d_op_test.py b/test/ops/conv2d_op_test.py index 670d445d..b64a5a2b 100755 --- a/test/ops/conv2d_op_test.py +++ b/test/ops/conv2d_op_test.py @@ -155,6 +155,17 @@ def testConv2DPointwise1x1NHWC(self): padding="SAME", data_format="NHWC") + def testConv2DResNet50Conv4ShapeNHWC(self): + """ResNet50 conv4_block1_2 shape should not use an unstable algorithm.""" + self._test_conv2d( + input_shape=[1, 14, 14, 256], + filter_shape=[3, 3, 256, 256], + dtype=tf.float32, + strides=[1, 1, 1, 1], + padding="SAME", + data_format="NHWC", + seed=2027) + def testConv2DValidNCHW(self): """NCHW + VALID + stride=1.""" for dtype in [tf.float32, tf.float16]: From 0a007af5f8d8924170b2a2f0dd69ef1b4be54520 Mon Sep 17 00:00:00 2001 From: Albert Date: Tue, 23 Jun 2026 22:23:28 +0800 Subject: [PATCH 2/2] Fix(Conv2d) --- musa_ext/kernels/math/musa_conv2d_op.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/musa_ext/kernels/math/musa_conv2d_op.cc b/musa_ext/kernels/math/musa_conv2d_op.cc index be447000..4a03d620 100644 --- a/musa_ext/kernels/math/musa_conv2d_op.cc +++ b/musa_ext/kernels/math/musa_conv2d_op.cc @@ -316,8 +316,7 @@ class MusaConv2DOp : public MusaOpKernel { return; } - const bool use_tf32 = - tf32_enabled_ && dilation_h_ == 1 && dilation_w_ == 1; + const bool use_tf32 = tf32_enabled_ && dilation_h_ == 1 && dilation_w_ == 1; if (data_format_ == FORMAT_NHWC) { OP_REQUIRES_OK(