Skip to content

Commit ee4e724

Browse files
committed
feat(llm): add provider-agnostic parameter mapping system
Implements flexible LLM parameter transformation to support provider-specific naming conventions (e.g., max_tokens -> max_new_tokens for HuggingFace).
1 parent 89225dc commit ee4e724

File tree

7 files changed

+732
-8
lines changed

7 files changed

+732
-8
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
tool_calls_var,
3131
)
3232
from nemoguardrails.integrations.langchain.message_utils import dicts_to_messages
33+
from nemoguardrails.llm.parameter_mapping import transform_llm_params
3334
from nemoguardrails.logging.callbacks import logging_callbacks
3435
from nemoguardrails.logging.explain import LLMCallInfo
3536

@@ -97,9 +98,18 @@ async def llm_call(
9798
_setup_llm_call_info(llm, model_name, model_provider)
9899
all_callbacks = _prepare_callbacks(custom_callback_handlers)
99100

100-
generation_llm: Union[BaseLanguageModel, Runnable] = (
101-
llm.bind(stop=stop, **llm_params) if llm_params and llm is not None else llm
102-
)
101+
if llm_params or stop:
102+
params_to_transform = llm_params.copy() if llm_params else {}
103+
if stop is not None:
104+
params_to_transform["stop"] = stop
105+
transformed_params = transform_llm_params(
106+
params_to_transform, llm, model_provider
107+
)
108+
generation_llm: Union[BaseLanguageModel, Runnable] = llm.bind(
109+
**transformed_params
110+
)
111+
else:
112+
generation_llm: Union[BaseLanguageModel, Runnable] = llm
103113

104114
if isinstance(prompt, str):
105115
response = await _invoke_with_string_prompt(
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Module for transforming LLM parameters between internal and provider-specific formats."""
17+
18+
import logging
19+
import weakref
20+
from typing import Any, Dict, Optional
21+
22+
from langchain.base_language import BaseLanguageModel
23+
24+
log = logging.getLogger(__name__)
25+
26+
# Global registry to store parameter mappings for LLM instances
27+
_llm_parameter_mappings = weakref.WeakKeyDictionary()
28+
29+
PROVIDER_PARAMETER_MAPPINGS = {
30+
"huggingface": {
31+
"max_tokens": "max_new_tokens",
32+
},
33+
"google_vertexai": {
34+
"max_tokens": "max_output_tokens",
35+
},
36+
}
37+
38+
39+
def register_llm_parameter_mapping(
40+
llm: BaseLanguageModel, parameter_mapping: Dict[str, Optional[str]]
41+
) -> None:
42+
"""Register a parameter mapping for a specific LLM instance.
43+
44+
Args:
45+
llm: The LLM instance
46+
parameter_mapping: The parameter mapping dictionary
47+
"""
48+
_llm_parameter_mappings[llm] = parameter_mapping
49+
log.debug(f"Registered parameter mapping for LLM {type(llm).__name__}")
50+
51+
52+
def get_llm_parameter_mapping(
53+
llm: BaseLanguageModel,
54+
) -> Optional[Dict[str, Optional[str]]]:
55+
"""Get the registered parameter mapping for an LLM instance.
56+
57+
Args:
58+
llm: The LLM instance
59+
60+
Returns:
61+
The parameter mapping if registered, None otherwise
62+
"""
63+
return _llm_parameter_mappings.get(llm)
64+
65+
66+
def _infer_provider_from_module(llm: BaseLanguageModel) -> Optional[str]:
67+
"""Infer provider name from the LLM's module path.
68+
69+
This function extracts the provider name from LangChain package naming conventions:
70+
- langchain_openai -> openai
71+
- langchain_anthropic -> anthropic
72+
- langchain_google_genai -> google_genai
73+
- langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints
74+
- langchain_community.chat_models.ollama -> ollama
75+
76+
Args:
77+
llm: The LLM instance
78+
79+
Returns:
80+
The inferred provider name, or None if it cannot be determined
81+
"""
82+
module = type(llm).__module__
83+
84+
if module.startswith("langchain_"):
85+
package = module.split(".")[0]
86+
provider = package.replace("langchain_", "")
87+
88+
if provider == "community":
89+
parts = module.split(".")
90+
if len(parts) >= 3:
91+
provider = parts[-1]
92+
log.debug(
93+
"Inferred provider '%s' from community module %s", provider, module
94+
)
95+
return provider
96+
else:
97+
log.debug("Inferred provider '%s' from module %s", provider, module)
98+
return provider
99+
100+
log.debug(f"Could not infer provider from module {module}")
101+
return None
102+
103+
104+
def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]:
105+
"""Get the provider name for an LLM instance by inferring from module path.
106+
107+
This function extracts the provider name from LangChain package naming conventions.
108+
See _infer_provider_from_module for details on the inference logic.
109+
110+
Args:
111+
llm: The LLM instance
112+
113+
Returns:
114+
The provider name if it can be inferred, None otherwise
115+
"""
116+
return _infer_provider_from_module(llm)
117+
118+
119+
def transform_llm_params(
120+
llm_params: Dict[str, Any],
121+
llm: BaseLanguageModel,
122+
provider: Optional[str] = None,
123+
parameter_mapping: Optional[Dict[str, Optional[str]]] = None,
124+
) -> Dict[str, Any]:
125+
"""Transform LLM parameters using provider-specific or custom mappings.
126+
127+
Args:
128+
llm_params: The original parameters dictionary
129+
llm: The LLM instance to infer provider from
130+
provider: Optional provider name. If None, will be automatically determined from llm.
131+
parameter_mapping: Custom mapping dictionary. If None, uses built-in provider mappings.
132+
Key is the internal parameter name, value is the provider parameter name.
133+
If value is None, the parameter is dropped.
134+
135+
Returns:
136+
Transformed parameters dictionary
137+
"""
138+
if not llm_params:
139+
return llm_params
140+
141+
mapping = parameter_mapping
142+
if mapping is None:
143+
mapping = get_llm_parameter_mapping(llm)
144+
if mapping:
145+
log.debug("Using registered parameter mapping for LLM instance")
146+
else:
147+
if provider is None:
148+
provider = get_llm_provider(llm)
149+
150+
if provider and provider in PROVIDER_PARAMETER_MAPPINGS:
151+
mapping = PROVIDER_PARAMETER_MAPPINGS[provider]
152+
log.debug("Using built-in parameter mapping for provider: %s", provider)
153+
else:
154+
return llm_params
155+
156+
if not mapping:
157+
return llm_params
158+
159+
transformed_params = {}
160+
161+
for param_name, param_value in llm_params.items():
162+
if param_name in mapping:
163+
mapped_name = mapping[param_name]
164+
if mapped_name is not None:
165+
log.debug("Mapped parameter %s -> %s", param_name, mapped_name)
166+
else:
167+
log.debug("Dropped parameter %s", param_name)
168+
else:
169+
transformed_params[param_name] = param_value
170+
171+
return transformed_params

nemoguardrails/rails/llm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ class Model(BaseModel):
123123
description="Configuration parameters for reasoning LLMs.",
124124
)
125125
parameters: Dict[str, Any] = Field(default_factory=dict)
126+
parameter_mapping: Optional[Dict[str, Optional[str]]] = Field(
127+
default=None,
128+
description="Optional parameter mapping to transform parameter names for provider-specific requirements. "
129+
"Keys are internal parameter names, values are provider parameter names. "
130+
"Set value to null to drop a parameter.",
131+
)
126132

127133
mode: Literal["chat", "text"] = Field(
128134
default="chat",

nemoguardrails/rails/llm/llmrails.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
ModelInitializationError,
7575
init_llm_model,
7676
)
77+
from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping
7778
from nemoguardrails.logging.explain import ExplainInfo
7879
from nemoguardrails.logging.processing_log import compute_generation_log
7980
from nemoguardrails.logging.stats import LLMStats
@@ -443,11 +444,19 @@ def _init_llms(self):
443444
if self.llm:
444445
# If an LLM was provided via constructor, use it as the main LLM
445446
# Log a warning if a main LLM is also specified in the config
446-
if any(model.type == "main" for model in self.config.models):
447+
main_model = next(
448+
(model for model in self.config.models if model.type == "main"), None
449+
)
450+
if main_model:
447451
log.warning(
448452
"Both an LLM was provided via constructor and a main LLM is specified in the config. "
449453
"The LLM provided via constructor will be used and the main LLM from config will be ignored."
450454
)
455+
# Still register parameter mapping from config if available
456+
if main_model.parameter_mapping:
457+
register_llm_parameter_mapping(
458+
self.llm, main_model.parameter_mapping
459+
)
451460
self.runtime.register_action_param("llm", self.llm)
452461

453462
self._configure_main_llm_streaming(self.llm)
@@ -465,6 +474,10 @@ def _init_llms(self):
465474
mode="chat",
466475
kwargs=kwargs,
467476
)
477+
if main_model.parameter_mapping:
478+
register_llm_parameter_mapping(
479+
self.llm, main_model.parameter_mapping
480+
)
468481
self.runtime.register_action_param("llm", self.llm)
469482

470483
self._configure_main_llm_streaming(
@@ -500,6 +513,10 @@ def _init_llms(self):
500513
kwargs=kwargs,
501514
)
502515

516+
if llm_config.parameter_mapping:
517+
register_llm_parameter_mapping(
518+
llm_model, llm_config.parameter_mapping
519+
)
503520
if llm_config.type == "main":
504521
# If a main LLM was already injected, skip creating another
505522
# one. Otherwise, create and register it.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for LLM parameter mapping integration in llm_call function."""
17+
18+
from unittest.mock import AsyncMock, Mock
19+
20+
import pytest
21+
22+
from nemoguardrails.actions.llm.utils import llm_call
23+
24+
25+
class MockResponse:
26+
"""Mock response object."""
27+
28+
def __init__(self, content="Test response"):
29+
self.content = content
30+
31+
32+
class MockHuggingFaceLLM:
33+
"""Mock HuggingFace LLM for testing parameter mapping."""
34+
35+
__module__ = "langchain_huggingface.llms"
36+
37+
def __init__(self):
38+
self.bind = Mock(return_value=self)
39+
self.ainvoke = AsyncMock(return_value=MockResponse())
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_llm_call_with_registered_parameter_mapping():
44+
"""Test llm_call applies registered parameter mapping correctly."""
45+
from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping
46+
47+
mock_llm = MockHuggingFaceLLM()
48+
register_llm_parameter_mapping(mock_llm, {"max_tokens": "max_new_tokens"})
49+
50+
result = await llm_call(
51+
llm=mock_llm,
52+
prompt="Test prompt",
53+
llm_params={"max_tokens": 100, "temperature": 0.5},
54+
)
55+
56+
mock_llm.bind.assert_called_once_with(max_new_tokens=100, temperature=0.5)
57+
assert result == "Test response"
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_llm_call_with_builtin_mapping():
62+
"""Test llm_call uses built-in provider mapping when no custom mapping provided."""
63+
mock_llm = MockHuggingFaceLLM()
64+
65+
result = await llm_call(
66+
llm=mock_llm,
67+
prompt="Test prompt",
68+
llm_params={"max_tokens": 50, "temperature": 0.7},
69+
)
70+
71+
mock_llm.bind.assert_called_once_with(max_new_tokens=50, temperature=0.7)
72+
assert result == "Test response"
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_llm_call_with_dropped_parameter():
77+
"""Test llm_call drops parameters mapped to None."""
78+
from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping
79+
80+
mock_llm = MockHuggingFaceLLM()
81+
register_llm_parameter_mapping(
82+
mock_llm, {"max_tokens": "max_new_tokens", "unsupported_param": None}
83+
)
84+
85+
result = await llm_call(
86+
llm=mock_llm,
87+
prompt="Test prompt",
88+
llm_params={"max_tokens": 100, "unsupported_param": "value"},
89+
)
90+
91+
mock_llm.bind.assert_called_once_with(max_new_tokens=100)
92+
assert result == "Test response"
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_llm_call_without_params():
97+
"""Test llm_call works without llm_params."""
98+
mock_llm = MockHuggingFaceLLM()
99+
100+
result = await llm_call(llm=mock_llm, prompt="Test prompt")
101+
102+
mock_llm.bind.assert_not_called()
103+
mock_llm.ainvoke.assert_called_once()
104+
assert result == "Test response"
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_llm_call_with_stop_tokens():
109+
"""Test llm_call handles stop tokens correctly with parameter mapping."""
110+
from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping
111+
112+
mock_llm = MockHuggingFaceLLM()
113+
register_llm_parameter_mapping(mock_llm, {"max_tokens": "max_new_tokens"})
114+
115+
result = await llm_call(
116+
llm=mock_llm,
117+
prompt="Test prompt",
118+
stop=["END", "STOP"],
119+
llm_params={"max_tokens": 100},
120+
)
121+
122+
mock_llm.bind.assert_called_once_with(stop=["END", "STOP"], max_new_tokens=100)
123+
assert result == "Test response"

0 commit comments

Comments
 (0)