Skip to content

Commit f2436c1

Browse files
committed
Update model initializers to return None if initializer not found
1 parent 3ae4277 commit f2436c1

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

nemoguardrails/llm/models/langchain_initializer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def init_langchain_model(
147147
# For chat mode, fall back to community chat models
148148
ModelInitializer(_init_community_chat_models, ["chat"]),
149149
# For text mode, use text completion
150-
ModelInitializer(_init_text_completion_model, ["text"]),
150+
ModelInitializer(_init_text_completion_model, ["text", "chat"]),
151151
]
152152

153153
# Track the last exception for better error reporting
@@ -235,7 +235,7 @@ def _init_chat_completion_model(
235235

236236
def _init_text_completion_model(
237237
model_name: str, provider_name: str, kwargs: Dict[str, Any]
238-
) -> BaseLLM:
238+
) -> BaseLLM | None:
239239
"""Initialize a text completion model.
240240
241241
Args:
@@ -249,9 +249,14 @@ def _init_text_completion_model(
249249
Raises:
250250
RuntimeError: If the provider is not found
251251
"""
252-
provider_cls = _get_text_completion_provider(provider_name)
252+
try:
253+
provider_cls = _get_text_completion_provider(provider_name)
254+
except RuntimeError as e:
255+
return None
256+
253257
if provider_cls is None:
254-
raise ValueError()
258+
return None
259+
255260
kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
256261
# remove stream_usage parameter as it's not supported by text completion APIs
257262
# (e.g., OpenAI's AsyncCompletions.create() doesn't accept this parameter)
@@ -261,7 +266,7 @@ def _init_text_completion_model(
261266

262267
def _init_community_chat_models(
263268
model_name: str, provider_name: str, kwargs: Dict[str, Any]
264-
) -> BaseChatModel:
269+
) -> BaseChatModel | None:
265270
"""Initialize community chat models.
266271
267272
Args:
@@ -278,14 +283,14 @@ def _init_community_chat_models(
278283
"""
279284
provider_cls = _get_chat_completion_provider(provider_name)
280285
if provider_cls is None:
281-
raise ValueError()
286+
return None
282287
kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
283288
return provider_cls(**kwargs)
284289

285290

286291
def _init_gpt35_turbo_instruct(
287292
model_name: str, provider_name: str, kwargs: Dict[str, Any]
288-
) -> BaseLLM:
293+
) -> BaseLLM | None:
289294
"""Initialize GPT-3.5 Turbo Instruct model.
290295
291296
Currently init_chat_model from langchain infers this as a chat model.

tests/llm_providers/test_langchain_initialization_methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ def test_init_community_chat_models_no_provider(self):
134134
"nemoguardrails.llm.models.langchain_initializer._get_chat_completion_provider"
135135
) as mock_get_provider:
136136
mock_get_provider.return_value = None
137-
with pytest.raises(ValueError):
138-
_init_community_chat_models("community-model", "provider", {})
137+
assert (
138+
_init_community_chat_models("community-model", "provider", {}) is None
139+
)
139140

140141

141142
class TestTextCompletionInitializer:
@@ -178,8 +179,7 @@ def test_init_text_completion_model_no_provider(self):
178179
"nemoguardrails.llm.models.langchain_initializer._get_text_completion_provider"
179180
) as mock_get_provider:
180181
mock_get_provider.return_value = None
181-
with pytest.raises(ValueError):
182-
_init_text_completion_model("text-model", "provider", {})
182+
assert _init_text_completion_model("text-model", "provider", {}) is None
183183

184184

185185
class TestUpdateModelKwargs:

tests/llm_providers/test_langchain_initializer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,13 @@ def test_all_initializers_fail(mock_initializers):
108108
mock_initializers["special"].return_value = None
109109
mock_initializers["chat"].return_value = None
110110
mock_initializers["community"].return_value = None
111+
mock_initializers["text"].return_value = None
111112
with pytest.raises(ModelInitializationError):
112113
init_langchain_model("unknown-model", "provider", "chat", {})
113114
mock_initializers["special"].assert_called_once()
114115
mock_initializers["chat"].assert_called_once()
115116
mock_initializers["community"].assert_called_once()
116-
mock_initializers["text"].assert_not_called()
117+
mock_initializers["text"].assert_called_once()
117118

118119

119120
def test_unsupported_mode(mock_initializers):
@@ -148,7 +149,7 @@ def test_all_initializers_raise_exceptions(mock_initializers):
148149
mock_initializers["special"].assert_called_once()
149150
mock_initializers["chat"].assert_called_once()
150151
mock_initializers["community"].assert_called_once()
151-
mock_initializers["text"].assert_not_called()
152+
mock_initializers["text"].assert_called_once()
152153

153154

154155
def test_duplicate_modes_in_initializer(mock_initializers):

0 commit comments

Comments
 (0)