diff --git a/openjudge/models/openai_chat_model.py b/openjudge/models/openai_chat_model.py index ac9557365..7eaca294b 100644 --- a/openjudge/models/openai_chat_model.py +++ b/openjudge/models/openai_chat_model.py @@ -82,6 +82,11 @@ def __init__( # Initialize client client_args = client_args or {} + + # Add timeout and max_retries defaults if not explicitly provided + client_args.setdefault("timeout", 60.0) + client_args.setdefault("max_retries", 2) + if api_key: client_args["api_key"] = api_key else: diff --git a/tests/models/test_openai_chat_model.py b/tests/models/test_openai_chat_model.py index 81b5bce78..bc693b092 100644 --- a/tests/models/test_openai_chat_model.py +++ b/tests/models/test_openai_chat_model.py @@ -296,6 +296,37 @@ def test_qwen_omni_audio_formatting(self): "data:;base64,", ) + @patch("openjudge.models.openai_chat_model.AsyncOpenAI") + def test_timeout_and_max_retries_passed(self, mock_async_openai): + """Test that timeout and max_retries are passed to AsyncOpenAI.""" + OpenAIChatModel( + model="gpt-4", + api_key="test-key", + client_args={ + "timeout": 30.0, + "max_retries": 3, + }, + ) + + _, kwargs = mock_async_openai.call_args + + assert kwargs["timeout"] == 30.0 + assert kwargs["max_retries"] == 3 + + @patch("openjudge.models.openai_chat_model.AsyncOpenAI") + def test_default_timeout_and_max_retries(self, mock_async_openai): + """Test that default timeout and max_retries are applied.""" + OpenAIChatModel( + model="gpt-4", + api_key="test-key", + ) + + _, kwargs = mock_async_openai.call_args + + assert kwargs["timeout"] == 60.0 + assert kwargs["max_retries"] == 2 + + if __name__ == "__main__": pytest.main([__file__])