diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fusion.cpp index 57f5a6dfbb1fa0..a603150a8b7b5c 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fusion.cpp @@ -122,6 +122,34 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() { matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); auto root = m.get_match_root(); + // Check that the first input of Multiply is the gate (activation) branch and the second input is the up branch; + // otherwise, do not fuse. Input order is critical for correctness: mismatched input order can silently cause + // accuracy issues. + auto mlp_gated_up_node = pattern_map.at(mlp_gated_up).get_node_shared_ptr(); + auto input0 = mlp_gated_up_node->input_value(0); + auto input1 = mlp_gated_up_node->input_value(1); + + bool input0_is_gate = false; + bool input1_is_up = false; + + if (pattern_map.count(mlp_silu_gate) && input0.get_node() == pattern_map.at(mlp_silu_gate).get_node()) { + input0_is_gate = true; + } + if (pattern_map.count(mlp_gelu_gate) && input0.get_node() == pattern_map.at(mlp_gelu_gate).get_node()) { + input0_is_gate = true; + } + + if (pattern_map.count(mlp_up_proj) && input1.get_node() == pattern_map.at(mlp_up_proj).get_node()) { + input1_is_up = true; + } + if (pattern_map.count(gate_up_proj_split) && + input1.get_node() == pattern_map.at(gate_up_proj_split).get_node() && input1.get_index() == 1) { + input1_is_up = true; + } + + if (!input0_is_gate || !input1_is_up) { + return false; + } auto src = pattern_map.at(input); if (!src.get_element_type().is_real()) { // FakeQuantize, should skip fusion diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mlp_fusion.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mlp_fusion.cpp index e2b9f9b831720f..41d00ae571dc3b 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mlp_fusion.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mlp_fusion.cpp @@ -6,13 +6,13 @@ #include #include "common_test_utils/ov_tensor_utils.hpp" -#include "openvino/runtime/exec_model_info.hpp" -#include "shared_test_classes/base/ov_subgraph.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/gelu.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/swish.hpp" +#include "openvino/runtime/exec_model_info.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" namespace ov { namespace test { @@ -23,6 +23,7 @@ struct LLMMLPFusionParams { size_t up_size; std::string act_type; bool use_dynamic_quant; + bool swap_inputs; // true = swap inputs to prevent fusion, false = normal order for fusion }; class LLMMLPFusionTest : public testing::WithParamInterface, public ov::test::SubgraphBaseTest { @@ -39,6 +40,7 @@ class LLMMLPFusionTest : public testing::WithParamInterface, result << "up_size=" << obj.param.up_size << "_"; result << "act_type=" << obj.param.act_type << "_"; result << "use_dynamic_quant=" << obj.param.use_dynamic_quant << "_"; + result << "swap_inputs=" << obj.param.swap_inputs << "_"; result << obj.index; return result.str(); } @@ -70,7 +72,8 @@ class LLMMLPFusionTest : public testing::WithParamInterface, in_data.start_from = 0; in_data.range = 1; in_data.resolution = 128; - auto tensor_scale_per_oc = ov::test::utils::create_and_fill_tensor(ov::element::f32, ov::Shape{OC, 1}, in_data); + auto tensor_scale_per_oc = + ov::test::utils::create_and_fill_tensor(ov::element::f32, ov::Shape{OC, 1}, in_data); auto scale_per_oc = std::make_shared(tensor_scale_per_oc); auto weight_deq = std::make_shared(weight_const_f32, scale_per_oc); @@ -85,7 +88,8 @@ class LLMMLPFusionTest : public testing::WithParamInterface, return std::make_shared(tensor); }; if (param.use_dynamic_quant) - configuration.insert({ov::hint::dynamic_quantization_group_size.name(), std::numeric_limits::max()}); + configuration.insert( + {ov::hint::dynamic_quantization_group_size.name(), std::numeric_limits::max()}); auto gate_weight = create_const(param.up_size, param.down_size, 100); auto up_weight = create_const(param.up_size, param.down_size, 100); @@ -101,13 +105,22 @@ class LLMMLPFusionTest : public testing::WithParamInterface, if (param.act_type == "Gelu") gate_act = std::make_shared(gate_proj); - auto gate_up = std::make_shared(gate_act, up_proj); + // Control input order based on swap_inputs parameter + std::shared_ptr gate_up; + if (param.swap_inputs) { + // Swapped order should prevent fusion + gate_up = std::make_shared(up_proj, gate_act); + } else { + // Normal order should allow fusion + gate_up = std::make_shared(gate_act, up_proj); + } + auto output = std::make_shared(gate_up, down_weight, false, true); function = std::make_shared(ov::OutputVector{output}, ov::ParameterVector{src}); } - void check_results() { + void check_fusion_result() { auto exec_model = compiledModel.get_runtime_model(); int fused_node_found = 0; @@ -116,7 +129,15 @@ class LLMMLPFusionTest : public testing::WithParamInterface, if (layer_type == "LLMMLP") fused_node_found++; } - ASSERT_EQ(fused_node_found, 1); + + auto& param = this->GetParam(); + if (param.swap_inputs) { + // When inputs are swapped, fusion should NOT happen + ASSERT_EQ(fused_node_found, 0) << "Fusion should not occur with swapped inputs"; + } else { + // Normal case, fusion should happen + ASSERT_EQ(fused_node_found, 1) << "Fusion should occur with correct input order"; + } } }; @@ -124,18 +145,24 @@ TEST_P(LLMMLPFusionTest, CompareWithRefs) { if (!ov::with_cpu_x86_avx512_core_amx_bf16()) GTEST_SKIP(); run(); - check_results(); + check_fusion_result(); } namespace { -static ov::test::InputShape ishape{ov::PartialShape{-1, -1, 4096 / 4}, {ov::Shape{1, 8, 4096 / 4}, ov::Shape{5, 37, 4096 / 4}}}; +static ov::test::InputShape ishape{ov::PartialShape{-1, -1, 4096 / 4}, + {ov::Shape{1, 8, 4096 / 4}, ov::Shape{5, 37, 4096 / 4}}}; +// Test parameters combining both normal fusion and no-fusion cases const std::vector mlp_params = { - {ishape, 4096 / 4, 11008 / 4, "Gelu", false}, - {ishape, 4096 / 4, 11008 / 4, "Gelu", true}, - {ishape, 4096 / 4, 11008 / 4, "Swish", false}, - {ishape, 4096 / 4, 11008 / 4, "Swish", true}, + // Normal cases - should fuse (swap_inputs = false) + {ishape, 4096 / 4, 11008 / 4, "Gelu", false, false}, + {ishape, 4096 / 4, 11008 / 4, "Gelu", true, false}, + {ishape, 4096 / 4, 11008 / 4, "Swish", false, false}, + {ishape, 4096 / 4, 11008 / 4, "Swish", true, false}, + + // Port order issue cases - should NOT fuse (swap_inputs = true) + {ishape, 4096 / 4, 11008 / 4, "Gelu", false, true}, }; INSTANTIATE_TEST_SUITE_P(smoke_LLMMLPFusion,