diff --git a/openhands-sdk/openhands/sdk/agent/agent.py b/openhands-sdk/openhands/sdk/agent/agent.py index 83ef39ae65..da5ff351d4 100644 --- a/openhands-sdk/openhands/sdk/agent/agent.py +++ b/openhands-sdk/openhands/sdk/agent/agent.py @@ -4,6 +4,7 @@ import openhands.sdk.security.risk as risk from openhands.sdk.agent.base import AgentBase +from openhands.sdk.context.prompts.prompt import render_template from openhands.sdk.context.view import View from openhands.sdk.conversation import ( ConversationCallbackType, @@ -17,6 +18,7 @@ LLMConvertibleEvent, MessageEvent, ObservationEvent, + SecurityPromptEvent, SystemPromptEvent, ) from openhands.sdk.event.condenser import Condensation, CondensationRequest @@ -112,6 +114,21 @@ def init_state( ) on_event(event) + # Add security prompt if template is available + try: + security_prompt_text = render_template( + prompt_dir=self.prompt_dir, + template_name="security_analyzer_info.j2", + ) + security_event = SecurityPromptEvent( + source="agent", + security_prompt=TextContent(text=security_prompt_text), + ) + on_event(security_event) + except Exception: + # Template not found or other error - skip security prompt + pass + def _execute_actions( self, conversation: LocalConversation, @@ -144,7 +161,12 @@ def step( # of events, exactly as expected, or a new condensation that needs to be # processed before the agent can sample another action. if self.condenser is not None: - view = View.from_events(state.events) + is_security_analyzer_enabled = isinstance( + self.security_analyzer, LLMSecurityAnalyzer + ) + view = View.from_events( + state.events, is_security_analyzer_enabled=is_security_analyzer_enabled + ) condensation_result = self.condenser.condense(view) match condensation_result: @@ -156,9 +178,13 @@ def step( return None else: - llm_convertible_events = [ - e for e in state.events if isinstance(e, LLMConvertibleEvent) - ] + is_security_analyzer_enabled = isinstance( + self.security_analyzer, LLMSecurityAnalyzer + ) + view = View.from_events( + state.events, is_security_analyzer_enabled=is_security_analyzer_enabled + ) + llm_convertible_events = view.events # Get LLM Response (Action) _messages = LLMConvertibleEvent.events_to_messages(llm_convertible_events) diff --git a/openhands-sdk/openhands/sdk/context/prompts/templates/security_analyzer_info.j2 b/openhands-sdk/openhands/sdk/context/prompts/templates/security_analyzer_info.j2 new file mode 100644 index 0000000000..397b5c08d8 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/prompts/templates/security_analyzer_info.j2 @@ -0,0 +1,18 @@ + +You have a security analyzer enabled that will evaluate the risk level of your actions. + +When using tools, you must include a `security_risk` parameter in your function calls with one of these values: +- "LOW": Safe operations like reading files, listing directories, or simple calculations +- "MEDIUM": Operations that modify files, install packages, or run user code within the sandbox +- "HIGH": Operations that could potentially access sensitive data, connect to external services, or perform privileged operations + +The security analyzer will use your risk assessment to determine if user confirmation is needed before executing actions. + +Examples: +- Reading a file: security_risk="LOW" +- Creating/editing a file: security_risk="MEDIUM" +- Running a bash command that installs software: security_risk="MEDIUM" +- Running a command that could access network or sensitive data: security_risk="HIGH" + +Always be thoughtful about the security implications of your actions and provide accurate risk assessments. + \ No newline at end of file diff --git a/openhands-sdk/openhands/sdk/context/view.py b/openhands-sdk/openhands/sdk/context/view.py index 3ee77eb124..7cc5f81065 100644 --- a/openhands-sdk/openhands/sdk/context/view.py +++ b/openhands-sdk/openhands/sdk/context/view.py @@ -14,6 +14,7 @@ from openhands.sdk.event.llm_convertible import ( ActionEvent, ObservationBaseEvent, + SecurityPromptEvent, ) from openhands.sdk.event.types import ToolCallID @@ -180,9 +181,17 @@ def _should_keep_event( return True @staticmethod - def from_events(events: Sequence[Event]) -> "View": + def from_events( + events: Sequence[Event], *, is_security_analyzer_enabled: bool = False + ) -> "View": """Create a view from a list of events, respecting the semantics of any condensation events. + + Args: + events: Sequence of events to create the view from + is_security_analyzer_enabled: Whether security analyzer is enabled. + If True, SecurityPromptEvent instances will be included in the view. + If False, they will be excluded from the view. """ forgotten_event_ids: set[EventID] = set() condensations: list[Condensation] = [] @@ -205,6 +214,10 @@ def from_events(events: Sequence[Event]) -> "View": for event in events if event.id not in forgotten_event_ids and isinstance(event, LLMConvertibleEvent) + and ( + not isinstance(event, SecurityPromptEvent) + or is_security_analyzer_enabled + ) ] # If we have a summary, insert it at the specified offset. diff --git a/openhands-sdk/openhands/sdk/event/__init__.py b/openhands-sdk/openhands/sdk/event/__init__.py index 578afcbb8b..5b7bde56d7 100644 --- a/openhands-sdk/openhands/sdk/event/__init__.py +++ b/openhands-sdk/openhands/sdk/event/__init__.py @@ -11,6 +11,7 @@ MessageEvent, ObservationBaseEvent, ObservationEvent, + SecurityPromptEvent, SystemPromptEvent, UserRejectObservation, ) @@ -22,6 +23,7 @@ "Event", "LLMConvertibleEvent", "SystemPromptEvent", + "SecurityPromptEvent", "ActionEvent", "ObservationEvent", "ObservationBaseEvent", diff --git a/openhands-sdk/openhands/sdk/event/llm_convertible/__init__.py b/openhands-sdk/openhands/sdk/event/llm_convertible/__init__.py index 1154668330..734f00d3e8 100644 --- a/openhands-sdk/openhands/sdk/event/llm_convertible/__init__.py +++ b/openhands-sdk/openhands/sdk/event/llm_convertible/__init__.py @@ -6,11 +6,13 @@ ObservationEvent, UserRejectObservation, ) +from openhands.sdk.event.llm_convertible.security import SecurityPromptEvent from openhands.sdk.event.llm_convertible.system import SystemPromptEvent __all__ = [ "SystemPromptEvent", + "SecurityPromptEvent", "ActionEvent", "ObservationEvent", "ObservationBaseEvent", diff --git a/openhands-sdk/openhands/sdk/event/llm_convertible/security.py b/openhands-sdk/openhands/sdk/event/llm_convertible/security.py new file mode 100644 index 0000000000..c085f667ce --- /dev/null +++ b/openhands-sdk/openhands/sdk/event/llm_convertible/security.py @@ -0,0 +1,36 @@ +from pydantic import Field +from rich.text import Text + +from openhands.sdk.event.base import N_CHAR_PREVIEW, LLMConvertibleEvent +from openhands.sdk.event.types import SourceType +from openhands.sdk.llm import Message, TextContent + + +class SecurityPromptEvent(LLMConvertibleEvent): + """Security-related prompt added by the agent when security analyzer is enabled.""" + + source: SourceType = "agent" + security_prompt: TextContent = Field( + ..., description="The security analyzer prompt text" + ) + + @property + def visualize(self) -> Text: + """Return Rich Text representation of this security prompt event.""" + content = Text() + content.append("Security Prompt:\n", style="bold") + content.append(self.security_prompt.text) + return content + + def to_llm_message(self) -> Message: + return Message(role="system", content=[self.security_prompt]) + + def __str__(self) -> str: + """Plain text string representation for SecurityPromptEvent.""" + base_str = f"{self.__class__.__name__} ({self.source})" + prompt_preview = ( + self.security_prompt.text[:N_CHAR_PREVIEW] + "..." + if len(self.security_prompt.text) > N_CHAR_PREVIEW + else self.security_prompt.text + ) + return f"{base_str}\n Security: {prompt_preview}" diff --git a/openhands-sdk/openhands/sdk/event/types.py b/openhands-sdk/openhands/sdk/event/types.py index 28c2f3d713..441de3fe8b 100644 --- a/openhands-sdk/openhands/sdk/event/types.py +++ b/openhands-sdk/openhands/sdk/event/types.py @@ -1,7 +1,14 @@ from typing import Literal -EventType = Literal["action", "observation", "message", "system_prompt", "agent_error"] +EventType = Literal[ + "action", + "observation", + "message", + "system_prompt", + "security_prompt", + "agent_error", +] SourceType = Literal["agent", "user", "environment"] EventID = str diff --git a/tests/sdk/context/condenser/test_llm_summarizing_condenser.py b/tests/sdk/context/condenser/test_llm_summarizing_condenser.py index 8b8f344b05..78067a13df 100644 --- a/tests/sdk/context/condenser/test_llm_summarizing_condenser.py +++ b/tests/sdk/context/condenser/test_llm_summarizing_condenser.py @@ -165,7 +165,9 @@ def test_get_condensation_with_previous_summary(mock_llm: LLM) -> None: events[:keep_first] + [condensation] + events[keep_first:] ) - view = View.from_events(events_with_condensation) + view = View.from_events( + events_with_condensation, is_security_analyzer_enabled=False + ) result = condenser.get_condensation(view) diff --git a/tests/sdk/context/test_view.py b/tests/sdk/context/test_view.py index 92aa73d2fb..231ad28062 100644 --- a/tests/sdk/context/test_view.py +++ b/tests/sdk/context/test_view.py @@ -281,7 +281,9 @@ def test_condensation_request_always_removed_from_view() -> None: CondensationRequest(), message_event(content="Event 1"), ] - view_unhandled = View.from_events(events_unhandled) + view_unhandled = View.from_events( + events_unhandled, is_security_analyzer_enabled=False + ) assert view_unhandled.unhandled_condensation_request is True assert len(view_unhandled) == 2 # Only MessageEvents @@ -363,7 +365,9 @@ def test_most_recent_condensation_property() -> None: # Test with no condensations events_no_condensation: list[Event] = cast(list[Event], message_events.copy()) - view_no_condensation = View.from_events(events_no_condensation) + view_no_condensation = View.from_events( + events_no_condensation, is_security_analyzer_enabled=False + ) assert view_no_condensation.most_recent_condensation is None # Test with single condensation @@ -395,7 +399,9 @@ def test_most_recent_condensation_property() -> None: message_events[2], condensation3, ] - view_multiple = View.from_events(events_multiple) + view_multiple = View.from_events( + events_multiple, is_security_analyzer_enabled=False + ) assert view_multiple.most_recent_condensation == condensation3 diff --git a/tests/sdk/event/test_security_prompt_event.py b/tests/sdk/event/test_security_prompt_event.py new file mode 100644 index 0000000000..7178d9556e --- /dev/null +++ b/tests/sdk/event/test_security_prompt_event.py @@ -0,0 +1,84 @@ +"""Tests for SecurityPromptEvent.""" + +from openhands.sdk.context.view import View +from openhands.sdk.event import SecurityPromptEvent, SystemPromptEvent +from openhands.sdk.llm import TextContent + + +def test_security_prompt_event_creation(): + """Test SecurityPromptEvent creation and basic functionality.""" + security_event = SecurityPromptEvent( + source="agent", + security_prompt=TextContent(text="This is a security prompt for testing."), + ) + + assert security_event.source == "agent" + assert ( + security_event.security_prompt.text == "This is a security prompt for testing." + ) + assert security_event.kind == "SecurityPromptEvent" + + +def test_security_prompt_event_to_llm_message(): + """Test SecurityPromptEvent to_llm_message conversion.""" + security_event = SecurityPromptEvent( + source="agent", + security_prompt=TextContent(text="Security analyzer instructions."), + ) + + message = security_event.to_llm_message() + + assert message.role == "system" + assert len(message.content) == 1 + content_item = message.content[0] + assert isinstance(content_item, TextContent) + assert content_item.text == "Security analyzer instructions." + + +def test_security_prompt_event_visualize(): + """Test SecurityPromptEvent visualize method.""" + security_event = SecurityPromptEvent( + source="agent", security_prompt=TextContent(text="Security prompt content.") + ) + + visualization = security_event.visualize + + assert "Security Prompt:" in visualization + assert "Security prompt content." in visualization + + +def test_security_prompt_event_str(): + """Test SecurityPromptEvent string representation.""" + security_event = SecurityPromptEvent( + source="agent", security_prompt=TextContent(text="Security prompt content.") + ) + + str_repr = str(security_event) + + assert "SecurityPromptEvent (agent)" in str_repr + assert "Security: Security prompt content." in str_repr + + +def test_view_from_events_security_analyzer_enabled(): + """Test View.from_events includes SecurityPromptEvent when security analyzer is enabled.""" # noqa: E501 + system_event = SystemPromptEvent( + source="agent", + system_prompt=TextContent(text="System prompt"), + tools=[], + ) + security_event = SecurityPromptEvent( + source="agent", + security_prompt=TextContent(text="Security prompt"), + ) + events = [system_event, security_event] + + # When security analyzer is enabled, SecurityPromptEvent should be included + view_enabled = View.from_events(events, is_security_analyzer_enabled=True) + assert len(view_enabled.events) == 2 + assert any(isinstance(e, SecurityPromptEvent) for e in view_enabled.events) + + # When security analyzer is disabled, SecurityPromptEvent should be excluded + view_disabled = View.from_events(events) + assert len(view_disabled.events) == 1 + assert not any(isinstance(e, SecurityPromptEvent) for e in view_disabled.events) + assert any(isinstance(e, SystemPromptEvent) for e in view_disabled.events)