Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") {
inputs.emplace_back(ov::op::util::make_try_fold<ov::op::v0::Unsqueeze>(original_weight, const_0));
}

auto fused = ov::op::util::make_try_fold<ov::op::v0::Concat>(inputs, 0);
auto fused = std::make_shared<ov::op::v0::Concat>(inputs, 0);
fused->get_rt_info()["postponed_constant"] = true;
if (needs_decompress) {
auto convert = std::make_shared<ov::op::v0::Convert>(fused, target_type);
ov::mark_as_decompression(convert);
Expand All @@ -418,6 +419,7 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") {
// Extract input and residual nodes from the pattern match
auto view_reshape_node = last_add_match.at(view_Reshape).get_node_shared_ptr();
auto residual_input_node = last_add_match.at(residual_input).get_node_shared_ptr();
auto original_shape_node = last_add_match.at(original_shape).get_node_shared_ptr();

// Build the fused MoE computation
const size_t num_experts = experts.size();
Expand Down Expand Up @@ -489,8 +491,7 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") {
auto final_output = std::make_shared<ov::op::v1::ReduceSum>(weighted_outputs, const_0, false);

// Reshape back to original shape and add residual connection
auto target_shape = std::make_shared<ov::op::v3::ShapeOf>(view_reshape_node, element::i64);
auto final_reshape = std::make_shared<ov::op::v1::Reshape>(final_output, target_shape, false);
auto final_reshape = std::make_shared<ov::op::v1::Reshape>(final_output, original_shape_node, false);
auto final_add = std::make_shared<ov::op::v1::Add>(residual_input_node, final_reshape);

if (last_reshape_node && !last_reshape_node->get_friendly_name().empty()) {
Expand All @@ -515,7 +516,7 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") {
bool ov::pass::FuseMOE::run_on_model(const std::shared_ptr<ov::Model>& model) {
RUN_ON_MODEL_SCOPE(FuseMOE);
ov::pass::Manager manager(get_pass_config(), "FuseMOE");

manager.register_pass<ov::pass::FuseMOEExperts>();
manager.run_passes(model);
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
// todo: enable after plugin support for MoE
// Remove pytestmark to enable e2e test:
// tests/model_hub_tests/transformation_tests/test_moe_transformation.py
// REGISTER_PASS(manager, FuseMOE)
REGISTER_PASS(manager, FuseMOE)

manager.run_passes(f);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def create_synthetic_moe_model(tmp_path, num_layers, num_experts, dtype="float32
def run_moe(tmp_path,
model_id,
model_link,
ie_device):
ie_device,
batch_size=1):
"""
Test that MoE models are loaded with fused expert subgraphs.

Expand All @@ -134,16 +135,23 @@ def run_moe(tmp_path,
contains the characteristic fused MoE pattern.

Additionally verifies output correctness by comparing with original PyTorch model.

Args:
batch_size: Number of sequences to process in parallel (default: 1)
"""
model_cached = snapshot_download(model_id) # required to avoid HF rate limits

# Load original PyTorch model and tokenizer for comparison (from cache to avoid rate limits)
pt_model = AutoModelForCausalLM.from_pretrained(model_cached, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_cached, trust_remote_code=True)

# Prepare test input
test_text = "Test input for MoE model verification"
inputs = tokenizer(test_text, return_tensors="pt")
# Prepare test input with specified batch size
if batch_size == 1:
test_text = "Test input for MoE model verification"
else:
# Create a list of different texts for batch processing
test_text = [f"Test input {i} for MoE model batch verification" for i in range(batch_size)]
inputs = tokenizer(test_text, return_tensors="pt", padding=True)

# Get PyTorch output
with torch.no_grad():
Expand Down Expand Up @@ -173,27 +181,31 @@ def run_moe(tmp_path,
max_diff = np.abs(pt_logits - ov_logits).max()
mean_diff = np.abs(pt_logits - ov_logits).mean()

print(f"Output comparison: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print(f"Output comparison (batch_size={batch_size}): max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")

# Verify outputs are close
# Tolerances: rtol=1e-3, atol=1e-3 account for OpenVINO IR conversion and execution differences
assert np.allclose(pt_logits, ov_logits, rtol=1e-3, atol=1e-3), \
f"Output mismatch between PyTorch and OpenVINO fused model: max_diff={max_diff}, mean_diff={mean_diff}"
f"Output mismatch between PyTorch and OpenVINO fused model (batch_size={batch_size}): max_diff={max_diff}, mean_diff={mean_diff}"


@retry(3, exceptions=(OSError,), delay=1)
def run_moe_synthetic(tmp_path,
num_layers,
num_experts,
dtype,
ie_device):
ie_device,
batch_size=1):
"""
Test MoE fusion on synthetically generated models.

Creates a model from config with specified parameters and verifies
that MoE fusion produces the expected fused pattern.

Additionally verifies output correctness by comparing with original PyTorch model.

Args:
batch_size: Number of sequences to process in parallel (default: 1)
"""
model_path = create_synthetic_moe_model(tmp_path, num_layers, num_experts, dtype)

Expand All @@ -203,9 +215,13 @@ def run_moe_synthetic(tmp_path,
tokenizer_cache = snapshot_download("optimum-internal-testing/tiny-random-qwen3_moe")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_cache, trust_remote_code=True)

# Prepare test input
test_text = "Test input for synthetic MoE model"
inputs = tokenizer(test_text, return_tensors="pt")
# Prepare test input with specified batch size
if batch_size == 1:
test_text = "Test input for synthetic MoE model"
else:
# Create a list of different texts for batch processing
test_text = [f"Synthetic test input {i} for MoE batch validation" for i in range(batch_size)]
inputs = tokenizer(test_text, return_tensors="pt", padding=True)

# Get PyTorch output
with torch.no_grad():
Expand Down Expand Up @@ -240,15 +256,15 @@ def run_moe_synthetic(tmp_path,
max_diff = np.abs(pt_logits - ov_logits).max()
mean_diff = np.abs(pt_logits - ov_logits).mean()

print(f"Synthetic ({num_layers}L, {num_experts}E, {dtype}): max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print(f"Synthetic ({num_layers}L, {num_experts}E, {dtype}, batch={batch_size}): max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")

# Adjust tolerances based on dtype
# FP16/BF16 have lower precision, so allow larger numerical differences
rtol = 1e-2 if dtype in ["float16", "bfloat16"] else 1e-3
atol = 1e-2 if dtype in ["float16", "bfloat16"] else 1e-3

assert np.allclose(pt_logits, ov_logits, rtol=rtol, atol=atol), \
f"Output mismatch for {num_layers}L, {num_experts}E, {dtype}: max_diff={max_diff}, mean_diff={mean_diff}"
f"Output mismatch for {num_layers}L, {num_experts}E, {dtype}, batch={batch_size}: max_diff={max_diff}, mean_diff={mean_diff}"


MOE_PRECOMMIT_TEST_CASES = [
Expand All @@ -270,10 +286,15 @@ def moe_test_idfn(entry):
return retval


# Batch sizes to test
BATCH_SIZES = [1, 2, 4]


@pytest.mark.precommit
@pytest.mark.parametrize("model_info_tuple", MOE_PRECOMMIT_TEST_CASES, ids=moe_test_idfn)
def test_moe_precommit(tmp_path, model_info_tuple, ie_device):
"""Test MoE fusion transformation on precommit models."""
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
def test_moe_precommit(tmp_path, model_info_tuple, ie_device, batch_size):
"""Test MoE fusion transformation on precommit models with different batch sizes."""
model_class, model_name, model_link, mark, reason = model_info_tuple
assert mark is None or mark == 'skip' or mark == 'xfail', \
"Incorrect test case: {}, {}".format(model_name, model_link)
Expand All @@ -282,36 +303,51 @@ def test_moe_precommit(tmp_path, model_info_tuple, ie_device):
elif mark == 'xfail':
pytest.xfail(reason)

run_moe(tmp_path, model_name, model_link, ie_device)
run_moe(tmp_path, model_name, model_link, ie_device, batch_size)


# Synthetic test cases with different configurations
MOE_SYNTHETIC_TEST_CASES = [
# (num_layers, num_experts, dtype)
# Test different numbers of MoE layers
(1, 4, "float32"), # Single MoE layer
(2, 4, "float32"), # Two MoE layers
(3, 4, "float32"), # Three MoE layers

# Test different numbers of experts
(1, 2, "float32"), # Minimal: 2 experts
(1, 8, "float32"), # More experts: 8
(1, 16, "float32"), # Many experts: 16

# Test different dtypes (important for decompression pattern testing)
(1, 4, "float16"), # FP16 - may have decompression
(1, 4, "bfloat16"), # BF16 - may have decompression

# Combined variations
(2, 8, "float16"), # Multiple layers + more experts + FP16
(3, 16, "bfloat16"), # Multiple layers + many experts + BF16
# (num_layers, num_experts, dtype, batch_size)
# Test different numbers of MoE layers with varying batch sizes
(1, 4, "float32", 1), # Single MoE layer, single batch
(1, 4, "float32", 2), # Single MoE layer, small batch
(1, 4, "float32", 4), # Single MoE layer, larger batch
(2, 4, "float32", 1), # Two MoE layers, single batch
(2, 4, "float32", 4), # Two MoE layers, larger batch
(3, 4, "float32", 1), # Three MoE layers, single batch
(3, 4, "float32", 2), # Three MoE layers, small batch

# Test different numbers of experts with batch processing
(1, 2, "float32", 1), # Minimal: 2 experts, single batch
(1, 2, "float32", 4), # Minimal: 2 experts, larger batch
(1, 8, "float32", 1), # More experts: 8, single batch
(1, 8, "float32", 2), # More experts: 8, small batch
(1, 16, "float32", 1), # Many experts: 16, single batch
(1, 16, "float32", 4), # Many experts: 16, larger batch

# Test different dtypes with batch processing (important for decompression pattern testing)
(1, 4, "float16", 1), # FP16 - may have decompression, single batch
(1, 4, "float16", 4), # FP16 - may have decompression, larger batch
(1, 4, "bfloat16", 1), # BF16 - may have decompression, single batch
(1, 4, "bfloat16", 2), # BF16 - may have decompression, small batch

# Combined variations with batch processing
(2, 8, "float16", 1), # Multiple layers + more experts + FP16, single batch
(2, 8, "float16", 4), # Multiple layers + more experts + FP16, larger batch
(3, 16, "bfloat16", 1), # Multiple layers + many experts + BF16, single batch
(3, 16, "bfloat16", 2), # Multiple layers + many experts + BF16, small batch

# Edge cases: larger batch sizes
(1, 4, "float32", 8), # Large batch size
(2, 8, "float32", 8), # Large batch with more complex model
]


def synthetic_test_idfn(entry):
"""Generate test ID for synthetic test cases."""
num_layers, num_experts, dtype = entry
return f"synthetic-l{num_layers}-e{num_experts}-{dtype}"
num_layers, num_experts, dtype, batch_size = entry
return f"synthetic-l{num_layers}-e{num_experts}-{dtype}-b{batch_size}"


@pytest.mark.precommit
Expand All @@ -324,6 +360,7 @@ def test_moe_synthetic(tmp_path, test_params, ie_device):
- Different numbers of MoE layers (1, 2, 3)
- Different numbers of experts (2, 4, 8, 16)
- Different dtypes (float32, float16, bfloat16) to validate decompression handling
- Different batch sizes (1, 2, 4, 8) to ensure fusion works correctly with batched inputs
"""
num_layers, num_experts, dtype = test_params
run_moe_synthetic(tmp_path, num_layers, num_experts, dtype, ie_device)
num_layers, num_experts, dtype, batch_size = test_params
run_moe_synthetic(tmp_path, num_layers, num_experts, dtype, ie_device, batch_size)
Loading