diff --git a/tests/models/transformers/test_models_transformer_hidream.py b/tests/models/transformers/test_models_transformer_hidream.py index fa0fa5123ac8..3f4f6be5cff9 100644 --- a/tests/models/transformers/test_models_transformer_hidream.py +++ b/tests/models/transformers/test_models_transformer_hidream.py @@ -23,13 +23,13 @@ torch_device, ) -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase): +class HiDreamTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): model_class = HiDreamImageTransformer2DModel main_input_name = "hidden_states" model_split_percents = [0.8, 0.8, 0.9]