diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py
index 77902dcf5852..3a9d8c32cadc 100644
--- a/src/diffusers/models/transformers/transformer_hidream_image.py
+++ b/src/diffusers/models/transformers/transformer_hidream_image.py
@@ -389,7 +389,9 @@ def forward(self, x):
     def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
         expert_cache = torch.zeros_like(x)
         idxs = flat_expert_indices.argsort()
-        tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
+        count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts)
+        tokens_per_expert = count_freq.cumsum(dim=0)
+
         token_idxs = idxs // self.num_activated_experts
         for i, end_idx in enumerate(tokens_per_expert):
             start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
diff --git a/tests/models/transformers/test_models_transformer_hidream.py b/tests/models/transformers/test_models_transformer_hidream.py
index fa0fa5123ac8..14336713a358 100644
--- a/tests/models/transformers/test_models_transformer_hidream.py
+++ b/tests/models/transformers/test_models_transformer_hidream.py
@@ -20,6 +20,10 @@
 from diffusers import HiDreamImageTransformer2DModel
 from diffusers.utils.testing_utils import (
     enable_full_determinism,
+    is_torch_compile,
+    require_torch_2,
+    require_torch_gpu,
+    slow,
     torch_device,
 )
 
@@ -94,3 +98,20 @@ def test_set_attn_processor_for_determinism(self):
     def test_gradient_checkpointing_is_applied(self):
         expected_set = {"HiDreamImageTransformer2DModel"}
         super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+    @require_torch_gpu
+    @require_torch_2
+    @is_torch_compile
+    @slow
+    def test_torch_compile_recompilation_and_graph_break(self):
+        torch._dynamo.reset()
+        torch._dynamo.config.capture_dynamic_output_shape_ops = True
+
+        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+        model = self.model_class(**init_dict).to(torch_device)
+        model = torch.compile(model, fullgraph=True)
+
+        with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
+            _ = model(**inputs_dict)
+            _ = model(**inputs_dict)