diff --git a/examples/configs/gs_content_safety/config/config.yml b/examples/configs/gs_content_safety/config/config.yml index 1b94bfc1c..3f6f5cc2c 100644 --- a/examples/configs/gs_content_safety/config/config.yml +++ b/examples/configs/gs_content_safety/config/config.yml @@ -18,5 +18,3 @@ rails: enabled: True chunk_size: 200 context_size: 50 - -streaming: True diff --git a/examples/configs/llm/hf_pipeline_dolly/config.yml b/examples/configs/llm/hf_pipeline_dolly/config.yml index 702d10531..499d15737 100644 --- a/examples/configs/llm/hf_pipeline_dolly/config.yml +++ b/examples/configs/llm/hf_pipeline_dolly/config.yml @@ -2,9 +2,6 @@ models: - type: main engine: hf_pipeline_dolly -# Remove attribute / set to False if streaming is not required -streaming: True - instructions: - type: general content: | diff --git a/examples/configs/streaming/config.yml b/examples/configs/streaming/config.yml index a433b5435..5c8a61fed 100644 --- a/examples/configs/streaming/config.yml +++ b/examples/configs/streaming/config.yml @@ -13,5 +13,3 @@ rails: dialog: single_call: enabled: True - -streaming: True diff --git a/examples/notebooks/generate_events_and_streaming.ipynb b/examples/notebooks/generate_events_and_streaming.ipynb index 94a629180..056c596ec 100644 --- a/examples/notebooks/generate_events_and_streaming.ipynb +++ b/examples/notebooks/generate_events_and_streaming.ipynb @@ -41,15 +41,11 @@ "metadata": { "collapsed": false }, - "source": [ - "## Step 1: create a config \n", - "\n", - "Let's create a simple config:" - ] + "source": "## Step 1: create a config \n\nLet's create a simple config. No special streaming configuration is needed—streaming is automatically enabled when a `StreamingHandler` is used:" }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "d9bac50b3383915e", "metadata": { "ExecuteTime": { @@ -59,21 +55,7 @@ "collapsed": false }, "outputs": [], - "source": [ - "from nemoguardrails import LLMRails, RailsConfig\n", - "\n", - "YAML_CONFIG = \"\"\"\n", - "models:\n", - " - type: main\n", - " engine: openai\n", - " model: gpt-4\n", - "\n", - "streaming: True\n", - "\"\"\"\n", - "\n", - "config = RailsConfig.from_content(yaml_content=YAML_CONFIG)\n", - "app = LLMRails(config)" - ] + "source": "from nemoguardrails import LLMRails, RailsConfig\n\nYAML_CONFIG = \"\"\"\nmodels:\n - type: main\n engine: openai\n model: gpt-4\n\"\"\"\n\nconfig = RailsConfig.from_content(yaml_content=YAML_CONFIG)\napp = LLMRails(config)" }, { "cell_type": "markdown", diff --git a/examples/scripts/demo_streaming.py b/examples/scripts/demo_streaming.py index cfaf3c6ef..ed2292c9f 100644 --- a/examples/scripts/demo_streaming.py +++ b/examples/scripts/demo_streaming.py @@ -34,8 +34,6 @@ - type: main engine: openai model: gpt-4 - -streaming: True """ @@ -99,8 +97,6 @@ async def demo_streaming_from_custom_action(): dialog: user_messages: embeddings_only: True - - streaming: True """, colang_content=""" # We need to have at least on canonical form to enable dialog rails. diff --git a/nemoguardrails/cli/chat.py b/nemoguardrails/cli/chat.py index 1561159c7..f91debb22 100644 --- a/nemoguardrails/cli/chat.py +++ b/nemoguardrails/cli/chat.py @@ -28,6 +28,7 @@ from nemoguardrails.colang.v2_x.runtime.eval import eval_expression from nemoguardrails.colang.v2_x.runtime.flows import State from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x +from nemoguardrails.exceptions import StreamingNotSupportedError from nemoguardrails.logging import verbose from nemoguardrails.logging.verbose import console from nemoguardrails.rails.llm.options import ( @@ -65,11 +66,6 @@ async def _run_chat_v1_0( raise RuntimeError("config_path cannot be None when server_url is None") rails_config = RailsConfig.from_path(config_path) rails_app = LLMRails(rails_config, verbose=verbose) - if streaming and not rails_config.streaming_supported: - console.print( - f"WARNING: The config `{config_path}` does not support streaming. Falling back to normal mode." - ) - streaming = False else: rails_app = None @@ -83,19 +79,33 @@ async def _run_chat_v1_0( if not server_url: # If we have streaming from a locally loaded config, we initialize the handler. - if streaming and not server_url and rails_app and rails_app.main_llm_supports_streaming: - bot_message_list = [] - async for chunk in rails_app.stream_async(messages=history): - if '{"event": "ABORT"' in chunk: - dict_chunk = json.loads(chunk) - console.print("\n\n[red]" + f"ABORT streaming. {dict_chunk['data']}" + "[/]") - break - - console.print("[green]" + f"{chunk}" + "[/]", end="") - bot_message_list.append(chunk) - - bot_message_text = "".join(bot_message_list) - bot_message = {"role": "assistant", "content": bot_message_text} + if streaming and not server_url and rails_app: + try: + bot_message_list = [] + async for chunk in rails_app.stream_async(messages=history): + if '{"event": "ABORT"' in chunk: + dict_chunk = json.loads(chunk) + console.print("\n\n[red]" + f"ABORT streaming. {dict_chunk['data']}" + "[/]") + break + + console.print("[green]" + f"{chunk}" + "[/]", end="") + bot_message_list.append(chunk) + + bot_message_text = "".join(bot_message_list) + bot_message = {"role": "assistant", "content": bot_message_text} + except StreamingNotSupportedError as e: + raise StreamingNotSupportedError( + f"Cannot use --streaming with config `{config_path}` because output rails " + "are configured but streaming is not enabled for them.\n\n" + "To fix this, either:\n" + " 1. Enable streaming for output rails by adding to your config.yml:\n" + " rails:\n" + " output:\n" + " streaming:\n" + " enabled: True\n\n" + " 2. Or run without the --streaming flag:\n" + f" nemoguardrails chat {config_path}" + ) from e else: if rails_app is None: @@ -124,7 +134,7 @@ async def _run_chat_v1_0( # String or other fallback case bot_message = {"role": "assistant", "content": str(response)} - if not streaming or not rails_app.main_llm_supports_streaming: + if not streaming: # We print bot messages in green. content = bot_message.get("content", str(bot_message)) console.print("[green]" + f"{content}" + "[/]") diff --git a/nemoguardrails/exceptions.py b/nemoguardrails/exceptions.py index fc5118331..3b96b7cea 100644 --- a/nemoguardrails/exceptions.py +++ b/nemoguardrails/exceptions.py @@ -19,6 +19,7 @@ "InvalidModelConfigurationError", "InvalidRailsConfigurationError", "LLMCallException", + "StreamingNotSupportedError", ] @@ -49,6 +50,12 @@ class InvalidRailsConfigurationError(ConfigurationError): pass +class StreamingNotSupportedError(InvalidRailsConfigurationError): + """Raised when streaming is requested but not supported by the configuration.""" + + pass + + class LLMCallException(Exception): """A wrapper around the LLM call invocation exception. diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index c3909fafa..4c72c563b 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -1375,7 +1375,8 @@ class RailsConfig(BaseModel): streaming: bool = Field( default=False, - description="Whether this configuration should use streaming mode or not.", + deprecated="The 'streaming' field is no longer required. Use stream_async() method directly instead. This field will be removed in a future version.", + description="DEPRECATED: Use stream_async() method instead. This field is ignored.", ) enable_rails_exceptions: bool = Field( @@ -1665,20 +1666,6 @@ def parse_object(cls, obj): return cls.parse_obj(obj) - @property - def streaming_supported(self): - """Whether the current config supports streaming or not.""" - - if len(self.rails.output.flows) > 0: - # if we have output rails streaming enabled - # we keep it in case it was needed when we have - # support per rails - if self.rails.output.streaming and self.rails.output.streaming.enabled: - return True - return False - - return True - def __add__(self, other): """Adds two RailsConfig objects.""" return _join_rails_configs(self, other) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 0833710fa..0b4279207 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -73,6 +73,7 @@ from nemoguardrails.exceptions import ( InvalidModelConfigurationError, InvalidRailsConfigurationError, + StreamingNotSupportedError, ) from nemoguardrails.kb.kb import KnowledgeBase from nemoguardrails.llm.cache import CacheInterface, LFUCache @@ -155,9 +156,6 @@ def __init__( # should be removed self.events_history_cache = {} - # Weather the main LLM supports streaming - self.main_llm_supports_streaming = False - # We also load the default flows from the `default_flows.yml` file in the current folder. # But only for version 1.0. # TODO: decide on the default flows for 2.x. @@ -377,10 +375,9 @@ def _prepare_model_kwargs(self, model_config): if api_key: kwargs["api_key"] = api_key - # enable streaming token usage when streaming is enabled + # enable streaming token usage # providers that don't support this parameter will simply ignore it - if self.config.streaming: - kwargs["stream_usage"] = True + kwargs["stream_usage"] = True return kwargs @@ -398,22 +395,9 @@ def _configure_main_llm_streaming( provider_name (Optional[str], optional): Optional provider name for logging. """ - if not self.config.streaming: - return if hasattr(llm, "streaming"): setattr(llm, "streaming", True) - self.main_llm_supports_streaming = True - else: - self.main_llm_supports_streaming = False - if model_name and provider_name: - log.warning( - "Model %s from provider %s does not support streaming.", - model_name, - provider_name, - ) - else: - log.warning("Provided main LLM does not support streaming.") def _init_llms(self): """ @@ -442,7 +426,6 @@ def _init_llms(self): ) self.runtime.register_action_param("llm", self.llm) - self._configure_main_llm_streaming(self.llm) else: # Otherwise, initialize the main LLM from the config main_model = next((model for model in self.config.models if model.type == "main"), None) @@ -457,11 +440,6 @@ def _init_llms(self): ) self.runtime.register_action_param("llm", self.llm) - self._configure_main_llm_streaming( - self.llm, - model_name=main_model.model, - provider_name=main_model.engine, - ) else: log.warning("No main LLM specified in the config and no LLM provided via constructor.") @@ -848,6 +826,7 @@ async def generate_async( if streaming_handler: streaming_handler_var.set(streaming_handler) + self._configure_main_llm_streaming(self.llm) # type: ignore # Initialize the object with additional explanation information. # We allow this to also be set externally. This is useful when multiple parallel @@ -1189,7 +1168,7 @@ def _validate_streaming_with_output_rails(self) -> None: if len(self.config.rails.output.flows) > 0 and ( not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled ): - raise InvalidRailsConfigurationError( + raise StreamingNotSupportedError( "stream_async() cannot be used when output rails are configured but " "rails.output.streaming.enabled is False. Either set " "rails.output.streaming.enabled to True in your configuration, or use " @@ -1246,6 +1225,8 @@ def stream_async( streaming_handler = StreamingHandler(include_generation_metadata=include_generation_metadata) + self._configure_main_llm_streaming(self.llm) # type: ignore + # Create a properly managed task with exception handling async def _generation_task(): try: @@ -1357,6 +1338,9 @@ async def generate_events_async( llm_stats = LLMStats() llm_stats_var.set(llm_stats) + if streaming_handler_var.get(): + self._configure_main_llm_streaming(self.llm) # type: ignore + # Compute the new events. processing_log = [] new_events = await self.runtime.generate_events(events, processing_log=processing_log) diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 658cffd01..9bfa0f78c 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -37,7 +37,6 @@ GenerationResponse, ) from nemoguardrails.server.datastore.datastore import DataStore -from nemoguardrails.streaming import StreamingHandler logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) @@ -426,18 +425,11 @@ async def chat_completion(body: RequestBody, request: Request): # And prepend them. messages = thread_messages + messages - if body.stream and llm_rails.config.streaming_supported and llm_rails.main_llm_supports_streaming: - # Create the streaming handler instance - streaming_handler = StreamingHandler() - - # Start the generation - asyncio.create_task( - llm_rails.generate_async( - messages=messages, - streaming_handler=streaming_handler, - options=body.options, - state=body.state, - ) + if body.stream: + streaming_handler = llm_rails.stream_async( + messages=messages, + options=body.options, + state=body.state, ) # TODO: Add support for thread_ids in streaming mode diff --git a/tests/cli/test_chat.py b/tests/cli/test_chat.py index 0d3548975..0d4a14b65 100644 --- a/tests/cli/test_chat.py +++ b/tests/cli/test_chat.py @@ -184,12 +184,10 @@ async def test_run_chat_v1_local_config(self, mock_rails_config, mock_llm_rails, from nemoguardrails.cli.chat import _run_chat_v1_0 mock_config = MagicMock() - mock_config.streaming_supported = False mock_rails_config.from_path.return_value = mock_config mock_rails = AsyncMock() mock_rails.generate_async = AsyncMock(return_value={"role": "assistant", "content": "Hello!"}) - mock_rails.main_llm_supports_streaming = False mock_llm_rails.return_value = mock_rails mock_input.side_effect = ["test message", KeyboardInterrupt()] @@ -203,31 +201,28 @@ async def test_run_chat_v1_local_config(self, mock_rails_config, mock_llm_rails, @pytest.mark.asyncio @patch("builtins.input") - @patch.object(chat_module, "console") @patch.object(chat_module, "LLMRails") @patch.object(chat_module, "RailsConfig") - async def test_run_chat_v1_streaming_not_supported( - self, mock_rails_config, mock_llm_rails, mock_console, mock_input - ): + async def test_run_chat_v1_streaming_not_supported(self, mock_rails_config, mock_llm_rails, mock_input): from nemoguardrails.cli.chat import _run_chat_v1_0 + from nemoguardrails.exceptions import InvalidRailsConfigurationError mock_config = MagicMock() - mock_config.streaming_supported = False mock_rails_config.from_path.return_value = mock_config - mock_rails = AsyncMock() + mock_rails = MagicMock() + + async def mock_stream_async_generator(*args, **kwargs): + raise InvalidRailsConfigurationError("Streaming not supported") + yield + + mock_rails.stream_async = mock_stream_async_generator mock_llm_rails.return_value = mock_rails - mock_input.side_effect = [KeyboardInterrupt()] + mock_input.side_effect = ["test message"] - try: + with pytest.raises(InvalidRailsConfigurationError): await _run_chat_v1_0(config_path="test_config", streaming=True) - except KeyboardInterrupt: - pass - - mock_console.print.assert_any_call( - "WARNING: The config `test_config` does not support streaming. Falling back to normal mode." - ) @pytest.mark.asyncio @patch("aiohttp.ClientSession") diff --git a/tests/cli/test_chat_v2x_integration.py b/tests/cli/test_chat_v2x_integration.py index 2adb586cb..99c21f36a 100644 --- a/tests/cli/test_chat_v2x_integration.py +++ b/tests/cli/test_chat_v2x_integration.py @@ -144,10 +144,10 @@ async def test_chat_v2x_with_real_llm(self): This requires LIVE_TEST_MODE=1 and OpenAI API key. """ - from unittest.mock import patch + from unittest.mock import MagicMock, patch from nemoguardrails import LLMRails, RailsConfig - from nemoguardrails.cli.chat import _run_chat_v2_x + from nemoguardrails.cli.chat import ChatState, _run_chat_v2_x config = RailsConfig.from_content( """ @@ -171,13 +171,22 @@ async def test_chat_v2x_with_real_llm(self): simulated_input = ["hi", "exit"] input_iter = iter(simulated_input) - def mock_input(*args, **kwargs): + async def mock_prompt_async(*args, **kwargs): try: return next(input_iter) except StopIteration: raise KeyboardInterrupt() - with patch("builtins.input", side_effect=mock_input): + mock_session = MagicMock() + mock_session.prompt_async = mock_prompt_async + + original_init = ChatState.__init__ + + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self.session = mock_session + + with patch.object(ChatState, "__init__", patched_init): try: await _run_chat_v2_x(rails) except (KeyboardInterrupt, StopIteration): diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index b40cf2876..1ad6699e6 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -195,21 +195,18 @@ def test_rails_config_simple_field_overwriting(): """Tests that fields from the second config overwrite fields from the first config.""" config1 = RailsConfig( models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], - streaming=False, lowest_temperature=0.1, colang_version="1.0", ) config2 = RailsConfig( models=[Model(type="secondary", engine="anthropic", model="claude-3")], - streaming=True, lowest_temperature=0.5, colang_version="2.x", ) result = config1 + config2 - assert result.streaming is True assert result.lowest_temperature == 0.5 assert result.colang_version == "2.x" @@ -304,12 +301,11 @@ def test_rails_config_none_config_path(): def test_llm_rails_configure_streaming_with_attr(): - """Check LLM has the streaming attribute set if RailsConfig has it""" + """Check LLM has the streaming attribute set when _configure_main_llm_streaming is called""" mock_llm = MagicMock(spec=BaseLLM) config = RailsConfig( models=[], - streaming=True, ) rails = LLMRails(config, llm=mock_llm) @@ -317,56 +313,3 @@ def test_llm_rails_configure_streaming_with_attr(): rails._configure_main_llm_streaming(llm=mock_llm) assert mock_llm.streaming - - -def test_llm_rails_configure_streaming_without_attr(caplog): - """Check LLM has the streaming attribute set if RailsConfig has it""" - - mock_llm = MagicMock(spec=BaseLLM) - config = RailsConfig( - models=[], - streaming=True, - ) - - rails = LLMRails(config, llm=mock_llm) - rails._configure_main_llm_streaming(mock_llm) - - assert caplog.messages[-1] == "Provided main LLM does not support streaming." - - -def test_rails_config_streaming_supported_no_output_flows(): - """Check `streaming_supported` property doesn't depend on RailsConfig.streaming with no output flows""" - - config = RailsConfig( - models=[], - streaming=False, - ) - assert config.streaming_supported - - -def test_rails_config_flows_streaming_supported_true(): - """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" - - rails = { - "output": { - "flows": ["content_safety_check_output"], - "streaming": {"enabled": True}, - } - } - prompts = [{"task": "content safety check output", "content": "..."}] - rails_config = RailsConfig.model_validate({"models": [], "rails": rails, "prompts": prompts}) - assert rails_config.streaming_supported - - -def test_rails_config_flows_streaming_supported_false(): - """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" - - rails = { - "output": { - "flows": ["content_safety_check_output"], - "streaming": {"enabled": False}, - } - } - prompts = [{"task": "content safety check output", "content": "..."}] - rails_config = RailsConfig.model_validate({"models": [], "rails": rails, "prompts": prompts}) - assert not rails_config.streaming_supported diff --git a/tests/runnable_rails/test_batching.py b/tests/runnable_rails/test_batching.py index f8bed8704..d1fb2c7c8 100644 --- a/tests/runnable_rails/test_batching.py +++ b/tests/runnable_rails/test_batching.py @@ -130,7 +130,7 @@ async def test_astream_output(): ], streaming=True, ) - config = RailsConfig.from_content(config={"models": [], "streaming": True}) + config = RailsConfig.from_content(config={"models": []}) model_with_rails = RunnableRails(config, llm=llm) # Collect all chunks from the stream diff --git a/tests/test_llm_params_e2e.py b/tests/test_llm_params_e2e.py index 306f4de37..772ab48bf 100644 --- a/tests/test_llm_params_e2e.py +++ b/tests/test_llm_params_e2e.py @@ -182,18 +182,18 @@ async def test_openai_llm_params_direct_llm_call(self, openai_config_path): async def test_openai_llm_params_streaming(self, openai_config_path): """Test llm_params work with streaming responses from OpenAI.""" config = RailsConfig.from_path(openai_config_path) - config.streaming = True rails = LLMRails(config, verbose=False) prompt = "Count from 1 to 3." - response = await rails.generate_async( + chunks = [] + async for chunk in rails.stream_async( messages=[{"role": "user", "content": prompt}], options={"llm_params": {"temperature": 0.0, "max_tokens": 20}}, - ) + ): + chunks.append(chunk) - assert response.response is not None - content = response.response[-1]["content"] + content = "".join(chunks) assert "1" in content @pytest.mark.asyncio diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 5e918232b..a47b07f49 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -965,10 +965,8 @@ def __init__(self): @pytest.mark.asyncio @patch("nemoguardrails.rails.llm.llmrails.init_llm_model") -async def test_stream_usage_enabled_for_streaming_supported_providers( - mock_init_llm_model, -): - """Test that stream_usage=True is set when streaming is enabled for supported providers.""" +async def test_stream_usage_always_enabled(mock_init_llm_model): + """Test that stream_usage=True is always set for LLM models.""" config = RailsConfig.from_content( config={ "models": [ @@ -978,7 +976,6 @@ async def test_stream_usage_enabled_for_streaming_supported_providers( "model": "gpt-4", } ], - "streaming": True, } ) @@ -991,68 +988,6 @@ async def test_stream_usage_enabled_for_streaming_supported_providers( assert kwargs.get("stream_usage") is True -@pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") -async def test_stream_usage_not_set_without_streaming(mock_init_llm_model): - """Test that stream_usage is not set when streaming is disabled.""" - config = RailsConfig.from_content( - config={ - "models": [ - { - "type": "main", - "engine": "openai", - "model": "gpt-4", - } - ], - "streaming": False, - } - ) - - LLMRails(config=config) - - mock_init_llm_model.assert_called_once() - call_args = mock_init_llm_model.call_args - kwargs = call_args.kwargs.get("kwargs", {}) - - assert "stream_usage" not in kwargs - - -@pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") -async def test_stream_usage_enabled_for_all_providers_when_streaming( - mock_init_llm_model, -): - """Test that stream_usage is passed to ALL providers when streaming is enabled. - - With the new design, stream_usage=True is passed to ALL providers when - streaming is enabled. Providers that don't support it will simply ignore it. - """ - config = RailsConfig.from_content( - config={ - "models": [ - { - "type": "main", - "engine": "unsupported", - "model": "whatever", - } - ], - "streaming": True, - } - ) - - LLMRails(config=config) - - mock_init_llm_model.assert_called_once() - call_args = mock_init_llm_model.call_args - kwargs = call_args.kwargs.get("kwargs", {}) - - # stream_usage should be set for all providers when streaming is enabled - assert kwargs.get("stream_usage") is True - - -# Add this test after the existing tests, around line 1100+ - - def test_register_methods_return_self(): """Test that all register_* methods return self for method chaining.""" config = RailsConfig.from_content(config={"models": []}) diff --git a/tests/test_parallel_streaming_output_rails.py b/tests/test_parallel_streaming_output_rails.py index c99e82636..62b6b269b 100644 --- a/tests/test_parallel_streaming_output_rails.py +++ b/tests/test_parallel_streaming_output_rails.py @@ -24,6 +24,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action +from nemoguardrails.exceptions import StreamingNotSupportedError from tests.utils import TestChat @@ -585,8 +586,8 @@ async def test_parallel_streaming_output_rails_default_config_behavior( llmrails = LLMRails(parallel_output_rails_default_config) - with pytest.raises(ValueError) as exc_info: - async for chunk in llmrails.stream_async(messages=[{"role": "user", "content": "Hi!"}]): + with pytest.raises(StreamingNotSupportedError) as exc_info: + async for _ in llmrails.stream_async(messages=[{"role": "user", "content": "Hi!"}]): pass assert str(exc_info.value) == ( diff --git a/tests/test_rails_config.py b/tests/test_rails_config.py index 796011d82..ff5b22951 100644 --- a/tests/test_rails_config.py +++ b/tests/test_rails_config.py @@ -1015,3 +1015,43 @@ def test_hero_topic_safety_prompt_raises(self): content: Verify the user input is on-topic """ ) + + +class TestDeprecatedStreamingConfig: + """Tests for deprecated streaming config field.""" + + def test_streaming_config_field_accepted(self): + """Test that the deprecated streaming: True config field is still accepted.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + streaming: True + """ + ) + assert config.streaming is True + + def test_streaming_config_field_default_false(self): + """Test that streaming defaults to False when not specified.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + """ + ) + assert config.streaming is False + + def test_streaming_config_field_shows_deprecation_warning(self): + """Test that using streaming: True shows a deprecation warning.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + config = RailsConfig.from_content( + yaml_content=""" + models: [] + streaming: True + """ + ) + assert config.streaming is True + + deprecation_warnings = [warning for warning in w if "streaming" in str(warning.message).lower()] + assert len(deprecation_warnings) > 0, "Expected a deprecation warning for 'streaming' field" diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c7f59a7d1..2e1b44e5c 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -19,15 +19,16 @@ import pytest -from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails import RailsConfig from nemoguardrails.actions import action +from nemoguardrails.exceptions import StreamingNotSupportedError from nemoguardrails.streaming import StreamingHandler -from tests.utils import FakeLLM, TestChat +from tests.utils import TestChat @pytest.fixture def chat_1(): - config: RailsConfig = RailsConfig.from_content(config={"models": [], "streaming": True}) + config: RailsConfig = RailsConfig.from_content(config={"models": []}) return TestChat( config, llm_completions=[ @@ -77,7 +78,7 @@ async def test_stream_async_api(chat_1): async def test_streaming_predefined_messages(): """Predefined messages should be streamed as a single chunk.""" config: RailsConfig = RailsConfig.from_content( - config={"models": [], "streaming": True}, + config={"models": []}, colang_content=""" define user express greeting "hi" @@ -110,7 +111,7 @@ async def test_streaming_predefined_messages(): async def test_streaming_dynamic_bot_message(): """Predefined messages should be streamed as a single chunk.""" config: RailsConfig = RailsConfig.from_content( - config={"models": [], "streaming": True}, + config={"models": []}, colang_content=""" define user express greeting "hi" @@ -146,7 +147,6 @@ async def test_streaming_single_llm_call(): config={ "models": [], "rails": {"dialog": {"single_call": {"enabled": True}}}, - "streaming": True, }, colang_content=""" define user express greeting @@ -180,7 +180,6 @@ async def test_streaming_single_llm_call_with_message_override(): config={ "models": [], "rails": {"dialog": {"single_call": {"enabled": True}}}, - "streaming": True, }, colang_content=""" define user express greeting @@ -219,7 +218,6 @@ async def test_streaming_single_llm_call_with_next_step_override_and_dynamic_mes config={ "models": [], "rails": {"dialog": {"single_call": {"enabled": True}}}, - "streaming": True, }, colang_content=""" define user express greeting @@ -267,7 +265,6 @@ def output_rails_streaming_config(): }, } }, - "streaming": True, "prompts": [{"task": "self_check_output", "content": "a test template"}], }, colang_content=""" @@ -479,7 +476,6 @@ async def test_streaming_with_output_rails_disabled_raises_error(): }, } }, - "streaming": True, "prompts": [{"task": "self_check_output", "content": "a test template"}], }, colang_content=""" @@ -498,7 +494,7 @@ async def test_streaming_with_output_rails_disabled_raises_error(): streaming=True, ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: async for chunk in chat.app.stream_async( messages=[{"role": "user", "content": "Hi!"}], ): @@ -522,7 +518,6 @@ async def test_streaming_with_output_rails_no_streaming_config_raises_error(): "flows": {"self check output"}, } }, - "streaming": True, "prompts": [{"task": "self_check_output", "content": "a test template"}], }, colang_content=""" @@ -541,7 +536,7 @@ async def test_streaming_with_output_rails_no_streaming_config_raises_error(): streaming=True, ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: async for chunk in chat.app.stream_async( messages=[{"role": "user", "content": "Hi!"}], ): @@ -568,7 +563,6 @@ async def test_streaming_error_handling(): "model": "non-existent-model", } ], - "streaming": True, } ) @@ -695,123 +689,3 @@ def _llm_type(self) -> str: _chat_providers.pop("custom_none_streaming", None) _llm_providers.pop("custom_streaming_llm", None) _llm_providers.pop("custom_none_streaming_llm", None) - - -@pytest.mark.parametrize( - "model_type,model_streaming,config_streaming,expected_result", - [ - # Chat model tests - ( - "chat", - False, - False, - False, - ), # Case 1: model streaming=no, config streaming=no, result=no - ( - "chat", - False, - True, - False, - ), # Case 2: model streaming=no, config streaming=yes, result=no - ( - "chat", - True, - False, - False, - ), # Case 3: model streaming=yes, config streaming=no, result=no - ( - "chat", - True, - True, - True, - ), # Case 4: model streaming=yes, config streaming=yes, result=yes - # LLM tests - ( - "llm", - False, - False, - False, - ), # Case 1: model streaming=no, config streaming=no, result=no - ( - "llm", - False, - True, - False, - ), # Case 2: model streaming=no, config streaming=yes, result=no - ( - "llm", - True, - False, - False, - ), # Case 3: model streaming=yes, config streaming=no, result=no - ( - "llm", - True, - True, - True, - ), # Case 4: model streaming=yes, config streaming=yes, result=yes - ], -) -def test_main_llm_supports_streaming_flag_config_combinations( - custom_streaming_providers, - model_type, - model_streaming, - config_streaming, - expected_result, -): - """Test all combinations of model streaming support and config streaming settings.""" - - # determine the engine name based on model type and streaming support - if model_type == "chat": - engine = "custom_streaming" if model_streaming else "custom_none_streaming" - else: - engine = "custom_streaming_llm" if model_streaming else "custom_none_streaming_llm" - - config = RailsConfig.from_content( - config={ - "models": [{"type": "main", "engine": engine, "model": "test-model"}], - "streaming": config_streaming, - } - ) - - rails = LLMRails(config) - - assert rails.main_llm_supports_streaming == expected_result, ( - f"main_llm_supports_streaming should be {expected_result} when " - f"model_type={model_type}, model_streaming={model_streaming}, config_streaming={config_streaming}" - ) - - -def test_main_llm_supports_streaming_flag_with_constructor(): - """Test that main_llm_supports_streaming is properly set when LLM is provided via constructor.""" - config = RailsConfig.from_content( - config={ - "models": [], - "streaming": True, - } - ) - - fake_llm = FakeLLM(responses=["test"], streaming=True) - rails = LLMRails(config, llm=fake_llm) - - assert rails.main_llm_supports_streaming is True, ( - "main_llm_supports_streaming should be True when streaming is enabled " - "and LLM provided via constructor supports streaming" - ) - - -def test_main_llm_supports_streaming_flag_disabled_when_no_streaming(): - """Test that main_llm_supports_streaming is False when streaming is disabled.""" - config = RailsConfig.from_content( - config={ - "models": [], - "streaming": False, - } - ) - - fake_llm = FakeLLM(responses=["test"], streaming=False) - rails = LLMRails(config, llm=fake_llm) - - assert rails.main_llm_supports_streaming is False, ( - "main_llm_supports_streaming should be False when streaming is disabled" - ) diff --git a/tests/test_streaming_output_rails.py b/tests/test_streaming_output_rails.py index 5354632db..e3f047b1f 100644 --- a/tests/test_streaming_output_rails.py +++ b/tests/test_streaming_output_rails.py @@ -23,6 +23,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action +from nemoguardrails.exceptions import StreamingNotSupportedError from nemoguardrails.rails.llm.llmrails import LLMRails from nemoguardrails.streaming import StreamingHandler from tests.utils import TestChat @@ -164,7 +165,7 @@ async def test_streaming_output_rails_blocked_default_config( llmrails = LLMRails(output_rails_streaming_config_default) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(StreamingNotSupportedError) as exc_info: async for chunk in llmrails.stream_async(messages=[{"role": "user", "content": "Hi!"}]): pass diff --git a/tests/test_token_usage_integration.py b/tests/test_token_usage_integration.py index d28957612..40a0d8392 100644 --- a/tests/test_token_usage_integration.py +++ b/tests/test_token_usage_integration.py @@ -256,47 +256,6 @@ async def math_calculation(): assert total_completion_tokens == 12 # 4 + 8 -@pytest.mark.asyncio -async def test_token_usage_not_tracked_without_streaming(llm_calls_option): - """Integration test verifying token usage is NOT tracked when streaming is disabled.""" - - config = RailsConfig.from_content( - config={ - "models": [ - { - "type": "main", - "engine": "openai", - "model": "gpt-4", - } - ], - "streaming": False, - } - ) - - token_usage_data = [{"total_tokens": 15, "prompt_tokens": 8, "completion_tokens": 7}] - - chat = TestChat( - config, - llm_completions=["Hello there!"], - streaming=False, - token_usage=token_usage_data, - ) - - result = await chat.app.generate_async(messages=[{"role": "user", "content": "Hi!"}], options=llm_calls_option) - - assert isinstance(result, GenerationResponse) - assert result.response[0]["content"] == "Hello there!" - - assert result.log is not None - assert result.log.llm_calls is not None - assert len(result.log.llm_calls) > 0 - - llm_call = result.log.llm_calls[0] - assert llm_call.total_tokens == 0 - assert llm_call.prompt_tokens == 0 - assert llm_call.completion_tokens == 0 - - @pytest.mark.asyncio async def test_token_usage_not_set_for_unsupported_provider(): """Integration test verifying token usage is NOT tracked for unsupported providers. diff --git a/tests/utils.py b/tests/utils.py index 660763ad7..732d7250a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -188,13 +188,10 @@ def __init__( """ self.llm = None if llm_completions is not None: - # check if we should simulate stream_usage=True behavior - # this mirrors the logic in LLMRails._prepare_model_kwargs - should_enable_stream_usage = False - if config.streaming: - main_model = next((model for model in config.models if model.type == "main"), None) - if main_model and main_model.engine in _TEST_PROVIDERS_WITH_TOKEN_USAGE_SUPPORT: - should_enable_stream_usage = True + main_model = next((model for model in config.models if model.type == "main"), None) + should_enable_stream_usage = bool( + main_model and main_model.engine in _TEST_PROVIDERS_WITH_TOKEN_USAGE_SUPPORT + ) self.llm = FakeLLM( responses=llm_completions,