diff --git a/python/sglang/test/test_deepep_utils.py b/python/sglang/test/test_deepep_utils.py index aa15b5a0bb2..9f9baf76f69 100644 --- a/python/sglang/test/test_deepep_utils.py +++ b/python/sglang/test/test_deepep_utils.py @@ -77,8 +77,8 @@ def create_grouped_scores( ): num_tokens, num_experts = scores.shape scores = scores.view(num_tokens, num_groups, -1) - mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device) - mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) + mask = torch.zeros((num_tokens, num_groups, 1), dtype=torch.bool, device=scores.device) + mask.scatter_(1, group_idx.unsqueeze(-1), True) return (scores * mask).view(num_tokens, num_experts)