diff --git a/tests/test_graphs.py b/tests/test_graphs.py index bc804749c..280dc908d 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -202,8 +202,7 @@ def pytest_train_model_vectoroutput(model_type, overwrite_data=False): @pytest.mark.parametrize( - "model_type", - ["SAGE", "GIN", "GAT", "MFC", "PNA", "SchNet", "DimeNet", "EGNN"] + "model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "SchNet", "DimeNet", "EGNN"] ) def pytest_train_model_conv_head(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data)