diff --git a/minference/lite/requests.py b/minference/lite/requests.py index 21e8c06..a6be1d9 100644 --- a/minference/lite/requests.py +++ b/minference/lite/requests.py @@ -218,6 +218,16 @@ def validate_vllm_request(request: Dict[str, Any]) -> bool: EntityRegistry._logger.error(f"Error validating VLLM request: {e}") raise ValueError(f"Error validating VLLM request: {e} with request: {request}") +def reset_workflow_step(chat_thread: ChatThread, step: int = 0): + """Resets the workflow step for a chat thread. + + Args: + chat_thread: The ChatThread object. + step: The step to reset to (default: 0). + """ + chat_thread.workflow_step = step + EntityRegistry._logger.info(f"Reset workflow step for ChatThread({chat_thread.id}) to {step}") + def get_openai_request(chat_thread: ChatThread) -> Optional[Dict[str, Any]]: """Get OpenAI format request from chat thread.""" EntityRegistry._logger.info(f"Getting OpenAI request for ChatThread({chat_thread.id}) with response format {chat_thread.llm_config.response_format}") @@ -242,7 +252,7 @@ def get_openai_request(chat_thread: ChatThread) -> Optional[Dict[str, Any]]: request["tool_choice"] = "auto" EntityRegistry._logger.info(f"Added {len(tools)} tools to OpenAI request as auto_tools") elif chat_thread.llm_config.response_format == ResponseFormat.workflow: - #detected workflow mode + #detected workflow mode if chat_thread.workflow_step is None: raise ValueError("Workflow step is None") EntityRegistry._logger.info(f"Detected workflow mode for ChatThread({chat_thread.id}) with workflow step {chat_thread.workflow_step}") @@ -253,12 +263,22 @@ def get_openai_request(chat_thread: ChatThread) -> Optional[Dict[str, Any]]: request["tools"] = [tool.get_openai_tool()] request["tool_choice"] = {"type": "function", "function": {"name": tool.name}} EntityRegistry._logger.info(f"Added tool({tool.name}) to OpenAI request as workflow step {chat_thread.workflow_step}") - chat_thread.workflow_step += 1 + # Increment ONLY if validation succeeds + if validate_openai_request(request): + chat_thread.workflow_step += 1 + EntityRegistry._logger.info(f"Validated OpenAI request for ChatThread({chat_thread.id}) with response format {chat_thread.llm_config.response_format}") + return request + else: + EntityRegistry._logger.error(f"Failed to validate OpenAI request for ChatThread({chat_thread.id}) with response format {chat_thread.llm_config.response_format}") + return None + else: EntityRegistry._logger.error(f"Tool not found for workflow step {chat_thread.workflow_step}") + return None # Return None if tool not found + elif chat_thread.llm_config.response_format != ResponseFormat.text: raise ValueError(f"Invalid response format: {chat_thread.llm_config.response_format}") - + if validate_openai_request(request): EntityRegistry._logger.info(f"Validated OpenAI request for ChatThread({chat_thread.id}) with response format {chat_thread.llm_config.response_format}") return request @@ -268,7 +288,7 @@ def get_openai_request(chat_thread: ChatThread) -> Optional[Dict[str, Any]]: def get_anthropic_request(chat_thread: ChatThread) -> Optional[Dict[str, Any]]: """Get Anthropic format request from chat thread.""" - system_content, messages = chat_thread.anthropic_messages + system_content, messages = chat_thread.anthropic_messages request = { "model": chat_thread.llm_config.model, "max_tokens": chat_thread.llm_config.max_tokens, @@ -295,9 +315,15 @@ def get_anthropic_request(chat_thread: ChatThread) -> Optional[Dict[str, Any]]: if tool: request["tools"] = [tool.get_anthropic_tool()] request["tool_choice"] = ToolChoiceToolChoiceTool(name=tool.name, type="tool") - chat_thread.workflow_step += 1 + # Increment ONLY if validation succeeds + if validate_anthropic_request(request): + chat_thread.workflow_step += 1 + return request + else: + return None else: EntityRegistry._logger.error(f"Tool not found for workflow step {chat_thread.workflow_step}") + return None # Return None if tool not found