diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_moe_experts.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_moe_experts.cpp index cbe4b036cc09d8..308008fb5fc266 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_moe_experts.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_moe_experts.cpp @@ -395,7 +395,8 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") { inputs.emplace_back(ov::op::util::make_try_fold(original_weight, const_0)); } - auto fused = ov::op::util::make_try_fold(inputs, 0); + auto fused = std::make_shared(inputs, 0); + fused->get_rt_info()["postponed_constant"] = true; if (needs_decompress) { auto convert = std::make_shared(fused, target_type); ov::mark_as_decompression(convert); @@ -418,6 +419,7 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") { // Extract input and residual nodes from the pattern match auto view_reshape_node = last_add_match.at(view_Reshape).get_node_shared_ptr(); auto residual_input_node = last_add_match.at(residual_input).get_node_shared_ptr(); + auto original_shape_node = last_add_match.at(original_shape).get_node_shared_ptr(); // Build the fused MoE computation const size_t num_experts = experts.size(); @@ -489,8 +491,7 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") { auto final_output = std::make_shared(weighted_outputs, const_0, false); // Reshape back to original shape and add residual connection - auto target_shape = std::make_shared(view_reshape_node, element::i64); - auto final_reshape = std::make_shared(final_output, target_shape, false); + auto final_reshape = std::make_shared(final_output, original_shape_node, false); auto final_add = std::make_shared(residual_input_node, final_reshape); if (last_reshape_node && !last_reshape_node->get_friendly_name().empty()) { @@ -515,7 +516,7 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") { bool ov::pass::FuseMOE::run_on_model(const std::shared_ptr& model) { RUN_ON_MODEL_SCOPE(FuseMOE); ov::pass::Manager manager(get_pass_config(), "FuseMOE"); - + manager.register_pass(); manager.run_passes(model); return false; diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 585a7d0a747f14..5f541f73246af6 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -293,7 +293,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr // todo: enable after plugin support for MoE // Remove pytestmark to enable e2e test: // tests/model_hub_tests/transformation_tests/test_moe_transformation.py - // REGISTER_PASS(manager, FuseMOE) + REGISTER_PASS(manager, FuseMOE) manager.run_passes(f); diff --git a/tests/model_hub_tests/transformation_tests/test_moe_transformation.py b/tests/model_hub_tests/transformation_tests/test_moe_transformation.py index 8b6200081f5e49..d49e396e62feb5 100644 --- a/tests/model_hub_tests/transformation_tests/test_moe_transformation.py +++ b/tests/model_hub_tests/transformation_tests/test_moe_transformation.py @@ -125,7 +125,8 @@ def create_synthetic_moe_model(tmp_path, num_layers, num_experts, dtype="float32 def run_moe(tmp_path, model_id, model_link, - ie_device): + ie_device, + batch_size=1): """ Test that MoE models are loaded with fused expert subgraphs. @@ -134,6 +135,9 @@ def run_moe(tmp_path, contains the characteristic fused MoE pattern. Additionally verifies output correctness by comparing with original PyTorch model. + + Args: + batch_size: Number of sequences to process in parallel (default: 1) """ model_cached = snapshot_download(model_id) # required to avoid HF rate limits @@ -141,9 +145,13 @@ def run_moe(tmp_path, pt_model = AutoModelForCausalLM.from_pretrained(model_cached, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_cached, trust_remote_code=True) - # Prepare test input - test_text = "Test input for MoE model verification" - inputs = tokenizer(test_text, return_tensors="pt") + # Prepare test input with specified batch size + if batch_size == 1: + test_text = "Test input for MoE model verification" + else: + # Create a list of different texts for batch processing + test_text = [f"Test input {i} for MoE model batch verification" for i in range(batch_size)] + inputs = tokenizer(test_text, return_tensors="pt", padding=True) # Get PyTorch output with torch.no_grad(): @@ -173,12 +181,12 @@ def run_moe(tmp_path, max_diff = np.abs(pt_logits - ov_logits).max() mean_diff = np.abs(pt_logits - ov_logits).mean() - print(f"Output comparison: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + print(f"Output comparison (batch_size={batch_size}): max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") # Verify outputs are close # Tolerances: rtol=1e-3, atol=1e-3 account for OpenVINO IR conversion and execution differences assert np.allclose(pt_logits, ov_logits, rtol=1e-3, atol=1e-3), \ - f"Output mismatch between PyTorch and OpenVINO fused model: max_diff={max_diff}, mean_diff={mean_diff}" + f"Output mismatch between PyTorch and OpenVINO fused model (batch_size={batch_size}): max_diff={max_diff}, mean_diff={mean_diff}" @retry(3, exceptions=(OSError,), delay=1) @@ -186,7 +194,8 @@ def run_moe_synthetic(tmp_path, num_layers, num_experts, dtype, - ie_device): + ie_device, + batch_size=1): """ Test MoE fusion on synthetically generated models. @@ -194,6 +203,9 @@ def run_moe_synthetic(tmp_path, that MoE fusion produces the expected fused pattern. Additionally verifies output correctness by comparing with original PyTorch model. + + Args: + batch_size: Number of sequences to process in parallel (default: 1) """ model_path = create_synthetic_moe_model(tmp_path, num_layers, num_experts, dtype) @@ -203,9 +215,13 @@ def run_moe_synthetic(tmp_path, tokenizer_cache = snapshot_download("optimum-internal-testing/tiny-random-qwen3_moe") tokenizer = AutoTokenizer.from_pretrained(tokenizer_cache, trust_remote_code=True) - # Prepare test input - test_text = "Test input for synthetic MoE model" - inputs = tokenizer(test_text, return_tensors="pt") + # Prepare test input with specified batch size + if batch_size == 1: + test_text = "Test input for synthetic MoE model" + else: + # Create a list of different texts for batch processing + test_text = [f"Synthetic test input {i} for MoE batch validation" for i in range(batch_size)] + inputs = tokenizer(test_text, return_tensors="pt", padding=True) # Get PyTorch output with torch.no_grad(): @@ -240,7 +256,7 @@ def run_moe_synthetic(tmp_path, max_diff = np.abs(pt_logits - ov_logits).max() mean_diff = np.abs(pt_logits - ov_logits).mean() - print(f"Synthetic ({num_layers}L, {num_experts}E, {dtype}): max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + print(f"Synthetic ({num_layers}L, {num_experts}E, {dtype}, batch={batch_size}): max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") # Adjust tolerances based on dtype # FP16/BF16 have lower precision, so allow larger numerical differences @@ -248,7 +264,7 @@ def run_moe_synthetic(tmp_path, atol = 1e-2 if dtype in ["float16", "bfloat16"] else 1e-3 assert np.allclose(pt_logits, ov_logits, rtol=rtol, atol=atol), \ - f"Output mismatch for {num_layers}L, {num_experts}E, {dtype}: max_diff={max_diff}, mean_diff={mean_diff}" + f"Output mismatch for {num_layers}L, {num_experts}E, {dtype}, batch={batch_size}: max_diff={max_diff}, mean_diff={mean_diff}" MOE_PRECOMMIT_TEST_CASES = [ @@ -270,10 +286,15 @@ def moe_test_idfn(entry): return retval +# Batch sizes to test +BATCH_SIZES = [1, 2, 4] + + @pytest.mark.precommit @pytest.mark.parametrize("model_info_tuple", MOE_PRECOMMIT_TEST_CASES, ids=moe_test_idfn) -def test_moe_precommit(tmp_path, model_info_tuple, ie_device): - """Test MoE fusion transformation on precommit models.""" +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +def test_moe_precommit(tmp_path, model_info_tuple, ie_device, batch_size): + """Test MoE fusion transformation on precommit models with different batch sizes.""" model_class, model_name, model_link, mark, reason = model_info_tuple assert mark is None or mark == 'skip' or mark == 'xfail', \ "Incorrect test case: {}, {}".format(model_name, model_link) @@ -282,36 +303,51 @@ def test_moe_precommit(tmp_path, model_info_tuple, ie_device): elif mark == 'xfail': pytest.xfail(reason) - run_moe(tmp_path, model_name, model_link, ie_device) + run_moe(tmp_path, model_name, model_link, ie_device, batch_size) # Synthetic test cases with different configurations MOE_SYNTHETIC_TEST_CASES = [ - # (num_layers, num_experts, dtype) - # Test different numbers of MoE layers - (1, 4, "float32"), # Single MoE layer - (2, 4, "float32"), # Two MoE layers - (3, 4, "float32"), # Three MoE layers - - # Test different numbers of experts - (1, 2, "float32"), # Minimal: 2 experts - (1, 8, "float32"), # More experts: 8 - (1, 16, "float32"), # Many experts: 16 - - # Test different dtypes (important for decompression pattern testing) - (1, 4, "float16"), # FP16 - may have decompression - (1, 4, "bfloat16"), # BF16 - may have decompression - - # Combined variations - (2, 8, "float16"), # Multiple layers + more experts + FP16 - (3, 16, "bfloat16"), # Multiple layers + many experts + BF16 + # (num_layers, num_experts, dtype, batch_size) + # Test different numbers of MoE layers with varying batch sizes + (1, 4, "float32", 1), # Single MoE layer, single batch + (1, 4, "float32", 2), # Single MoE layer, small batch + (1, 4, "float32", 4), # Single MoE layer, larger batch + (2, 4, "float32", 1), # Two MoE layers, single batch + (2, 4, "float32", 4), # Two MoE layers, larger batch + (3, 4, "float32", 1), # Three MoE layers, single batch + (3, 4, "float32", 2), # Three MoE layers, small batch + + # Test different numbers of experts with batch processing + (1, 2, "float32", 1), # Minimal: 2 experts, single batch + (1, 2, "float32", 4), # Minimal: 2 experts, larger batch + (1, 8, "float32", 1), # More experts: 8, single batch + (1, 8, "float32", 2), # More experts: 8, small batch + (1, 16, "float32", 1), # Many experts: 16, single batch + (1, 16, "float32", 4), # Many experts: 16, larger batch + + # Test different dtypes with batch processing (important for decompression pattern testing) + (1, 4, "float16", 1), # FP16 - may have decompression, single batch + (1, 4, "float16", 4), # FP16 - may have decompression, larger batch + (1, 4, "bfloat16", 1), # BF16 - may have decompression, single batch + (1, 4, "bfloat16", 2), # BF16 - may have decompression, small batch + + # Combined variations with batch processing + (2, 8, "float16", 1), # Multiple layers + more experts + FP16, single batch + (2, 8, "float16", 4), # Multiple layers + more experts + FP16, larger batch + (3, 16, "bfloat16", 1), # Multiple layers + many experts + BF16, single batch + (3, 16, "bfloat16", 2), # Multiple layers + many experts + BF16, small batch + + # Edge cases: larger batch sizes + (1, 4, "float32", 8), # Large batch size + (2, 8, "float32", 8), # Large batch with more complex model ] def synthetic_test_idfn(entry): """Generate test ID for synthetic test cases.""" - num_layers, num_experts, dtype = entry - return f"synthetic-l{num_layers}-e{num_experts}-{dtype}" + num_layers, num_experts, dtype, batch_size = entry + return f"synthetic-l{num_layers}-e{num_experts}-{dtype}-b{batch_size}" @pytest.mark.precommit @@ -324,6 +360,7 @@ def test_moe_synthetic(tmp_path, test_params, ie_device): - Different numbers of MoE layers (1, 2, 3) - Different numbers of experts (2, 4, 8, 16) - Different dtypes (float32, float16, bfloat16) to validate decompression handling + - Different batch sizes (1, 2, 4, 8) to ensure fusion works correctly with batched inputs """ - num_layers, num_experts, dtype = test_params - run_moe_synthetic(tmp_path, num_layers, num_experts, dtype, ie_device) + num_layers, num_experts, dtype, batch_size = test_params + run_moe_synthetic(tmp_path, num_layers, num_experts, dtype, ie_device, batch_size)