Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Tests

on:
push:
branches: [ main, fix/tests-ci-security ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e core/
# Also install dev dependencies if needed for tests
pip install -e "core/[dev]"

- name: Run tests
run: |
python -m pytest core/tests/ -v
6 changes: 3 additions & 3 deletions core/framework/credentials/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def list_all(self) -> list[str]:
index_path = self.base_path / "metadata" / "index.json"
if not index_path.exists():
return []
with open(index_path) as f:
with open(index_path, encoding="utf-8") as f:
index = json.load(f)
return list(index.get("credentials", {}).keys())

Expand Down Expand Up @@ -265,7 +265,7 @@ def _update_index(
index_path = self.base_path / "metadata" / "index.json"

if index_path.exists():
with open(index_path) as f:
with open(index_path, encoding="utf-8") as f:
index = json.load(f)
else:
index = {"credentials": {}, "version": "1.0"}
Expand All @@ -280,7 +280,7 @@ def _update_index(

index["last_modified"] = datetime.now(UTC).isoformat()

with open(index_path, "w") as f:
with open(index_path, "w", encoding="utf-8") as f:
json.dump(index, f, indent=2)


Expand Down
8 changes: 4 additions & 4 deletions core/framework/graph/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class EdgeSpec(BaseModel):

model_config = {"extra": "allow"}

def should_traverse(
async def should_traverse(
self,
source_success: bool,
source_output: dict[str, Any],
Expand Down Expand Up @@ -139,7 +139,7 @@ def should_traverse(
if llm is None or goal is None:
# Fallback to ON_SUCCESS if LLM not available
return source_success
return self._llm_decide(
return await self._llm_decide(
llm=llm,
goal=goal,
source_success=source_success,
Expand Down Expand Up @@ -184,7 +184,7 @@ def _evaluate_condition(
logger.warning(f" Available context keys: {list(context.keys())}")
return False

def _llm_decide(
async def _llm_decide(
self,
llm: Any,
goal: Any,
Expand Down Expand Up @@ -230,7 +230,7 @@ def _llm_decide(
{{"proceed": true/false, "reasoning": "brief explanation"}}"""

try:
response = llm.complete(
response = await llm.complete(
messages=[{"role": "user", "content": prompt}],
system="You are a routing agent. Respond with JSON only.",
max_tokens=150,
Expand Down
16 changes: 8 additions & 8 deletions core/framework/graph/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ async def execute(
current_node_id = result.next_node
else:
# Get all traversable edges for fan-out detection
traversable_edges = self._get_all_traversable_edges(
traversable_edges = await self._get_all_traversable_edges(
graph=graph,
goal=goal,
current_node_id=current_node_id,
Expand Down Expand Up @@ -500,7 +500,7 @@ async def execute(
break
else:
# Sequential: follow single edge (existing logic via _follow_edges)
next_node = self._follow_edges(
next_node = await self._follow_edges(
graph=graph,
goal=goal,
current_node_id=current_node_id,
Expand Down Expand Up @@ -650,7 +650,7 @@ def _get_node_implementation(
# Should never reach here due to validation above
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")

def _follow_edges(
async def _follow_edges(
self,
graph: GraphSpec,
goal: Goal,
Expand All @@ -665,7 +665,7 @@ def _follow_edges(
for edge in edges:
target_node_spec = graph.get_node(edge.target)

if edge.should_traverse(
if await edge.should_traverse(
source_success=result.success,
source_output=result.output,
memory=memory.read_all(),
Expand All @@ -688,7 +688,7 @@ def _follow_edges(
self.logger.warning(f"⚠ Output validation failed: {validation.errors}")

# Clean the output
cleaned_output = self.output_cleaner.clean_output(
cleaned_output = await self.output_cleaner.clean_output(
output=output_to_validate,
source_node_id=current_node_id,
target_node_spec=target_node_spec,
Expand Down Expand Up @@ -726,7 +726,7 @@ def _follow_edges(

return None

def _get_all_traversable_edges(
async def _get_all_traversable_edges(
self,
graph: GraphSpec,
goal: Goal,
Expand All @@ -746,7 +746,7 @@ def _get_all_traversable_edges(

for edge in edges:
target_node_spec = graph.get_node(edge.target)
if edge.should_traverse(
if await edge.should_traverse(
source_success=result.success,
source_output=result.output,
memory=memory.read_all(),
Expand Down Expand Up @@ -859,7 +859,7 @@ async def execute_single_branch(
f"⚠ Output validation failed for branch "
f"{branch.node_id}: {validation.errors}"
)
cleaned_output = self.output_cleaner.clean_output(
cleaned_output = await self.output_cleaner.clean_output(
output=source_result.output,
source_node_id=source_node_spec.id if source_node_spec else "unknown",
target_node_spec=node_spec,
Expand Down
10 changes: 5 additions & 5 deletions core/framework/graph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,15 +795,15 @@ def executor(tool_use: ToolUse) -> ToolResult:

# Retry the call with compaction instruction
if ctx.available_tools and self.tool_executor:
response = ctx.llm.complete_with_tools(
response = await ctx.llm.complete_with_tools(
messages=compaction_messages,
system=system,
tools=ctx.available_tools,
tool_executor=executor,
max_tokens=ctx.max_tokens,
)
else:
response = ctx.llm.complete(
response = await ctx.llm.complete(
messages=compaction_messages,
system=system,
json_mode=use_json_mode,
Expand Down Expand Up @@ -884,15 +884,15 @@ def executor(tool_use: ToolUse) -> ToolResult:

# Re-call LLM with feedback
if ctx.available_tools and self.tool_executor:
response = ctx.llm.complete_with_tools(
response = await ctx.llm.complete_with_tools(
messages=current_messages,
system=system,
tools=ctx.available_tools,
tool_executor=executor,
max_tokens=ctx.max_tokens,
)
else:
response = ctx.llm.complete(
response = await ctx.llm.complete(
messages=current_messages,
system=system,
json_mode=use_json_mode,
Expand Down Expand Up @@ -1514,7 +1514,7 @@ async def _llm_route(
logger.info(" 🤔 Router using LLM to choose path...")

try:
response = ctx.llm.complete(
response = await ctx.llm.complete(
messages=[{"role": "user", "content": prompt}],
system=ctx.node_spec.system_prompt
or "You are a routing agent. Respond with JSON only.",
Expand Down
4 changes: 2 additions & 2 deletions core/framework/graph/output_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def validate_output(
warnings=warnings,
)

def clean_output(
async def clean_output(
self,
output: dict[str, Any],
source_node_id: str,
Expand Down Expand Up @@ -286,7 +286,7 @@ def clean_output(
f"🧹 Cleaning output from '{source_node_id}' using {self.config.fast_model}"
)

response = self.llm.complete(
response = await self.llm.complete(
messages=[{"role": "user", "content": prompt}],
system=(
"You clean malformed agent outputs. Return only valid JSON matching the schema."
Expand Down
16 changes: 11 additions & 5 deletions core/framework/graph/safe_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,23 @@ def visit_Expr(self, node: ast.Expr) -> Any:
return self.visit(node.value)

def visit_Constant(self, node: ast.Constant) -> Any:
return node.value
# Strictly allow only basic types: int, float, str, bool, bytes, NoneType
if isinstance(node.value, (int, float, str, bool, bytes, type(None))):
return node.value
raise ValueError(f"Constant of type {type(node.value).__name__} is not allowed")

# --- Number/String/Bytes/NameConstant (Python < 3.8 compat if needed) ---
# --- Number/String/Bytes/NameConstant (Python < 3.8 compat) ---
def visit_Num(self, node: ast.Num) -> Any:
return node.n
return self.visit_Constant(ast.Constant(value=node.n))

def visit_Str(self, node: ast.Str) -> Any:
return node.s
return self.visit_Constant(ast.Constant(value=node.s))

def visit_Bytes(self, node: ast.Bytes) -> Any:
return self.visit_Constant(ast.Constant(value=node.s))

def visit_NameConstant(self, node: ast.NameConstant) -> Any:
return node.value
return self.visit_Constant(ast.Constant(value=node.value))

# --- Data Structures ---
def visit_List(self, node: ast.List) -> list:
Expand Down
13 changes: 9 additions & 4 deletions core/framework/llm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from framework.llm.litellm import LiteLLMProvider
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
from framework.llm.resilience import ResilienceConfig


def _get_api_key_from_credential_manager() -> str | None:
Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
self,
api_key: str | None = None,
model: str = "claude-haiku-4-5-20251001",
resilience_config: ResilienceConfig | None = None,
):
"""
Initialize the Anthropic provider.
Expand All @@ -47,7 +49,9 @@ def __init__(
api_key: Anthropic API key. If not provided, uses CredentialManager
or ANTHROPIC_API_KEY env var.
model: Model to use (default: claude-haiku-4-5-20251001)
resilience_config: Optional resilience configuration.
"""
super().__init__(resilience_config)
# Delegate to LiteLLMProvider internally.
self.api_key = api_key or _get_api_key_from_credential_manager()
if not self.api_key:
Expand All @@ -60,9 +64,10 @@ def __init__(
self._provider = LiteLLMProvider(
model=model,
api_key=self.api_key,
resilience_config=self.resilience_config,
)

def complete(
async def complete(
self,
messages: list[dict[str, Any]],
system: str = "",
Expand All @@ -72,7 +77,7 @@ def complete(
json_mode: bool = False,
) -> LLMResponse:
"""Generate a completion from Claude (via LiteLLM)."""
return self._provider.complete(
return await self._provider.complete(
messages=messages,
system=system,
tools=tools,
Expand All @@ -81,7 +86,7 @@ def complete(
json_mode=json_mode,
)

def complete_with_tools(
async def complete_with_tools(
self,
messages: list[dict[str, Any]],
system: str,
Expand All @@ -90,7 +95,7 @@ def complete_with_tools(
max_iterations: int = 10,
) -> LLMResponse:
"""Run a tool-use loop until Claude produces a final response (via LiteLLM)."""
return self._provider.complete_with_tools(
return await self._provider.complete_with_tools(
messages=messages,
system=system,
tools=tools,
Expand Down
Loading
Loading