diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 80991a3ebbb5f..f4ccff1d7770d 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -2373,11 +2373,11 @@ "test_slice_negative_axes", "test_slice_start_out_of_bounds", "test_slice", - // "test_softmax_axis_0_expanded", + "test_softmax_axis_0_expanded", "test_softmax_axis_0", - // "test_softmax_axis_1_expanded", + "test_softmax_axis_1_expanded", "test_softmax_axis_1", - // "test_softmax_axis_2_expanded", + "test_softmax_axis_2_expanded", "test_softmax_axis_2", // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded", @@ -2447,13 +2447,13 @@ // "test_softmax_cross_entropy_sum_log_prob_expanded", // "test_softmax_cross_entropy_sum_log_prob", // "test_softmax_cross_entropy_sum", - // "opset13/test_softmax_default_axis_expanded", + "opset13/test_softmax_default_axis_expanded", "opset13/test_softmax_default_axis", - // "test_softmax_example_expanded", + "test_softmax_example_expanded", "test_softmax_example", - // "test_softmax_large_number_expanded", + "test_softmax_large_number_expanded", "test_softmax_large_number", - // "test_softmax_negative_axis_expanded", + "test_softmax_negative_axis_expanded", "test_softmax_negative_axis", // // "test_softplus_example", // // "test_softplus", diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 99d137f81864c..a4e1e31aba504 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -29,14 +29,48 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const auto input_size = input_shape.size(); + emscripten::val options = emscripten::val::object(); + NodeAttrHelper helper(node); - const int32_t default_axis = node.SinceVersion() < 13 ? 1 : -1; + const auto since_version = node.SinceVersion(); + const int32_t default_axis = since_version < 13 ? 1 : -1; int32_t axis = helper.Get("axis", default_axis); - axis = static_cast(HandleNegativeAxis(axis, input_size)); + axis = SafeInt(HandleNegativeAxis(axis, input_size)); + + // Prior to opset 13, Softmax operates with different semantics compared to opset 13 and later. + // Specifically, it normalizes over the flattened range of dimensions starting from the specified + // axis to the last dimension. + // In contrast, WebNN's softmax aligns with the behavior introduced in opset 13 and later. + // To handle the differences for earlier opsets, a reshape operation can be applied if necessary. + const bool do_reshape = since_version < 13 && axis != SafeInt(input_size - 1); + std::vector input_shape_uint32; + if (do_reshape) { + input_shape_uint32 = GetNarrowedIntFromInt64(input_shape); + // Need to reshape the input to 2D tensor with new shape [M, N]. + // M = d0*d1*...*d(axis-1), N = d(axis)*...*d(n-1) + const auto M = Product(std::vector(input_shape_uint32.begin(), input_shape_uint32.begin() + axis)); + const auto N = Product(std::vector(input_shape_uint32.begin() + axis, input_shape_uint32.end())); + emscripten::val new_shape = emscripten::val::array(); + new_shape.set(0, M); + new_shape.set(1, N); + + options.set("label", node.Name() + "_reshape_input"); + input = model_builder.GetBuilder().call("reshape", input, new_shape, options); + // Apply softmax along the last dimension (N). + axis = 1; + } - emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("softmax", input, axis, options); + + if (do_reshape) { + // Softmax has the same output shape as input shape. + // Reshape the output back to the original input shape. + options.set("label", node.Name() + "_reshape_output"); + output = model_builder.GetBuilder().call( + "reshape", output, emscripten::val::array(input_shape_uint32), options); + } + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); }