Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,34 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto root = m.get_match_root();
// Check that the first input of Multiply is the gate (activation) branch and the second input is the up branch;
// otherwise, do not fuse. Input order is critical for correctness: mismatched input order can silently cause
// accuracy issues.
Comment on lines +125 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please move this limitation from callback function to pattern to match?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @chenhu-wang : previously we have makePattern() function to get the pattern port information, while in recent ov master branch, this function has been removed(#32180). so it seems we can only get pattern port information during callback time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    auto mlp_silu_gate = wrap_type<Swish>({mlp_gate_proj->output(0) | gate_up_proj_split->output(0)});
    auto mlp_gelu_gate = wrap_type<Gelu>({mlp_gate_proj->output(0) | gate_up_proj_split->output(0)});

    auto mlp_up_proj = wrap_type<MatMul>({input, up_proj_weight | up_proj_weight_compressed | up_proj_weight_deq},
                                         {{"transpose_a", false}, {"transpose_b", true}});

    auto mlp_gated_up = wrap_type<Multiply>({mlp_silu_gate->output(0) | mlp_gelu_gate->output(0),
                                             mlp_up_proj->output(0) | gate_up_proj_split->output(1)},
                                            {{"auto_broadcast", "numpy"}});

does this not work in pattern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems not work. the pattern will always match no matter which port(output(0) or output(1)) is used on my side. it seems 'wrap_type' object not really distinguish the real port information but just node type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, maybe need transformation guys be aware of this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @mryzhov : do you have any comments/suggestions about this?

auto mlp_gated_up_node = pattern_map.at(mlp_gated_up).get_node_shared_ptr();
auto input0 = mlp_gated_up_node->input_value(0);
auto input1 = mlp_gated_up_node->input_value(1);

bool input0_is_gate = false;
bool input1_is_up = false;

if (pattern_map.count(mlp_silu_gate) && input0.get_node() == pattern_map.at(mlp_silu_gate).get_node()) {
input0_is_gate = true;
}
if (pattern_map.count(mlp_gelu_gate) && input0.get_node() == pattern_map.at(mlp_gelu_gate).get_node()) {
input0_is_gate = true;
}

if (pattern_map.count(mlp_up_proj) && input1.get_node() == pattern_map.at(mlp_up_proj).get_node()) {
input1_is_up = true;
}
if (pattern_map.count(gate_up_proj_split) &&
input1.get_node() == pattern_map.at(gate_up_proj_split).get_node() && input1.get_index() == 1) {
input1_is_up = true;
}

if (!input0_is_gate || !input1_is_up) {
return false;
}
auto src = pattern_map.at(input);
if (!src.get_element_type().is_real()) {
// FakeQuantize, should skip fusion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
#include <vector>

#include "common_test_utils/ov_tensor_utils.hpp"
#include "openvino/runtime/exec_model_info.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gelu.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/swish.hpp"
#include "openvino/runtime/exec_model_info.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"

namespace ov {
namespace test {
Expand All @@ -23,6 +23,7 @@ struct LLMMLPFusionParams {
size_t up_size;
std::string act_type;
bool use_dynamic_quant;
bool swap_inputs; // true = swap inputs to prevent fusion, false = normal order for fusion
};

class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>, public ov::test::SubgraphBaseTest {
Expand All @@ -39,6 +40,7 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
result << "up_size=" << obj.param.up_size << "_";
result << "act_type=" << obj.param.act_type << "_";
result << "use_dynamic_quant=" << obj.param.use_dynamic_quant << "_";
result << "swap_inputs=" << obj.param.swap_inputs << "_";
result << obj.index;
return result.str();
}
Expand Down Expand Up @@ -70,7 +72,8 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
in_data.start_from = 0;
in_data.range = 1;
in_data.resolution = 128;
auto tensor_scale_per_oc = ov::test::utils::create_and_fill_tensor(ov::element::f32, ov::Shape{OC, 1}, in_data);
auto tensor_scale_per_oc =
ov::test::utils::create_and_fill_tensor(ov::element::f32, ov::Shape{OC, 1}, in_data);
auto scale_per_oc = std::make_shared<ov::op::v0::Constant>(tensor_scale_per_oc);

auto weight_deq = std::make_shared<ov::op::v1::Multiply>(weight_const_f32, scale_per_oc);
Expand All @@ -85,7 +88,8 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
return std::make_shared<ov::op::v0::Constant>(tensor);
};
if (param.use_dynamic_quant)
configuration.insert({ov::hint::dynamic_quantization_group_size.name(), std::numeric_limits<uint64_t>::max()});
configuration.insert(
{ov::hint::dynamic_quantization_group_size.name(), std::numeric_limits<uint64_t>::max()});

auto gate_weight = create_const(param.up_size, param.down_size, 100);
auto up_weight = create_const(param.up_size, param.down_size, 100);
Expand All @@ -101,13 +105,22 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
if (param.act_type == "Gelu")
gate_act = std::make_shared<ov::op::v7::Gelu>(gate_proj);

auto gate_up = std::make_shared<ov::op::v1::Multiply>(gate_act, up_proj);
// Control input order based on swap_inputs parameter
std::shared_ptr<ov::op::v1::Multiply> gate_up;
if (param.swap_inputs) {
// Swapped order should prevent fusion
gate_up = std::make_shared<ov::op::v1::Multiply>(up_proj, gate_act);
} else {
// Normal order should allow fusion
gate_up = std::make_shared<ov::op::v1::Multiply>(gate_act, up_proj);
}

auto output = std::make_shared<ov::op::v0::MatMul>(gate_up, down_weight, false, true);

function = std::make_shared<ov::Model>(ov::OutputVector{output}, ov::ParameterVector{src});
}

void check_results() {
void check_fusion_result() {
auto exec_model = compiledModel.get_runtime_model();

int fused_node_found = 0;
Expand All @@ -116,26 +129,40 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
if (layer_type == "LLMMLP")
fused_node_found++;
}
ASSERT_EQ(fused_node_found, 1);

auto& param = this->GetParam();
if (param.swap_inputs) {
// When inputs are swapped, fusion should NOT happen
ASSERT_EQ(fused_node_found, 0) << "Fusion should not occur with swapped inputs";
} else {
// Normal case, fusion should happen
ASSERT_EQ(fused_node_found, 1) << "Fusion should occur with correct input order";
}
}
};

TEST_P(LLMMLPFusionTest, CompareWithRefs) {
if (!ov::with_cpu_x86_avx512_core_amx_bf16())
GTEST_SKIP();
run();
check_results();
check_fusion_result();
}

namespace {

static ov::test::InputShape ishape{ov::PartialShape{-1, -1, 4096 / 4}, {ov::Shape{1, 8, 4096 / 4}, ov::Shape{5, 37, 4096 / 4}}};
static ov::test::InputShape ishape{ov::PartialShape{-1, -1, 4096 / 4},
{ov::Shape{1, 8, 4096 / 4}, ov::Shape{5, 37, 4096 / 4}}};

// Test parameters combining both normal fusion and no-fusion cases
const std::vector<LLMMLPFusionParams> mlp_params = {
{ishape, 4096 / 4, 11008 / 4, "Gelu", false},
{ishape, 4096 / 4, 11008 / 4, "Gelu", true},
{ishape, 4096 / 4, 11008 / 4, "Swish", false},
{ishape, 4096 / 4, 11008 / 4, "Swish", true},
// Normal cases - should fuse (swap_inputs = false)
{ishape, 4096 / 4, 11008 / 4, "Gelu", false, false},
{ishape, 4096 / 4, 11008 / 4, "Gelu", true, false},
{ishape, 4096 / 4, 11008 / 4, "Swish", false, false},
{ishape, 4096 / 4, 11008 / 4, "Swish", true, false},

// Port order issue cases - should NOT fuse (swap_inputs = true)
{ishape, 4096 / 4, 11008 / 4, "Gelu", false, true},
};

INSTANTIATE_TEST_SUITE_P(smoke_LLMMLPFusion,
Expand Down
Loading