From 17dc66e776cd6e780009faec3d70612e9bdcaaf5 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:43:53 +0100 Subject: [PATCH 1/9] refactor(streaming): simplify streaming support validation - Remove `streaming` and `streaming_supported` properties from RailsConfig - Add `StreamingNotSupportedError` exception for clearer error handling - Move streaming validation from config-time to runtime in LLMRails - Remove `main_llm_supports_streaming` flag and related logic - Update CLI to catch StreamingNotSupportedError instead of pre-checking - Simplify server API streaming logic to use stream_async directly - Configure LLM streaming only when stream_async is called - Remove redundant streaming warnings and fallback logic BREAKING CHANGE: `RailsConfig.streaming` and `RailsConfig.streaming_supported` properties have been removed. Streaming support is now validated at runtime when `stream_async()` is called. --- nemoguardrails/cli/chat.py | 40 +++++++++++++++------------- nemoguardrails/rails/llm/config.py | 19 ------------- nemoguardrails/rails/llm/llmrails.py | 34 +++++------------------ nemoguardrails/server/api.py | 18 ++++--------- 4 files changed, 32 insertions(+), 79 deletions(-) diff --git a/nemoguardrails/cli/chat.py b/nemoguardrails/cli/chat.py index 1561159c7..ec1b09f29 100644 --- a/nemoguardrails/cli/chat.py +++ b/nemoguardrails/cli/chat.py @@ -28,6 +28,7 @@ from nemoguardrails.colang.v2_x.runtime.eval import eval_expression from nemoguardrails.colang.v2_x.runtime.flows import State from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x +from nemoguardrails.exceptions import InvalidRailsConfigurationError from nemoguardrails.logging import verbose from nemoguardrails.logging.verbose import console from nemoguardrails.rails.llm.options import ( @@ -65,11 +66,6 @@ async def _run_chat_v1_0( raise RuntimeError("config_path cannot be None when server_url is None") rails_config = RailsConfig.from_path(config_path) rails_app = LLMRails(rails_config, verbose=verbose) - if streaming and not rails_config.streaming_supported: - console.print( - f"WARNING: The config `{config_path}` does not support streaming. Falling back to normal mode." - ) - streaming = False else: rails_app = None @@ -83,19 +79,25 @@ async def _run_chat_v1_0( if not server_url: # If we have streaming from a locally loaded config, we initialize the handler. - if streaming and not server_url and rails_app and rails_app.main_llm_supports_streaming: - bot_message_list = [] - async for chunk in rails_app.stream_async(messages=history): - if '{"event": "ABORT"' in chunk: - dict_chunk = json.loads(chunk) - console.print("\n\n[red]" + f"ABORT streaming. {dict_chunk['data']}" + "[/]") - break - - console.print("[green]" + f"{chunk}" + "[/]", end="") - bot_message_list.append(chunk) - - bot_message_text = "".join(bot_message_list) - bot_message = {"role": "assistant", "content": bot_message_text} + if streaming and not server_url and rails_app: + try: + bot_message_list = [] + async for chunk in rails_app.stream_async(messages=history): + if '{"event": "ABORT"' in chunk: + dict_chunk = json.loads(chunk) + console.print("\n\n[red]" + f"ABORT streaming. {dict_chunk['data']}" + "[/]") + break + + console.print("[green]" + f"{chunk}" + "[/]", end="") + bot_message_list.append(chunk) + + bot_message_text = "".join(bot_message_list) + bot_message = {"role": "assistant", "content": bot_message_text} + except InvalidRailsConfigurationError as e: + # TODO: improve this error message + raise InvalidRailsConfigurationError( + f"The config `{config_path}` does not support streaming. {e}" + ) from e else: if rails_app is None: @@ -124,7 +126,7 @@ async def _run_chat_v1_0( # String or other fallback case bot_message = {"role": "assistant", "content": str(response)} - if not streaming or not rails_app.main_llm_supports_streaming: + if not streaming: # We print bot messages in green. content = bot_message.get("content", str(bot_message)) console.print("[green]" + f"{content}" + "[/]") diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index c3909fafa..4fe23799e 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -1373,11 +1373,6 @@ class RailsConfig(BaseModel): description="Configuration for the various rails (input, output, etc.).", ) - streaming: bool = Field( - default=False, - description="Whether this configuration should use streaming mode or not.", - ) - enable_rails_exceptions: bool = Field( default=False, description="If set, the pre-defined guardrails raise exceptions instead of returning pre-defined messages.", @@ -1665,20 +1660,6 @@ def parse_object(cls, obj): return cls.parse_obj(obj) - @property - def streaming_supported(self): - """Whether the current config supports streaming or not.""" - - if len(self.rails.output.flows) > 0: - # if we have output rails streaming enabled - # we keep it in case it was needed when we have - # support per rails - if self.rails.output.streaming and self.rails.output.streaming.enabled: - return True - return False - - return True - def __add__(self, other): """Adds two RailsConfig objects.""" return _join_rails_configs(self, other) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 0833710fa..934c8cc61 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -155,9 +155,6 @@ def __init__( # should be removed self.events_history_cache = {} - # Weather the main LLM supports streaming - self.main_llm_supports_streaming = False - # We also load the default flows from the `default_flows.yml` file in the current folder. # But only for version 1.0. # TODO: decide on the default flows for 2.x. @@ -377,10 +374,9 @@ def _prepare_model_kwargs(self, model_config): if api_key: kwargs["api_key"] = api_key - # enable streaming token usage when streaming is enabled + # enable streaming token usage # providers that don't support this parameter will simply ignore it - if self.config.streaming: - kwargs["stream_usage"] = True + kwargs["stream_usage"] = True return kwargs @@ -398,22 +394,9 @@ def _configure_main_llm_streaming( provider_name (Optional[str], optional): Optional provider name for logging. """ - if not self.config.streaming: - return if hasattr(llm, "streaming"): setattr(llm, "streaming", True) - self.main_llm_supports_streaming = True - else: - self.main_llm_supports_streaming = False - if model_name and provider_name: - log.warning( - "Model %s from provider %s does not support streaming.", - model_name, - provider_name, - ) - else: - log.warning("Provided main LLM does not support streaming.") def _init_llms(self): """ @@ -442,7 +425,6 @@ def _init_llms(self): ) self.runtime.register_action_param("llm", self.llm) - self._configure_main_llm_streaming(self.llm) else: # Otherwise, initialize the main LLM from the config main_model = next((model for model in self.config.models if model.type == "main"), None) @@ -457,11 +439,6 @@ def _init_llms(self): ) self.runtime.register_action_param("llm", self.llm) - self._configure_main_llm_streaming( - self.llm, - model_name=main_model.model, - provider_name=main_model.engine, - ) else: log.warning("No main LLM specified in the config and no LLM provided via constructor.") @@ -1190,10 +1167,9 @@ def _validate_streaming_with_output_rails(self) -> None: not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled ): raise InvalidRailsConfigurationError( - "stream_async() cannot be used when output rails are configured but " + "Streaming cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or use " - "generate_async() instead of stream_async()." + "rails.output.streaming.enabled to True in your configuration, or disable streaming." ) @overload @@ -1246,6 +1222,8 @@ def stream_async( streaming_handler = StreamingHandler(include_generation_metadata=include_generation_metadata) + self._configure_main_llm_streaming(self.llm) # type: ignore + # Create a properly managed task with exception handling async def _generation_task(): try: diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 658cffd01..9bfa0f78c 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -37,7 +37,6 @@ GenerationResponse, ) from nemoguardrails.server.datastore.datastore import DataStore -from nemoguardrails.streaming import StreamingHandler logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) @@ -426,18 +425,11 @@ async def chat_completion(body: RequestBody, request: Request): # And prepend them. messages = thread_messages + messages - if body.stream and llm_rails.config.streaming_supported and llm_rails.main_llm_supports_streaming: - # Create the streaming handler instance - streaming_handler = StreamingHandler() - - # Start the generation - asyncio.create_task( - llm_rails.generate_async( - messages=messages, - streaming_handler=streaming_handler, - options=body.options, - state=body.state, - ) + if body.stream: + streaming_handler = llm_rails.stream_async( + messages=messages, + options=body.options, + state=body.state, ) # TODO: Add support for thread_ids in streaming mode From 980979d7d71d08a3358f4e877a064114f5ce91ae Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:48:53 +0100 Subject: [PATCH 2/9] test: update tests for streaming refactor - Remove RailsConfig.streaming references from all tests - Update streaming tests to use StreamingNotSupportedError - Remove tests for deprecated streaming_supported property - Remove tests for main_llm_supports_streaming flag - Update CLI tests to verify StreamingNotSupportedError handling - Simplify test fixtures to remove streaming config parameters - Update token usage tests to reflect stream_usage always enabled - Fix test mocks to align with new streaming validation approach --- tests/cli/test_chat.py | 27 ++-- tests/cli/test_chat_v2x_integration.py | 17 +- tests/rails/llm/test_config.py | 59 +------ tests/runnable_rails/test_batching.py | 2 +- tests/test_llm_params_e2e.py | 10 +- tests/test_llmrails.py | 69 +------- tests/test_parallel_streaming_output_rails.py | 10 +- tests/test_streaming.py | 152 ++---------------- tests/test_streaming_output_rails.py | 8 +- tests/test_token_usage_integration.py | 41 ----- tests/utils.py | 11 +- 11 files changed, 58 insertions(+), 348 deletions(-) diff --git a/tests/cli/test_chat.py b/tests/cli/test_chat.py index 0d3548975..0d4a14b65 100644 --- a/tests/cli/test_chat.py +++ b/tests/cli/test_chat.py @@ -184,12 +184,10 @@ async def test_run_chat_v1_local_config(self, mock_rails_config, mock_llm_rails, from nemoguardrails.cli.chat import _run_chat_v1_0 mock_config = MagicMock() - mock_config.streaming_supported = False mock_rails_config.from_path.return_value = mock_config mock_rails = AsyncMock() mock_rails.generate_async = AsyncMock(return_value={"role": "assistant", "content": "Hello!"}) - mock_rails.main_llm_supports_streaming = False mock_llm_rails.return_value = mock_rails mock_input.side_effect = ["test message", KeyboardInterrupt()] @@ -203,31 +201,28 @@ async def test_run_chat_v1_local_config(self, mock_rails_config, mock_llm_rails, @pytest.mark.asyncio @patch("builtins.input") - @patch.object(chat_module, "console") @patch.object(chat_module, "LLMRails") @patch.object(chat_module, "RailsConfig") - async def test_run_chat_v1_streaming_not_supported( - self, mock_rails_config, mock_llm_rails, mock_console, mock_input - ): + async def test_run_chat_v1_streaming_not_supported(self, mock_rails_config, mock_llm_rails, mock_input): from nemoguardrails.cli.chat import _run_chat_v1_0 + from nemoguardrails.exceptions import InvalidRailsConfigurationError mock_config = MagicMock() - mock_config.streaming_supported = False mock_rails_config.from_path.return_value = mock_config - mock_rails = AsyncMock() + mock_rails = MagicMock() + + async def mock_stream_async_generator(*args, **kwargs): + raise InvalidRailsConfigurationError("Streaming not supported") + yield + + mock_rails.stream_async = mock_stream_async_generator mock_llm_rails.return_value = mock_rails - mock_input.side_effect = [KeyboardInterrupt()] + mock_input.side_effect = ["test message"] - try: + with pytest.raises(InvalidRailsConfigurationError): await _run_chat_v1_0(config_path="test_config", streaming=True) - except KeyboardInterrupt: - pass - - mock_console.print.assert_any_call( - "WARNING: The config `test_config` does not support streaming. Falling back to normal mode." - ) @pytest.mark.asyncio @patch("aiohttp.ClientSession") diff --git a/tests/cli/test_chat_v2x_integration.py b/tests/cli/test_chat_v2x_integration.py index 2adb586cb..99c21f36a 100644 --- a/tests/cli/test_chat_v2x_integration.py +++ b/tests/cli/test_chat_v2x_integration.py @@ -144,10 +144,10 @@ async def test_chat_v2x_with_real_llm(self): This requires LIVE_TEST_MODE=1 and OpenAI API key. """ - from unittest.mock import patch + from unittest.mock import MagicMock, patch from nemoguardrails import LLMRails, RailsConfig - from nemoguardrails.cli.chat import _run_chat_v2_x + from nemoguardrails.cli.chat import ChatState, _run_chat_v2_x config = RailsConfig.from_content( """ @@ -171,13 +171,22 @@ async def test_chat_v2x_with_real_llm(self): simulated_input = ["hi", "exit"] input_iter = iter(simulated_input) - def mock_input(*args, **kwargs): + async def mock_prompt_async(*args, **kwargs): try: return next(input_iter) except StopIteration: raise KeyboardInterrupt() - with patch("builtins.input", side_effect=mock_input): + mock_session = MagicMock() + mock_session.prompt_async = mock_prompt_async + + original_init = ChatState.__init__ + + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self.session = mock_session + + with patch.object(ChatState, "__init__", patched_init): try: await _run_chat_v2_x(rails) except (KeyboardInterrupt, StopIteration): diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index b40cf2876..1ad6699e6 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -195,21 +195,18 @@ def test_rails_config_simple_field_overwriting(): """Tests that fields from the second config overwrite fields from the first config.""" config1 = RailsConfig( models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], - streaming=False, lowest_temperature=0.1, colang_version="1.0", ) config2 = RailsConfig( models=[Model(type="secondary", engine="anthropic", model="claude-3")], - streaming=True, lowest_temperature=0.5, colang_version="2.x", ) result = config1 + config2 - assert result.streaming is True assert result.lowest_temperature == 0.5 assert result.colang_version == "2.x" @@ -304,12 +301,11 @@ def test_rails_config_none_config_path(): def test_llm_rails_configure_streaming_with_attr(): - """Check LLM has the streaming attribute set if RailsConfig has it""" + """Check LLM has the streaming attribute set when _configure_main_llm_streaming is called""" mock_llm = MagicMock(spec=BaseLLM) config = RailsConfig( models=[], - streaming=True, ) rails = LLMRails(config, llm=mock_llm) @@ -317,56 +313,3 @@ def test_llm_rails_configure_streaming_with_attr(): rails._configure_main_llm_streaming(llm=mock_llm) assert mock_llm.streaming - - -def test_llm_rails_configure_streaming_without_attr(caplog): - """Check LLM has the streaming attribute set if RailsConfig has it""" - - mock_llm = MagicMock(spec=BaseLLM) - config = RailsConfig( - models=[], - streaming=True, - ) - - rails = LLMRails(config, llm=mock_llm) - rails._configure_main_llm_streaming(mock_llm) - - assert caplog.messages[-1] == "Provided main LLM does not support streaming." - - -def test_rails_config_streaming_supported_no_output_flows(): - """Check `streaming_supported` property doesn't depend on RailsConfig.streaming with no output flows""" - - config = RailsConfig( - models=[], - streaming=False, - ) - assert config.streaming_supported - - -def test_rails_config_flows_streaming_supported_true(): - """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" - - rails = { - "output": { - "flows": ["content_safety_check_output"], - "streaming": {"enabled": True}, - } - } - prompts = [{"task": "content safety check output", "content": "..."}] - rails_config = RailsConfig.model_validate({"models": [], "rails": rails, "prompts": prompts}) - assert rails_config.streaming_supported - - -def test_rails_config_flows_streaming_supported_false(): - """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" - - rails = { - "output": { - "flows": ["content_safety_check_output"], - "streaming": {"enabled": False}, - } - } - prompts = [{"task": "content safety check output", "content": "..."}] - rails_config = RailsConfig.model_validate({"models": [], "rails": rails, "prompts": prompts}) - assert not rails_config.streaming_supported diff --git a/tests/runnable_rails/test_batching.py b/tests/runnable_rails/test_batching.py index f8bed8704..d1fb2c7c8 100644 --- a/tests/runnable_rails/test_batching.py +++ b/tests/runnable_rails/test_batching.py @@ -130,7 +130,7 @@ async def test_astream_output(): ], streaming=True, ) - config = RailsConfig.from_content(config={"models": [], "streaming": True}) + config = RailsConfig.from_content(config={"models": []}) model_with_rails = RunnableRails(config, llm=llm) # Collect all chunks from the stream diff --git a/tests/test_llm_params_e2e.py b/tests/test_llm_params_e2e.py index 306f4de37..772ab48bf 100644 --- a/tests/test_llm_params_e2e.py +++ b/tests/test_llm_params_e2e.py @@ -182,18 +182,18 @@ async def test_openai_llm_params_direct_llm_call(self, openai_config_path): async def test_openai_llm_params_streaming(self, openai_config_path): """Test llm_params work with streaming responses from OpenAI.""" config = RailsConfig.from_path(openai_config_path) - config.streaming = True rails = LLMRails(config, verbose=False) prompt = "Count from 1 to 3." - response = await rails.generate_async( + chunks = [] + async for chunk in rails.stream_async( messages=[{"role": "user", "content": prompt}], options={"llm_params": {"temperature": 0.0, "max_tokens": 20}}, - ) + ): + chunks.append(chunk) - assert response.response is not None - content = response.response[-1]["content"] + content = "".join(chunks) assert "1" in content @pytest.mark.asyncio diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 5e918232b..a47b07f49 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -965,10 +965,8 @@ def __init__(self): @pytest.mark.asyncio @patch("nemoguardrails.rails.llm.llmrails.init_llm_model") -async def test_stream_usage_enabled_for_streaming_supported_providers( - mock_init_llm_model, -): - """Test that stream_usage=True is set when streaming is enabled for supported providers.""" +async def test_stream_usage_always_enabled(mock_init_llm_model): + """Test that stream_usage=True is always set for LLM models.""" config = RailsConfig.from_content( config={ "models": [ @@ -978,7 +976,6 @@ async def test_stream_usage_enabled_for_streaming_supported_providers( "model": "gpt-4", } ], - "streaming": True, } ) @@ -991,68 +988,6 @@ async def test_stream_usage_enabled_for_streaming_supported_providers( assert kwargs.get("stream_usage") is True -@pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") -async def test_stream_usage_not_set_without_streaming(mock_init_llm_model): - """Test that stream_usage is not set when streaming is disabled.""" - config = RailsConfig.from_content( - config={ - "models": [ - { - "type": "main", - "engine": "openai", - "model": "gpt-4", - } - ], - "streaming": False, - } - ) - - LLMRails(config=config) - - mock_init_llm_model.assert_called_once() - call_args = mock_init_llm_model.call_args - kwargs = call_args.kwargs.get("kwargs", {}) - - assert "stream_usage" not in kwargs - - -@pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") -async def test_stream_usage_enabled_for_all_providers_when_streaming( - mock_init_llm_model, -): - """Test that stream_usage is passed to ALL providers when streaming is enabled. - - With the new design, stream_usage=True is passed to ALL providers when - streaming is enabled. Providers that don't support it will simply ignore it. - """ - config = RailsConfig.from_content( - config={ - "models": [ - { - "type": "main", - "engine": "unsupported", - "model": "whatever", - } - ], - "streaming": True, - } - ) - - LLMRails(config=config) - - mock_init_llm_model.assert_called_once() - call_args = mock_init_llm_model.call_args - kwargs = call_args.kwargs.get("kwargs", {}) - - # stream_usage should be set for all providers when streaming is enabled - assert kwargs.get("stream_usage") is True - - -# Add this test after the existing tests, around line 1100+ - - def test_register_methods_return_self(): """Test that all register_* methods return self for method chaining.""" config = RailsConfig.from_content(config={"models": []}) diff --git a/tests/test_parallel_streaming_output_rails.py b/tests/test_parallel_streaming_output_rails.py index c99e82636..19ce8e1c5 100644 --- a/tests/test_parallel_streaming_output_rails.py +++ b/tests/test_parallel_streaming_output_rails.py @@ -24,6 +24,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action +from nemoguardrails.exceptions import InvalidRailsConfigurationError from tests.utils import TestChat @@ -585,15 +586,14 @@ async def test_parallel_streaming_output_rails_default_config_behavior( llmrails = LLMRails(parallel_output_rails_default_config) - with pytest.raises(ValueError) as exc_info: - async for chunk in llmrails.stream_async(messages=[{"role": "user", "content": "Hi!"}]): + with pytest.raises(InvalidRailsConfigurationError) as exc_info: + async for _ in llmrails.stream_async(messages=[{"role": "user", "content": "Hi!"}]): pass assert str(exc_info.value) == ( - "stream_async() cannot be used when output rails are configured but " + "Streaming cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or use " - "generate_async() instead of stream_async()." + "rails.output.streaming.enabled to True in your configuration, or disable streaming." ) await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c7f59a7d1..04ec75197 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -19,15 +19,16 @@ import pytest -from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails import RailsConfig from nemoguardrails.actions import action +from nemoguardrails.exceptions import InvalidRailsConfigurationError from nemoguardrails.streaming import StreamingHandler -from tests.utils import FakeLLM, TestChat +from tests.utils import TestChat @pytest.fixture def chat_1(): - config: RailsConfig = RailsConfig.from_content(config={"models": [], "streaming": True}) + config: RailsConfig = RailsConfig.from_content(config={"models": []}) return TestChat( config, llm_completions=[ @@ -77,7 +78,7 @@ async def test_stream_async_api(chat_1): async def test_streaming_predefined_messages(): """Predefined messages should be streamed as a single chunk.""" config: RailsConfig = RailsConfig.from_content( - config={"models": [], "streaming": True}, + config={"models": []}, colang_content=""" define user express greeting "hi" @@ -110,7 +111,7 @@ async def test_streaming_predefined_messages(): async def test_streaming_dynamic_bot_message(): """Predefined messages should be streamed as a single chunk.""" config: RailsConfig = RailsConfig.from_content( - config={"models": [], "streaming": True}, + config={"models": []}, colang_content=""" define user express greeting "hi" @@ -146,7 +147,6 @@ async def test_streaming_single_llm_call(): config={ "models": [], "rails": {"dialog": {"single_call": {"enabled": True}}}, - "streaming": True, }, colang_content=""" define user express greeting @@ -180,7 +180,6 @@ async def test_streaming_single_llm_call_with_message_override(): config={ "models": [], "rails": {"dialog": {"single_call": {"enabled": True}}}, - "streaming": True, }, colang_content=""" define user express greeting @@ -219,7 +218,6 @@ async def test_streaming_single_llm_call_with_next_step_override_and_dynamic_mes config={ "models": [], "rails": {"dialog": {"single_call": {"enabled": True}}}, - "streaming": True, }, colang_content=""" define user express greeting @@ -267,7 +265,6 @@ def output_rails_streaming_config(): }, } }, - "streaming": True, "prompts": [{"task": "self_check_output", "content": "a test template"}], }, colang_content=""" @@ -479,7 +476,6 @@ async def test_streaming_with_output_rails_disabled_raises_error(): }, } }, - "streaming": True, "prompts": [{"task": "self_check_output", "content": "a test template"}], }, colang_content=""" @@ -498,17 +494,16 @@ async def test_streaming_with_output_rails_disabled_raises_error(): streaming=True, ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(InvalidRailsConfigurationError) as exc_info: async for chunk in chat.app.stream_async( messages=[{"role": "user", "content": "Hi!"}], ): pass assert str(exc_info.value) == ( - "stream_async() cannot be used when output rails are configured but " + "Streaming cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or use " - "generate_async() instead of stream_async()." + "rails.output.streaming.enabled to True in your configuration, or disable streaming." ) @@ -522,7 +517,6 @@ async def test_streaming_with_output_rails_no_streaming_config_raises_error(): "flows": {"self check output"}, } }, - "streaming": True, "prompts": [{"task": "self_check_output", "content": "a test template"}], }, colang_content=""" @@ -541,17 +535,16 @@ async def test_streaming_with_output_rails_no_streaming_config_raises_error(): streaming=True, ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(InvalidRailsConfigurationError) as exc_info: async for chunk in chat.app.stream_async( messages=[{"role": "user", "content": "Hi!"}], ): pass assert str(exc_info.value) == ( - "stream_async() cannot be used when output rails are configured but " + "Streaming cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or use " - "generate_async() instead of stream_async()." + "rails.output.streaming.enabled to True in your configuration, or disable streaming." ) @@ -568,7 +561,6 @@ async def test_streaming_error_handling(): "model": "non-existent-model", } ], - "streaming": True, } ) @@ -695,123 +687,3 @@ def _llm_type(self) -> str: _chat_providers.pop("custom_none_streaming", None) _llm_providers.pop("custom_streaming_llm", None) _llm_providers.pop("custom_none_streaming_llm", None) - - -@pytest.mark.parametrize( - "model_type,model_streaming,config_streaming,expected_result", - [ - # Chat model tests - ( - "chat", - False, - False, - False, - ), # Case 1: model streaming=no, config streaming=no, result=no - ( - "chat", - False, - True, - False, - ), # Case 2: model streaming=no, config streaming=yes, result=no - ( - "chat", - True, - False, - False, - ), # Case 3: model streaming=yes, config streaming=no, result=no - ( - "chat", - True, - True, - True, - ), # Case 4: model streaming=yes, config streaming=yes, result=yes - # LLM tests - ( - "llm", - False, - False, - False, - ), # Case 1: model streaming=no, config streaming=no, result=no - ( - "llm", - False, - True, - False, - ), # Case 2: model streaming=no, config streaming=yes, result=no - ( - "llm", - True, - False, - False, - ), # Case 3: model streaming=yes, config streaming=no, result=no - ( - "llm", - True, - True, - True, - ), # Case 4: model streaming=yes, config streaming=yes, result=yes - ], -) -def test_main_llm_supports_streaming_flag_config_combinations( - custom_streaming_providers, - model_type, - model_streaming, - config_streaming, - expected_result, -): - """Test all combinations of model streaming support and config streaming settings.""" - - # determine the engine name based on model type and streaming support - if model_type == "chat": - engine = "custom_streaming" if model_streaming else "custom_none_streaming" - else: - engine = "custom_streaming_llm" if model_streaming else "custom_none_streaming_llm" - - config = RailsConfig.from_content( - config={ - "models": [{"type": "main", "engine": engine, "model": "test-model"}], - "streaming": config_streaming, - } - ) - - rails = LLMRails(config) - - assert rails.main_llm_supports_streaming == expected_result, ( - f"main_llm_supports_streaming should be {expected_result} when " - f"model_type={model_type}, model_streaming={model_streaming}, config_streaming={config_streaming}" - ) - - -def test_main_llm_supports_streaming_flag_with_constructor(): - """Test that main_llm_supports_streaming is properly set when LLM is provided via constructor.""" - config = RailsConfig.from_content( - config={ - "models": [], - "streaming": True, - } - ) - - fake_llm = FakeLLM(responses=["test"], streaming=True) - rails = LLMRails(config, llm=fake_llm) - - assert rails.main_llm_supports_streaming is True, ( - "main_llm_supports_streaming should be True when streaming is enabled " - "and LLM provided via constructor supports streaming" - ) - - -def test_main_llm_supports_streaming_flag_disabled_when_no_streaming(): - """Test that main_llm_supports_streaming is False when streaming is disabled.""" - config = RailsConfig.from_content( - config={ - "models": [], - "streaming": False, - } - ) - - fake_llm = FakeLLM(responses=["test"], streaming=False) - rails = LLMRails(config, llm=fake_llm) - - assert rails.main_llm_supports_streaming is False, ( - "main_llm_supports_streaming should be False when streaming is disabled" - ) diff --git a/tests/test_streaming_output_rails.py b/tests/test_streaming_output_rails.py index 5354632db..d65095240 100644 --- a/tests/test_streaming_output_rails.py +++ b/tests/test_streaming_output_rails.py @@ -23,6 +23,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action +from nemoguardrails.exceptions import InvalidRailsConfigurationError from nemoguardrails.rails.llm.llmrails import LLMRails from nemoguardrails.streaming import StreamingHandler from tests.utils import TestChat @@ -164,15 +165,14 @@ async def test_streaming_output_rails_blocked_default_config( llmrails = LLMRails(output_rails_streaming_config_default) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(InvalidRailsConfigurationError) as exc_info: async for chunk in llmrails.stream_async(messages=[{"role": "user", "content": "Hi!"}]): pass assert str(exc_info.value) == ( - "stream_async() cannot be used when output rails are configured but " + "Streaming cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or use " - "generate_async() instead of stream_async()." + "rails.output.streaming.enabled to True in your configuration, or disable streaming." ) await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) diff --git a/tests/test_token_usage_integration.py b/tests/test_token_usage_integration.py index d28957612..40a0d8392 100644 --- a/tests/test_token_usage_integration.py +++ b/tests/test_token_usage_integration.py @@ -256,47 +256,6 @@ async def math_calculation(): assert total_completion_tokens == 12 # 4 + 8 -@pytest.mark.asyncio -async def test_token_usage_not_tracked_without_streaming(llm_calls_option): - """Integration test verifying token usage is NOT tracked when streaming is disabled.""" - - config = RailsConfig.from_content( - config={ - "models": [ - { - "type": "main", - "engine": "openai", - "model": "gpt-4", - } - ], - "streaming": False, - } - ) - - token_usage_data = [{"total_tokens": 15, "prompt_tokens": 8, "completion_tokens": 7}] - - chat = TestChat( - config, - llm_completions=["Hello there!"], - streaming=False, - token_usage=token_usage_data, - ) - - result = await chat.app.generate_async(messages=[{"role": "user", "content": "Hi!"}], options=llm_calls_option) - - assert isinstance(result, GenerationResponse) - assert result.response[0]["content"] == "Hello there!" - - assert result.log is not None - assert result.log.llm_calls is not None - assert len(result.log.llm_calls) > 0 - - llm_call = result.log.llm_calls[0] - assert llm_call.total_tokens == 0 - assert llm_call.prompt_tokens == 0 - assert llm_call.completion_tokens == 0 - - @pytest.mark.asyncio async def test_token_usage_not_set_for_unsupported_provider(): """Integration test verifying token usage is NOT tracked for unsupported providers. diff --git a/tests/utils.py b/tests/utils.py index 660763ad7..732d7250a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -188,13 +188,10 @@ def __init__( """ self.llm = None if llm_completions is not None: - # check if we should simulate stream_usage=True behavior - # this mirrors the logic in LLMRails._prepare_model_kwargs - should_enable_stream_usage = False - if config.streaming: - main_model = next((model for model in config.models if model.type == "main"), None) - if main_model and main_model.engine in _TEST_PROVIDERS_WITH_TOKEN_USAGE_SUPPORT: - should_enable_stream_usage = True + main_model = next((model for model in config.models if model.type == "main"), None) + should_enable_stream_usage = bool( + main_model and main_model.engine in _TEST_PROVIDERS_WITH_TOKEN_USAGE_SUPPORT + ) self.llm = FakeLLM( responses=llm_completions, From 67a5aa40dcabceabef0bbdcb62347e6fe9c2abdb Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:00:05 +0100 Subject: [PATCH 3/9] fix --- nemoguardrails/rails/llm/llmrails.py | 5 +++-- tests/test_parallel_streaming_output_rails.py | 5 +++-- tests/test_streaming.py | 10 ++++++---- tests/test_streaming_output_rails.py | 5 +++-- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 934c8cc61..89a58c14c 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -1167,9 +1167,10 @@ def _validate_streaming_with_output_rails(self) -> None: not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled ): raise InvalidRailsConfigurationError( - "Streaming cannot be used when output rails are configured but " + "stream_async() cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or disable streaming." + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." ) @overload diff --git a/tests/test_parallel_streaming_output_rails.py b/tests/test_parallel_streaming_output_rails.py index 19ce8e1c5..99104db9d 100644 --- a/tests/test_parallel_streaming_output_rails.py +++ b/tests/test_parallel_streaming_output_rails.py @@ -591,9 +591,10 @@ async def test_parallel_streaming_output_rails_default_config_behavior( pass assert str(exc_info.value) == ( - "Streaming cannot be used when output rails are configured but " + "stream_async() cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or disable streaming." + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." ) await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 04ec75197..44689b3fc 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -501,9 +501,10 @@ async def test_streaming_with_output_rails_disabled_raises_error(): pass assert str(exc_info.value) == ( - "Streaming cannot be used when output rails are configured but " + "stream_async() cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or disable streaming." + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." ) @@ -542,9 +543,10 @@ async def test_streaming_with_output_rails_no_streaming_config_raises_error(): pass assert str(exc_info.value) == ( - "Streaming cannot be used when output rails are configured but " + "stream_async() cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or disable streaming." + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." ) diff --git a/tests/test_streaming_output_rails.py b/tests/test_streaming_output_rails.py index d65095240..e44084f86 100644 --- a/tests/test_streaming_output_rails.py +++ b/tests/test_streaming_output_rails.py @@ -170,9 +170,10 @@ async def test_streaming_output_rails_blocked_default_config( pass assert str(exc_info.value) == ( - "Streaming cannot be used when output rails are configured but " + "stream_async() cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or disable streaming." + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." ) await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) From 1d3b7d039e3c149a8cf3b979f4c374f744637c93 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:06:59 +0100 Subject: [PATCH 4/9] improve streaming config error message --- nemoguardrails/cli/chat.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/nemoguardrails/cli/chat.py b/nemoguardrails/cli/chat.py index ec1b09f29..ab90965e1 100644 --- a/nemoguardrails/cli/chat.py +++ b/nemoguardrails/cli/chat.py @@ -94,10 +94,21 @@ async def _run_chat_v1_0( bot_message_text = "".join(bot_message_list) bot_message = {"role": "assistant", "content": bot_message_text} except InvalidRailsConfigurationError as e: - # TODO: improve this error message - raise InvalidRailsConfigurationError( - f"The config `{config_path}` does not support streaming. {e}" - ) from e + error_msg = str(e) + if "stream_async()" in error_msg and "output rails" in error_msg: + raise InvalidRailsConfigurationError( + f"Cannot use --streaming with config `{config_path}` because output rails " + "are configured but streaming is not enabled for them.\n\n" + "To fix this, either:\n" + " 1. Enable streaming for output rails by adding to your config.yml:\n" + " rails:\n" + " output:\n" + " streaming:\n" + " enabled: True\n\n" + " 2. Or run without the --streaming flag:\n" + f" nemoguardrails chat {config_path}" + ) from e + raise else: if rails_app is None: From fe536db65f55b3b96a86798cf9acdea5c9ad9f7d Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:31:30 +0100 Subject: [PATCH 5/9] update configs --- examples/configs/gs_content_safety/config/config.yml | 2 -- examples/configs/llm/hf_pipeline_dolly/config.yml | 3 --- examples/configs/streaming/config.yml | 2 -- examples/scripts/demo_streaming.py | 4 ---- 4 files changed, 11 deletions(-) diff --git a/examples/configs/gs_content_safety/config/config.yml b/examples/configs/gs_content_safety/config/config.yml index 1b94bfc1c..3f6f5cc2c 100644 --- a/examples/configs/gs_content_safety/config/config.yml +++ b/examples/configs/gs_content_safety/config/config.yml @@ -18,5 +18,3 @@ rails: enabled: True chunk_size: 200 context_size: 50 - -streaming: True diff --git a/examples/configs/llm/hf_pipeline_dolly/config.yml b/examples/configs/llm/hf_pipeline_dolly/config.yml index 702d10531..499d15737 100644 --- a/examples/configs/llm/hf_pipeline_dolly/config.yml +++ b/examples/configs/llm/hf_pipeline_dolly/config.yml @@ -2,9 +2,6 @@ models: - type: main engine: hf_pipeline_dolly -# Remove attribute / set to False if streaming is not required -streaming: True - instructions: - type: general content: | diff --git a/examples/configs/streaming/config.yml b/examples/configs/streaming/config.yml index a433b5435..5c8a61fed 100644 --- a/examples/configs/streaming/config.yml +++ b/examples/configs/streaming/config.yml @@ -13,5 +13,3 @@ rails: dialog: single_call: enabled: True - -streaming: True diff --git a/examples/scripts/demo_streaming.py b/examples/scripts/demo_streaming.py index cfaf3c6ef..ed2292c9f 100644 --- a/examples/scripts/demo_streaming.py +++ b/examples/scripts/demo_streaming.py @@ -34,8 +34,6 @@ - type: main engine: openai model: gpt-4 - -streaming: True """ @@ -99,8 +97,6 @@ async def demo_streaming_from_custom_action(): dialog: user_messages: embeddings_only: True - - streaming: True """, colang_content=""" # We need to have at least on canonical form to enable dialog rails. From 40abc9c2dabb63da0e6d93ec0e823b601fddf637 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:50:17 +0100 Subject: [PATCH 6/9] fix(streaming): auto-enable streaming with handler --- .../generate_events_and_streaming.ipynb | 24 +++---------------- nemoguardrails/rails/llm/llmrails.py | 4 ++++ 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/examples/notebooks/generate_events_and_streaming.ipynb b/examples/notebooks/generate_events_and_streaming.ipynb index 94a629180..056c596ec 100644 --- a/examples/notebooks/generate_events_and_streaming.ipynb +++ b/examples/notebooks/generate_events_and_streaming.ipynb @@ -41,15 +41,11 @@ "metadata": { "collapsed": false }, - "source": [ - "## Step 1: create a config \n", - "\n", - "Let's create a simple config:" - ] + "source": "## Step 1: create a config \n\nLet's create a simple config. No special streaming configuration is needed—streaming is automatically enabled when a `StreamingHandler` is used:" }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "d9bac50b3383915e", "metadata": { "ExecuteTime": { @@ -59,21 +55,7 @@ "collapsed": false }, "outputs": [], - "source": [ - "from nemoguardrails import LLMRails, RailsConfig\n", - "\n", - "YAML_CONFIG = \"\"\"\n", - "models:\n", - " - type: main\n", - " engine: openai\n", - " model: gpt-4\n", - "\n", - "streaming: True\n", - "\"\"\"\n", - "\n", - "config = RailsConfig.from_content(yaml_content=YAML_CONFIG)\n", - "app = LLMRails(config)" - ] + "source": "from nemoguardrails import LLMRails, RailsConfig\n\nYAML_CONFIG = \"\"\"\nmodels:\n - type: main\n engine: openai\n model: gpt-4\n\"\"\"\n\nconfig = RailsConfig.from_content(yaml_content=YAML_CONFIG)\napp = LLMRails(config)" }, { "cell_type": "markdown", diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 89a58c14c..072a183ba 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -825,6 +825,7 @@ async def generate_async( if streaming_handler: streaming_handler_var.set(streaming_handler) + self._configure_main_llm_streaming(self.llm) # type: ignore # Initialize the object with additional explanation information. # We allow this to also be set externally. This is useful when multiple parallel @@ -1336,6 +1337,9 @@ async def generate_events_async( llm_stats = LLMStats() llm_stats_var.set(llm_stats) + if streaming_handler_var.get(): + self._configure_main_llm_streaming(self.llm) # type: ignore + # Compute the new events. processing_log = [] new_events = await self.runtime.generate_events(events, processing_log=processing_log) From 078e1d7f79b3dd06fcbbfe7d09f66b7d064cfe29 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 12 Dec 2025 16:02:32 +0100 Subject: [PATCH 7/9] deprecate streaming field in RailsConfig --- nemoguardrails/rails/llm/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 4fe23799e..4c72c563b 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -1373,6 +1373,12 @@ class RailsConfig(BaseModel): description="Configuration for the various rails (input, output, etc.).", ) + streaming: bool = Field( + default=False, + deprecated="The 'streaming' field is no longer required. Use stream_async() method directly instead. This field will be removed in a future version.", + description="DEPRECATED: Use stream_async() method instead. This field is ignored.", + ) + enable_rails_exceptions: bool = Field( default=False, description="If set, the pre-defined guardrails raise exceptions instead of returning pre-defined messages.", From dd54b5fd5870d62e2ed37115f39ea0e2ee4a7de3 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 12 Dec 2025 16:18:12 +0100 Subject: [PATCH 8/9] add tests for deprecated field --- tests/test_rails_config.py | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_rails_config.py b/tests/test_rails_config.py index 796011d82..ff5b22951 100644 --- a/tests/test_rails_config.py +++ b/tests/test_rails_config.py @@ -1015,3 +1015,43 @@ def test_hero_topic_safety_prompt_raises(self): content: Verify the user input is on-topic """ ) + + +class TestDeprecatedStreamingConfig: + """Tests for deprecated streaming config field.""" + + def test_streaming_config_field_accepted(self): + """Test that the deprecated streaming: True config field is still accepted.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + streaming: True + """ + ) + assert config.streaming is True + + def test_streaming_config_field_default_false(self): + """Test that streaming defaults to False when not specified.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + """ + ) + assert config.streaming is False + + def test_streaming_config_field_shows_deprecation_warning(self): + """Test that using streaming: True shows a deprecation warning.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + config = RailsConfig.from_content( + yaml_content=""" + models: [] + streaming: True + """ + ) + assert config.streaming is True + + deprecation_warnings = [warning for warning in w if "streaming" in str(warning.message).lower()] + assert len(deprecation_warnings) > 0, "Expected a deprecation warning for 'streaming' field" From 220cbd85a985fbc69cb45f0ab02b592f028394f8 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 12 Dec 2025 16:31:55 +0100 Subject: [PATCH 9/9] refactor: introduce StreamingNotSupportedError exception Replace InvalidRailsConfigurationError with a new, more specific StreamingNotSupportedError for cases where streaming is requested but not supported by the configuration. Update all relevant imports, usages, and tests to use the new exception. This improves error clarity and guidance for users encountering streaming configuration issues. --- nemoguardrails/cli/chat.py | 31 +++++++++---------- nemoguardrails/exceptions.py | 7 +++++ nemoguardrails/rails/llm/llmrails.py | 3 +- tests/test_parallel_streaming_output_rails.py | 4 +-- tests/test_streaming.py | 6 ++-- tests/test_streaming_output_rails.py | 4 +-- 6 files changed, 30 insertions(+), 25 deletions(-) diff --git a/nemoguardrails/cli/chat.py b/nemoguardrails/cli/chat.py index ab90965e1..f91debb22 100644 --- a/nemoguardrails/cli/chat.py +++ b/nemoguardrails/cli/chat.py @@ -28,7 +28,7 @@ from nemoguardrails.colang.v2_x.runtime.eval import eval_expression from nemoguardrails.colang.v2_x.runtime.flows import State from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x -from nemoguardrails.exceptions import InvalidRailsConfigurationError +from nemoguardrails.exceptions import StreamingNotSupportedError from nemoguardrails.logging import verbose from nemoguardrails.logging.verbose import console from nemoguardrails.rails.llm.options import ( @@ -93,22 +93,19 @@ async def _run_chat_v1_0( bot_message_text = "".join(bot_message_list) bot_message = {"role": "assistant", "content": bot_message_text} - except InvalidRailsConfigurationError as e: - error_msg = str(e) - if "stream_async()" in error_msg and "output rails" in error_msg: - raise InvalidRailsConfigurationError( - f"Cannot use --streaming with config `{config_path}` because output rails " - "are configured but streaming is not enabled for them.\n\n" - "To fix this, either:\n" - " 1. Enable streaming for output rails by adding to your config.yml:\n" - " rails:\n" - " output:\n" - " streaming:\n" - " enabled: True\n\n" - " 2. Or run without the --streaming flag:\n" - f" nemoguardrails chat {config_path}" - ) from e - raise + except StreamingNotSupportedError as e: + raise StreamingNotSupportedError( + f"Cannot use --streaming with config `{config_path}` because output rails " + "are configured but streaming is not enabled for them.\n\n" + "To fix this, either:\n" + " 1. Enable streaming for output rails by adding to your config.yml:\n" + " rails:\n" + " output:\n" + " streaming:\n" + " enabled: True\n\n" + " 2. Or run without the --streaming flag:\n" + f" nemoguardrails chat {config_path}" + ) from e else: if rails_app is None: diff --git a/nemoguardrails/exceptions.py b/nemoguardrails/exceptions.py index fc5118331..3b96b7cea 100644 --- a/nemoguardrails/exceptions.py +++ b/nemoguardrails/exceptions.py @@ -19,6 +19,7 @@ "InvalidModelConfigurationError", "InvalidRailsConfigurationError", "LLMCallException", + "StreamingNotSupportedError", ] @@ -49,6 +50,12 @@ class InvalidRailsConfigurationError(ConfigurationError): pass +class StreamingNotSupportedError(InvalidRailsConfigurationError): + """Raised when streaming is requested but not supported by the configuration.""" + + pass + + class LLMCallException(Exception): """A wrapper around the LLM call invocation exception. diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 072a183ba..0b4279207 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -73,6 +73,7 @@ from nemoguardrails.exceptions import ( InvalidModelConfigurationError, InvalidRailsConfigurationError, + StreamingNotSupportedError, ) from nemoguardrails.kb.kb import KnowledgeBase from nemoguardrails.llm.cache import CacheInterface, LFUCache @@ -1167,7 +1168,7 @@ def _validate_streaming_with_output_rails(self) -> None: if len(self.config.rails.output.flows) > 0 and ( not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled ): - raise InvalidRailsConfigurationError( + raise StreamingNotSupportedError( "stream_async() cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " "rails.output.streaming.enabled to True in your configuration, or use " diff --git a/tests/test_parallel_streaming_output_rails.py b/tests/test_parallel_streaming_output_rails.py index 99104db9d..62b6b269b 100644 --- a/tests/test_parallel_streaming_output_rails.py +++ b/tests/test_parallel_streaming_output_rails.py @@ -24,7 +24,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action -from nemoguardrails.exceptions import InvalidRailsConfigurationError +from nemoguardrails.exceptions import StreamingNotSupportedError from tests.utils import TestChat @@ -586,7 +586,7 @@ async def test_parallel_streaming_output_rails_default_config_behavior( llmrails = LLMRails(parallel_output_rails_default_config) - with pytest.raises(InvalidRailsConfigurationError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: async for _ in llmrails.stream_async(messages=[{"role": "user", "content": "Hi!"}]): pass diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 44689b3fc..2e1b44e5c 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -21,7 +21,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action -from nemoguardrails.exceptions import InvalidRailsConfigurationError +from nemoguardrails.exceptions import StreamingNotSupportedError from nemoguardrails.streaming import StreamingHandler from tests.utils import TestChat @@ -494,7 +494,7 @@ async def test_streaming_with_output_rails_disabled_raises_error(): streaming=True, ) - with pytest.raises(InvalidRailsConfigurationError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: async for chunk in chat.app.stream_async( messages=[{"role": "user", "content": "Hi!"}], ): @@ -536,7 +536,7 @@ async def test_streaming_with_output_rails_no_streaming_config_raises_error(): streaming=True, ) - with pytest.raises(InvalidRailsConfigurationError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: async for chunk in chat.app.stream_async( messages=[{"role": "user", "content": "Hi!"}], ): diff --git a/tests/test_streaming_output_rails.py b/tests/test_streaming_output_rails.py index e44084f86..e3f047b1f 100644 --- a/tests/test_streaming_output_rails.py +++ b/tests/test_streaming_output_rails.py @@ -23,7 +23,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action -from nemoguardrails.exceptions import InvalidRailsConfigurationError +from nemoguardrails.exceptions import StreamingNotSupportedError from nemoguardrails.rails.llm.llmrails import LLMRails from nemoguardrails.streaming import StreamingHandler from tests.utils import TestChat @@ -165,7 +165,7 @@ async def test_streaming_output_rails_blocked_default_config( llmrails = LLMRails(output_rails_streaming_config_default) - with pytest.raises(InvalidRailsConfigurationError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: async for chunk in llmrails.stream_async(messages=[{"role": "user", "content": "Hi!"}]): pass