Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def init(self):
self._init_flows_index(),
)

def _extract_user_message_example(self, flow: Flow):
def _extract_user_message_example(self, flow: Flow) -> None:
"""Heuristic to extract user message examples from a flow."""
elements = [
item
Expand Down
8 changes: 4 additions & 4 deletions nemoguardrails/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def to_messages(colang_history: str) -> List[dict]:
# a message from the user, and the rest gets translated to messages from the assistant.
lines = colang_history.split("\n")

bot_lines = []
bot_lines: list[str] = []
for i, line in enumerate(lines):
if line.startswith('user "'):
# If we have bot lines in the buffer, we first add a bot message.
Expand Down Expand Up @@ -191,8 +191,8 @@ def to_messages_v2(colang_history: str) -> List[dict]:
# a message from the user, and the rest gets translated to messages from the assistant.
lines = colang_history.split("\n")

user_lines = []
bot_lines = []
user_lines: list[str] = []
bot_lines: list[str] = []
for line in lines:
if line.startswith("user action:"):
if len(bot_lines) > 0:
Expand Down Expand Up @@ -285,7 +285,7 @@ def verbose_v1(colang_history: str) -> str:
return "\n".join(lines)


def to_chat_messages(events: List[dict]) -> str:
def to_chat_messages(events: List[dict]) -> List[dict]:
"""Filter that turns an array of events into a sequence of user/assistant messages.

Properly handles multimodal content by preserving the structure when the content
Expand Down
23 changes: 12 additions & 11 deletions nemoguardrails/llm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Type, Union
from typing import List, Optional, Type

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM, BaseLLM
from langchain_core.language_models.llms import LLM


def get_llm_instance_wrapper(
llm_instance: Union[LLM, BaseLLM], llm_type: str
) -> Type[LLM]:
def get_llm_instance_wrapper(llm_instance: LLM, llm_type: str) -> Type[LLM]:
"""Wraps an LLM instance in a class that can be registered with LLMRails.

This is useful to create specific types of LLMs using a generic LLM provider
Expand All @@ -47,7 +45,7 @@ def model_kwargs(self):
These are needed to allow changes to the arguments of the LLM calls.
"""
if hasattr(llm_instance, "model_kwargs"):
return llm_instance.model_kwargs
return getattr(llm_instance, "model_kwargs")
return {}

@property
Expand All @@ -66,26 +64,29 @@ def _modify_instance_kwargs(self):
"""

if hasattr(llm_instance, "model_kwargs"):
if isinstance(llm_instance.model_kwargs, dict):
llm_instance.model_kwargs["temperature"] = self.temperature
llm_instance.model_kwargs["streaming"] = self.streaming
model_kwargs = getattr(llm_instance, "model_kwargs")
if isinstance(model_kwargs, dict):
model_kwargs["temperature"] = self.temperature
model_kwargs["streaming"] = self.streaming

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
self._modify_instance_kwargs()
return llm_instance._call(prompt, stop, run_manager)
return llm_instance._call(prompt, stop, run_manager, **kwargs)

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
self._modify_instance_kwargs()
return await llm_instance._acall(prompt, stop, run_manager)
return await llm_instance._acall(prompt, stop, run_manager, **kwargs)

return WrapperLLM
7 changes: 5 additions & 2 deletions nemoguardrails/llm/models/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

from .langchain_initializer import ModelInitializationError, init_langchain_model
from nemoguardrails.llm.models.langchain_initializer import (
ModelInitializationError,
init_langchain_model,
)


# later we can easily conver it to a class
def init_llm_model(
model_name: Optional[str],
model_name: str,
provider_name: str,
mode: Literal["chat", "text"],
kwargs: Dict[str, Any],
Expand Down
15 changes: 8 additions & 7 deletions nemoguardrails/llm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

import logging
from typing import Dict, Type
from typing import Any, Dict, Type

from langchain.base_language import BaseLanguageModel

Expand All @@ -33,18 +33,18 @@ class LLMParams:
def __init__(self, llm: BaseLanguageModel, **kwargs):
self.llm = llm
self.altered_params = kwargs
self.original_params = {}
self.original_params: dict[str, Any] = {}

def __enter__(self):
# Here we can access and modify the global language model parameters.
self.original_params = {}
for param, value in self.altered_params.items():
if hasattr(self.llm, param):
self.original_params[param] = getattr(self.llm, param)
setattr(self.llm, param, value)

elif hasattr(self.llm, "model_kwargs"):
if param not in self.llm.model_kwargs:
model_kwargs = getattr(self.llm, "model_kwargs", {})
if param not in model_kwargs:
log.warning(
"Parameter %s does not exist for %s. Passing to model_kwargs",
param,
Expand All @@ -53,9 +53,10 @@ def __enter__(self):

self.original_params[param] = None
else:
self.original_params[param] = self.llm.model_kwargs[param]
self.original_params[param] = model_kwargs[param]

self.llm.model_kwargs[param] = value
model_kwargs[param] = value
setattr(self.llm, "model_kwargs", model_kwargs)

else:
log.warning(
Expand All @@ -64,7 +65,7 @@ def __enter__(self):
self.llm.__class__.__name__,
)

def __exit__(self, type, value, traceback):
def __exit__(self, exc_type, value, traceback):
# Restore original parameters when exiting the context
for param, value in self.original_params.items():
if hasattr(self.llm, param):
Expand Down
39 changes: 30 additions & 9 deletions nemoguardrails/llm/providers/huggingface/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from typing import Any, List, Optional

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.schema.output import GenerationChunk
from langchain_community.llms import HuggingFacePipeline

# Import HuggingFacePipeline with fallbacks for different LangChain versions
HuggingFacePipeline = None # type: ignore[assignment]

try:
from langchain_community.llms import (
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
)
except ImportError:
# Fallback for older versions of langchain
try:
from langchain.llms import (
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
)
except ImportError:
# Create a dummy class if HuggingFacePipeline is not available
class HuggingFacePipeline: # type: ignore[misc,no-redef]
def __init__(self, *args, **kwargs):
raise ImportError("HuggingFacePipeline is not available")


class HuggingFacePipelineCompatible(HuggingFacePipeline):
Expand All @@ -47,12 +66,13 @@ def _call(
)

# Streaming for NeMo Guardrails is not supported in sync calls.
if self.model_kwargs and self.model_kwargs.get("streaming"):
raise Exception(
model_kwargs = getattr(self, "model_kwargs", {})
if model_kwargs and model_kwargs.get("streaming"):
raise NotImplementedError(
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
)

llm_result = self._generate(
llm_result = getattr(self, "_generate")(
[prompt],
stop=stop,
run_manager=run_manager,
Expand All @@ -78,11 +98,12 @@ async def _acall(
)

# Handle streaming, if the flag is set
if self.model_kwargs and self.model_kwargs.get("streaming"):
model_kwargs = getattr(self, "model_kwargs", {})
if model_kwargs and model_kwargs.get("streaming"):
# Retrieve the streamer object, needs to be set in model_kwargs
streamer = self.model_kwargs.get("streamer")
streamer = model_kwargs.get("streamer")
if not streamer:
raise Exception(
raise ValueError(
"Cannot stream, please add HuggingFace streamer object to model_kwargs!"
)

Expand All @@ -99,7 +120,7 @@ async def _acall(
run_manager=run_manager,
**kwargs,
)
loop.create_task(self._agenerate(**generation_kwargs))
loop.create_task(getattr(self, "_agenerate")(**generation_kwargs))

# And start waiting for the chunks to come in.
completion = ""
Expand All @@ -111,7 +132,7 @@ async def _acall(

return completion

llm_result = await self._agenerate(
llm_result = await getattr(self, "_agenerate")(
[prompt],
stop=stop,
run_manager=run_manager,
Expand Down
26 changes: 22 additions & 4 deletions nemoguardrails/llm/providers/huggingface/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,27 @@
# limitations under the License.

import asyncio
from typing import TYPE_CHECKING, Optional

from transformers.generation.streamers import TextStreamer
TRANSFORMERS_AVAILABLE = True
try:
from transformers.generation.streamers import ( # type: ignore[import-untyped]
TextStreamer,
)
except ImportError:
# Fallback if transformers is not available
TRANSFORMERS_AVAILABLE = False

class TextStreamer: # type: ignore[no-redef]
def __init__(self, *args, **kwargs):
pass

class AsyncTextIteratorStreamer(TextStreamer):

if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore[import-untyped]


class AsyncTextIteratorStreamer(TextStreamer): # type: ignore[misc]
"""
Simple async implementation for HuggingFace Transformers streamers.

Expand All @@ -30,12 +46,14 @@ def __init__(
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = asyncio.Queue()
self.text_queue: asyncio.Queue[str] = asyncio.Queue()
self.stop_signal = None
self.loop = None
self.loop: Optional[asyncio.AbstractEventLoop] = None

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
if self.loop is None:
return
if len(text) > 0:
asyncio.run_coroutine_threadsafe(self.text_queue.put(text), self.loop)

Expand Down
Loading