diff --git a/CHANGELOG.md b/CHANGELOG.md index 59ed3bb..5c41134 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - nsjail-based sandboxing for code execution (replaces Docker socket-based approach) -- Single unified Docker image with all 12 language runtimes +- Single unified Docker image with all 13 language runtimes - Hour and day periods for execution heatmap visualizations - MyPy type checking integration with comprehensive type hints - Dynamic Content Security Policy headers based on request path @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added #### Core Features -- Multi-language code execution supporting 12 languages: Python, JavaScript, TypeScript, Go, Java, C, C++, PHP, Rust, R, Fortran, and D +- Multi-language code execution supporting 13 languages: Python, JavaScript, TypeScript, Go, Java, C, C++, PHP, Rust, R, Fortran, D, and Bash - FastAPI-based REST API with interactive documentation - Sandboxed execution environments with comprehensive security controls - Redis-based session management with automatic cleanup diff --git a/Dockerfile b/Dockerfile index 46f580c..469a35a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -267,8 +267,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # REPL Server + entrypoint # ============================================ COPY docker/repl_server.py /opt/repl_server.py +COPY docker/ptc_server.py /opt/ptc_server.py COPY docker/entrypoint.sh /opt/entrypoint.sh -RUN chmod +x /opt/repl_server.py /opt/entrypoint.sh +RUN chmod +x /opt/repl_server.py /opt/ptc_server.py /opt/entrypoint.sh # ============================================ # Sandbox directory structure diff --git a/README.md b/README.md index 67bdc6c..418b213 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Get up and running in minutes by building the execution environment. docker build -t code-interpreter:nsjail . ``` - This builds a single image containing all 12 language runtimes and nsjail for sandboxed execution. + This builds a single image containing all 13 language runtimes and nsjail for sandboxed execution. 4. **Start the API** @@ -55,7 +55,7 @@ The dashboard requires the master API key for authentication. ## Features -- **Multi-language Support**: Execute code in 12 languages - Python, JavaScript, TypeScript, Go, Java, C, C++, PHP, Rust, R, Fortran, and D +- **Multi-language Support**: Execute code in 13 languages - Python, JavaScript, TypeScript, Go, Java, C, C++, PHP, Rust, R, Fortran, D, and Bash - **Sub-50ms Python Execution**: Pre-warmed REPL sandboxes achieve ~20-40ms latency for simple Python code - **Sandbox Pool**: Pre-warmed nsjail sandboxes provide ~3ms acquisition time (vs 500-2000ms cold start) - **High Concurrency**: Thread-safe execution supporting 10+ concurrent requests @@ -88,7 +88,7 @@ For a deep dive into the system design, components, and request flows, see [ARCH The API provides endpoints for code execution, file management, and session state control. -- `POST /exec`: Execute code in one of the 12 supported languages. +- `POST /exec`: Execute code in one of the 13 supported languages. - `POST /upload`: Upload files for processing. - `GET /download`: Retrieve generated files. @@ -98,7 +98,7 @@ For detailed information on all endpoints and specific language notes, see [ARCH ## Supported Languages -We support 12 programming languages including Python, JavaScript, TypeScript, Go, Rust, and more. Each language has optimized execution paths and resource limits. +We support 13 programming languages including Python, JavaScript, TypeScript, Go, Rust, Bash, and more. Each language has optimized execution paths and resource limits. See the [Supported Languages table](docs/ARCHITECTURE.md#supported-languages) for details on versions and included libraries. diff --git a/docker/ptc_server.py b/docker/ptc_server.py new file mode 100644 index 0000000..1b9b5ca --- /dev/null +++ b/docker/ptc_server.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +"""Programmatic Tool Calling (PTC) Server for nsjail sandbox execution. + +This script runs INSIDE the nsjail sandbox and provides a Python execution +environment where code can call externally-defined tools. Tool calls are +serialized as JSON over stdin/stdout, allowing the host process to fulfill +them and send results back. + +Protocol: +1. Host sends initial request via stdin: + {"code": "...", "tools": [{"name": "...", "description": "...", "parameters": {...}}]} + +2. Code executes. When a tool stub is called, PTC server writes to stdout: + {"type": "tool_calls", "calls": [{"id": "...", "name": "...", "input": {...}}]} + +3. Host reads tool_calls, fulfills them, and writes results to stdin: + {"type": "tool_results", "results": [{"call_id": "...", "result": ..., "is_error": false}]} + +4. Code continues. On completion, PTC server writes: + {"type": "completed", "stdout": "...", "stderr": "..."} + +5. On error, PTC server writes: + {"type": "error", "error": "..."} +""" + +import asyncio +import json +import os +import sys +import traceback +import uuid +from io import StringIO + +DELIMITER = "\n---PTC_END---\n" + +# Keep references to the REAL stdin/stdout for protocol communication. +# User code's print() will be redirected to a StringIO capture buffer. +_real_stdin = sys.stdin +_real_stdout = sys.stdout +_real_stderr = sys.stderr + + +def _write_message(msg: dict) -> None: + """Write a JSON message to the host via the real stdout.""" + data = json.dumps(msg) + DELIMITER + _real_stdout.write(data) + _real_stdout.flush() + + +def _read_message() -> dict: + """Read a JSON message from the host via the real stdin.""" + buf = "" + while True: + line = _real_stdin.readline() + if not line: + raise EOFError("stdin closed") + buf += line + if DELIMITER in buf: + json_part = buf.split(DELIMITER)[0] + return json.loads(json_part) + + +# Pending tool calls collected during async execution +_pending_calls = [] +_tool_results_map = {} # call_id -> result + + +def _make_tool_stub(tool_name: str) -> callable: + """Create an async function stub for a tool.""" + + async def tool_stub(**kwargs): + call_id = uuid.uuid4().hex[:12] + call_info = { + "id": call_id, + "name": tool_name, + "input": kwargs, + } + _pending_calls.append(call_info) + + # Wait for result - the main loop will flush calls and read results + while call_id not in _tool_results_map: + await asyncio.sleep(0.01) + + result_info = _tool_results_map.pop(call_id) + if result_info.get("is_error"): + raise RuntimeError( + result_info.get("error_message", "Tool call failed") + ) + return result_info.get("result") + + tool_stub.__name__ = tool_name + tool_stub.__qualname__ = tool_name + return tool_stub + + +async def _execute_with_tools( + code: str, tools: list, user_stdout: StringIO, user_stderr: StringIO +) -> dict: + """Execute code with tool stubs, capturing user output.""" + global _pending_calls, _tool_results_map + + _pending_calls = [] + _tool_results_map = {} + + # Build namespace with tool stubs + namespace = {"__builtins__": __builtins__, "__name__": "__main__"} + + try: + import json as _json + + namespace["json"] = _json + except ImportError: + pass + + for tool in tools: + namespace[tool["name"]] = _make_tool_stub(tool["name"]) + + # Wrap user code in async function + indented_code = "\n".join(" " + line for line in code.split("\n")) + wrapped_code = f"async def __ptc_main__():\n{indented_code}\n" + + try: + compiled = compile(wrapped_code, "", "exec") + exec(compiled, namespace) + except SyntaxError as e: + return {"type": "error", "error": f"SyntaxError: {e}"} + + main_func = namespace["__ptc_main__"] + main_task = asyncio.ensure_future(main_func()) + + try: + while not main_task.done(): + # Let the task run briefly to accumulate batched calls + await asyncio.sleep(0.05) + + if _pending_calls and not main_task.done(): + calls_to_send = list(_pending_calls) + _pending_calls.clear() + + _write_message({ + "type": "tool_calls", + "calls": calls_to_send, + }) + + # Wait for results from host + response = _read_message() + + if response.get("type") != "tool_results": + return { + "type": "error", + "error": f"Expected tool_results, got " + f"{response.get('type')}", + } + + for result in response.get("results", []): + _tool_results_map[result["call_id"]] = result + + # Task completed + main_task.result() + return {"type": "completed"} + + except Exception as e: + tb = traceback.format_exc() + return { + "type": "error", + "error": str(e), + "stderr_extra": tb, + } + + +def main(): + """Main entry point for PTC server.""" + try: + os.chdir("/mnt/data") + except OSError: + pass + + # Read initial request + try: + request = _read_message() + except Exception as e: + _write_message({ + "type": "error", + "error": f"Failed to read initial request: {e}", + }) + return + + code = request.get("code", "") + tools = request.get("tools", []) + + if not code: + _write_message({"type": "error", "error": "No code provided"}) + return + + # Redirect sys.stdout and sys.stderr so user print() calls + # are captured, not mixed with our protocol messages. + user_stdout = StringIO() + user_stderr = StringIO() + sys.stdout = user_stdout + sys.stderr = user_stderr + + try: + result = asyncio.run( + _execute_with_tools(code, tools, user_stdout, user_stderr) + ) + except Exception as e: + result = { + "type": "error", + "error": str(e), + } + + # Restore real stdout for final message + sys.stdout = _real_stdout + sys.stderr = _real_stderr + + # Attach captured user output + result["stdout"] = user_stdout.getvalue() + stderr_val = user_stderr.getvalue() + if result.get("stderr_extra"): + stderr_val += result.pop("stderr_extra") + result["stderr"] = stderr_val + + _write_message(result) + + +if __name__ == "__main__": + main() diff --git a/scripts/load_test/config.py b/scripts/load_test/config.py index 3e4bca4..3b9e744 100644 --- a/scripts/load_test/config.py +++ b/scripts/load_test/config.py @@ -60,7 +60,21 @@ } # Supported languages -SUPPORTED_LANGUAGES = ["py", "js", "ts", "go", "java", "c", "cpp", "php", "rs", "r", "f90", "d"] +SUPPORTED_LANGUAGES = [ + "py", + "js", + "ts", + "go", + "java", + "c", + "cpp", + "php", + "rs", + "r", + "f90", + "d", + "bash", +] @dataclass @@ -112,11 +126,7 @@ def get_api_key(self) -> str: } -def get_vm_type( - cpu_cores: int, - memory_gb: int, - provider: str = "azure" -) -> str: +def get_vm_type(cpu_cores: int, memory_gb: int, provider: str = "azure") -> str: """Get recommended VM type for given resources.""" vm_maps = { "azure": AZURE_VM_TYPES, diff --git a/scripts/load_test/scenarios/multi_language.py b/scripts/load_test/scenarios/multi_language.py index bb14cb3..3043d44 100644 --- a/scripts/load_test/scenarios/multi_language.py +++ b/scripts/load_test/scenarios/multi_language.py @@ -1,9 +1,8 @@ -"""Multi-language test scenarios for all 12 supported languages.""" +"""Multi-language test scenarios for all 13 supported languages.""" from typing import List from .base import BaseScenario - # Language-specific hello world and compute code LANGUAGE_CODE = { "py": { @@ -150,6 +149,15 @@ writeln("D compute result: ", result); }""", }, + "bash": { + "baseline": 'echo "Hello from Bash"', + "compute": """sum=0 +for i in $(seq 0 9999); do + sum=$((sum + i * i)) +done +echo "Bash compute result: $sum" +""", + }, } diff --git a/scripts/locustfile.py b/scripts/locustfile.py index 6b6b00d..a1519b1 100644 --- a/scripts/locustfile.py +++ b/scripts/locustfile.py @@ -19,7 +19,6 @@ import os from locust import HttpUser, task, between, tag - # API key from environment or default API_KEY = os.environ.get("API_KEY", "test-api-key-for-development-only") @@ -46,78 +45,108 @@ def on_start(self): @task(10) def cpu_light(self): """Light CPU computation - simple math.""" - self.client.post("/exec", json={ - "lang": "py", - "code": "result = sum(range(10000))\nprint(f'Sum: {result}')", - }, headers=self.headers, name="CPU Light") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "result = sum(range(10000))\nprint(f'Sum: {result}')", + }, + headers=self.headers, + name="CPU Light", + ) @tag("cpu", "cpu_medium") @task(5) def cpu_medium(self): """Medium CPU computation - moderate math.""" - self.client.post("/exec", json={ - "lang": "py", - "code": "result = sum(i**2 for i in range(100000))\nprint(f'Sum of squares: {result}')", - }, headers=self.headers, name="CPU Medium") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "result = sum(i**2 for i in range(100000))\nprint(f'Sum of squares: {result}')", + }, + headers=self.headers, + name="CPU Medium", + ) @tag("cpu", "cpu_heavy") @task(2) def cpu_heavy(self): """Heavy CPU computation - matrix multiplication.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import numpy as np + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import numpy as np size = 500 a = np.random.rand(size, size) b = np.random.rand(size, size) c = np.dot(a, b) print(f'Matrix: shape={c.shape}, sum={c.sum():.4f}')""", - }, headers=self.headers, name="CPU Heavy") + }, + headers=self.headers, + name="CPU Heavy", + ) @tag("cpu", "cpu_sklearn") @task(1) def cpu_sklearn(self): """ML model training with sklearn.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import numpy as np + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import numpy as np from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import make_classification X, y = make_classification(n_samples=500, n_features=20, n_informative=10, random_state=42) model = RandomForestClassifier(n_estimators=50, random_state=42) model.fit(X, y) print(f'RandomForest score={model.score(X, y):.4f}')""", - }, headers=self.headers, name="CPU Sklearn") + }, + headers=self.headers, + name="CPU Sklearn", + ) @tag("cpu", "cpu_prime") @task(3) def cpu_prime(self): """Prime number computation.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """def is_prime(n): + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """def is_prime(n): if n < 2: return False for i in range(2, int(n**0.5) + 1): if n % i == 0: return False return True primes = [n for n in range(10000) if is_prime(n)] print(f'Found {len(primes)} primes, largest={primes[-1]}')""", - }, headers=self.headers, name="CPU Prime") + }, + headers=self.headers, + name="CPU Prime", + ) @tag("cpu", "cpu_fibonacci") @task(3) def cpu_fibonacci(self): """Fibonacci sequence computation.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """def fib(n): + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """def fib(n): a, b = 0, 1 for _ in range(n): a, b = b, a + b return a result = fib(10000) print(f'fib={str(result)[:50]}...')""", - }, headers=self.headers, name="CPU Fibonacci") + }, + headers=self.headers, + name="CPU Fibonacci", + ) # ========================================================================= # Memory-Bound Tests (6 scenarios) @@ -127,45 +156,62 @@ def cpu_fibonacci(self): @task(5) def mem_10mb(self): """Allocate 10MB NumPy array.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import numpy as np + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import numpy as np size = 1310720 # 10MB arr = np.random.rand(size) print(f'Allocated 10MB, sum={arr.sum():.4f}')""", - }, headers=self.headers, name="Memory 10MB") + }, + headers=self.headers, + name="Memory 10MB", + ) @tag("memory", "mem_50mb") @task(3) def mem_50mb(self): """Allocate 50MB NumPy array.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import numpy as np + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import numpy as np size = 6553600 # 50MB arr = np.random.rand(size) print(f'Allocated 50MB, mean={arr.mean():.6f}')""", - }, headers=self.headers, name="Memory 50MB") + }, + headers=self.headers, + name="Memory 50MB", + ) @tag("memory", "mem_100mb") @task(2) def mem_100mb(self): """Allocate 100MB NumPy array.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import numpy as np + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import numpy as np size = 13107200 # 100MB arr = np.random.rand(size) print(f'Allocated 100MB, std={arr.std():.6f}')""", - }, headers=self.headers, name="Memory 100MB") + }, + headers=self.headers, + name="Memory 100MB", + ) @tag("memory", "mem_pandas") @task(2) def mem_pandas(self): """1M row DataFrame operations.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import pandas as pd + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import pandas as pd import numpy as np n_rows = 1000000 df = pd.DataFrame({ @@ -176,36 +222,49 @@ def mem_pandas(self): }) grouped = df.groupby('d').agg({'a': 'mean', 'b': 'sum', 'c': 'max'}) print(f'DataFrame shape={df.shape}, memory={df.memory_usage(deep=True).sum() / 1e6:.1f}MB')""", - }, headers=self.headers, name="Memory Pandas") + }, + headers=self.headers, + name="Memory Pandas", + ) @tag("memory", "mem_list") @task(3) def mem_list(self): """Large Python list (5M integers).""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import sys + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import sys size = 5000000 data = list(range(size)) total = sum(data) filtered = [x for x in data if x % 2 == 0] mem_mb = sys.getsizeof(data) / (1024 * 1024) print(f'List size={size}, sum={total}, even_count={len(filtered)}, mem~{mem_mb:.1f}MB')""", - }, headers=self.headers, name="Memory List") + }, + headers=self.headers, + name="Memory List", + ) @tag("memory", "mem_dict") @task(3) def mem_dict(self): """Large dictionary (1M entries).""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import sys + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import sys size = 1000000 data = {i: f'value_{i}' for i in range(size)} keys = list(data.keys()) mem_mb = sys.getsizeof(data) / (1024 * 1024) print(f'Dict size={len(data)}, first_key={keys[0]}, mem~{mem_mb:.1f}MB')""", - }, headers=self.headers, name="Memory Dict") + }, + headers=self.headers, + name="Memory Dict", + ) # ========================================================================= # I/O-Bound Tests (6 scenarios) @@ -216,35 +275,47 @@ def mem_dict(self): def io_small(self): """Write 10 x 100KB files.""" self._iteration_counter += 1 - self.client.post("/exec", json={ - "lang": "py", - "code": f"""import os + self.client.post( + "/exec", + json={ + "lang": "py", + "code": f"""import os for i in range(10): with open(f'/mnt/data/small_{{i}}.txt', 'w') as f: f.write('x' * 102400) print('Created 10 x 100KB files')""", - }, headers=self.headers, name="I/O Small Files") + }, + headers=self.headers, + name="I/O Small Files", + ) @tag("io", "io_large") @task(2) def io_large(self): """Write 3 x 1MB files.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import os + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import os for i in range(3): with open(f'/mnt/data/large_{i}.txt', 'w') as f: f.write('y' * 1048576) print('Created 3 x 1MB files')""", - }, headers=self.headers, name="I/O Large Files") + }, + headers=self.headers, + name="I/O Large Files", + ) @tag("io", "io_matplotlib") @task(2) def io_matplotlib(self): """Generate matplotlib PNG plot.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import matplotlib + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np @@ -256,15 +327,20 @@ def io_matplotlib(self): plt.close() import os print(f'Plot size: {os.path.getsize("/mnt/data/plot.png")/1024:.1f}KB')""", - }, headers=self.headers, name="I/O Matplotlib") + }, + headers=self.headers, + name="I/O Matplotlib", + ) @tag("io", "io_csv") @task(3) def io_csv(self): """CSV read/write with pandas.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import pandas as pd + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import pandas as pd import numpy as np import os df = pd.DataFrame({ @@ -277,15 +353,20 @@ def io_csv(self): df_read['sum'] = df_read['value_a'] + df_read['value_b'] df_read.to_csv('/mnt/data/output.csv', index=False) print(f'CSV size: {os.path.getsize("/mnt/data/output.csv")/1024:.0f}KB')""", - }, headers=self.headers, name="I/O CSV") + }, + headers=self.headers, + name="I/O CSV", + ) @tag("io", "io_json") @task(3) def io_json(self): """JSON read/write with nested data.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import json + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import json import os data = { 'records': [ @@ -298,15 +379,20 @@ def io_json(self): with open('/mnt/data/data.json', 'r') as f: loaded = json.load(f) print(f'Records: {len(loaded["records"])}, Size: {os.path.getsize("/mnt/data/data.json")/1024:.0f}KB')""", - }, headers=self.headers, name="I/O JSON") + }, + headers=self.headers, + name="I/O JSON", + ) @tag("io", "io_image") @task(1) def io_image(self): """OpenCV image processing.""" - self.client.post("/exec", json={ - "lang": "py", - "code": """import cv2 + self.client.post( + "/exec", + json={ + "lang": "py", + "code": """import cv2 import numpy as np import os img = np.random.randint(0, 255, (800, 1200, 3), dtype=np.uint8) @@ -316,151 +402,354 @@ def io_image(self): cv2.imwrite('/mnt/data/processed.png', img_blur) cv2.imwrite('/mnt/data/edges.png', edges) print(f'Processed: {os.path.getsize("/mnt/data/processed.png")/1024:.0f}KB')""", - }, headers=self.headers, name="I/O Image") + }, + headers=self.headers, + name="I/O Image", + ) # ========================================================================= - # Multi-Language Tests (24 scenarios - 12 languages x 2) + # Multi-Language Tests (26 scenarios - 13 languages x 2) # ========================================================================= # Python @tag("language", "py") @task(2) def python_baseline(self): - self.client.post("/exec", json={"lang": "py", "code": 'print("Hello from Python")'}, headers=self.headers, name="Python Baseline") + self.client.post( + "/exec", + json={"lang": "py", "code": 'print("Hello from Python")'}, + headers=self.headers, + name="Python Baseline", + ) @tag("language", "py") @task(1) def python_compute(self): - self.client.post("/exec", json={"lang": "py", "code": 'result = sum(i*i for i in range(10000))\nprint(f"Result: {result}")'}, headers=self.headers, name="Python Compute") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": 'result = sum(i*i for i in range(10000))\nprint(f"Result: {result}")', + }, + headers=self.headers, + name="Python Compute", + ) # JavaScript @tag("language", "js") @task(2) def javascript_baseline(self): - self.client.post("/exec", json={"lang": "js", "code": 'console.log("Hello from JavaScript");'}, headers=self.headers, name="JavaScript Baseline") + self.client.post( + "/exec", + json={"lang": "js", "code": 'console.log("Hello from JavaScript");'}, + headers=self.headers, + name="JavaScript Baseline", + ) @tag("language", "js") @task(1) def javascript_compute(self): - self.client.post("/exec", json={"lang": "js", "code": 'let r=0; for(let i=0;i<10000;i++) r+=i*i; console.log("Result:",r);'}, headers=self.headers, name="JavaScript Compute") + self.client.post( + "/exec", + json={ + "lang": "js", + "code": 'let r=0; for(let i=0;i<10000;i++) r+=i*i; console.log("Result:",r);', + }, + headers=self.headers, + name="JavaScript Compute", + ) # TypeScript @tag("language", "ts") @task(2) def typescript_baseline(self): - self.client.post("/exec", json={"lang": "ts", "code": 'console.log("Hello from TypeScript");'}, headers=self.headers, name="TypeScript Baseline") + self.client.post( + "/exec", + json={"lang": "ts", "code": 'console.log("Hello from TypeScript");'}, + headers=self.headers, + name="TypeScript Baseline", + ) @tag("language", "ts") @task(1) def typescript_compute(self): - self.client.post("/exec", json={"lang": "ts", "code": 'let r:number=0; for(let i:number=0;i<10000;i++) r+=i*i; console.log("Result:",r);'}, headers=self.headers, name="TypeScript Compute") + self.client.post( + "/exec", + json={ + "lang": "ts", + "code": 'let r:number=0; for(let i:number=0;i<10000;i++) r+=i*i; console.log("Result:",r);', + }, + headers=self.headers, + name="TypeScript Compute", + ) # Go @tag("language", "go") @task(2) def go_baseline(self): - self.client.post("/exec", json={"lang": "go", "code": 'package main\nimport "fmt"\nfunc main() { fmt.Println("Hello from Go") }'}, headers=self.headers, name="Go Baseline") + self.client.post( + "/exec", + json={ + "lang": "go", + "code": 'package main\nimport "fmt"\nfunc main() { fmt.Println("Hello from Go") }', + }, + headers=self.headers, + name="Go Baseline", + ) @tag("language", "go") @task(1) def go_compute(self): - self.client.post("/exec", json={"lang": "go", "code": 'package main\nimport "fmt"\nfunc main() { r:=0; for i:=0;i<10000;i++ { r+=i*i }; fmt.Println("Result:",r) }'}, headers=self.headers, name="Go Compute") + self.client.post( + "/exec", + json={ + "lang": "go", + "code": 'package main\nimport "fmt"\nfunc main() { r:=0; for i:=0;i<10000;i++ { r+=i*i }; fmt.Println("Result:",r) }', + }, + headers=self.headers, + name="Go Compute", + ) # Java @tag("language", "java") @task(2) def java_baseline(self): - self.client.post("/exec", json={"lang": "java", "code": 'public class Main { public static void main(String[] args) { System.out.println("Hello from Java"); } }'}, headers=self.headers, name="Java Baseline") + self.client.post( + "/exec", + json={ + "lang": "java", + "code": 'public class Main { public static void main(String[] args) { System.out.println("Hello from Java"); } }', + }, + headers=self.headers, + name="Java Baseline", + ) @tag("language", "java") @task(1) def java_compute(self): - self.client.post("/exec", json={"lang": "java", "code": 'public class Main { public static void main(String[] args) { long r=0; for(int i=0;i<10000;i++) r+=(long)i*i; System.out.println("Result: "+r); } }'}, headers=self.headers, name="Java Compute") + self.client.post( + "/exec", + json={ + "lang": "java", + "code": 'public class Main { public static void main(String[] args) { long r=0; for(int i=0;i<10000;i++) r+=(long)i*i; System.out.println("Result: "+r); } }', + }, + headers=self.headers, + name="Java Compute", + ) # C @tag("language", "c") @task(2) def c_baseline(self): - self.client.post("/exec", json={"lang": "c", "code": '#include \nint main() { printf("Hello from C\\n"); return 0; }'}, headers=self.headers, name="C Baseline") + self.client.post( + "/exec", + json={ + "lang": "c", + "code": '#include \nint main() { printf("Hello from C\\n"); return 0; }', + }, + headers=self.headers, + name="C Baseline", + ) @tag("language", "c") @task(1) def c_compute(self): - self.client.post("/exec", json={"lang": "c", "code": '#include \nint main() { long long r=0; for(int i=0;i<10000;i++) r+=(long long)i*i; printf("Result: %lld\\n",r); return 0; }'}, headers=self.headers, name="C Compute") + self.client.post( + "/exec", + json={ + "lang": "c", + "code": '#include \nint main() { long long r=0; for(int i=0;i<10000;i++) r+=(long long)i*i; printf("Result: %lld\\n",r); return 0; }', + }, + headers=self.headers, + name="C Compute", + ) # C++ @tag("language", "cpp") @task(2) def cpp_baseline(self): - self.client.post("/exec", json={"lang": "cpp", "code": '#include \nint main() { std::cout << "Hello from C++" << std::endl; return 0; }'}, headers=self.headers, name="C++ Baseline") + self.client.post( + "/exec", + json={ + "lang": "cpp", + "code": '#include \nint main() { std::cout << "Hello from C++" << std::endl; return 0; }', + }, + headers=self.headers, + name="C++ Baseline", + ) @tag("language", "cpp") @task(1) def cpp_compute(self): - self.client.post("/exec", json={"lang": "cpp", "code": '#include \nint main() { long long r=0; for(int i=0;i<10000;i++) r+=(long long)i*i; std::cout << "Result: " << r << std::endl; return 0; }'}, headers=self.headers, name="C++ Compute") + self.client.post( + "/exec", + json={ + "lang": "cpp", + "code": '#include \nint main() { long long r=0; for(int i=0;i<10000;i++) r+=(long long)i*i; std::cout << "Result: " << r << std::endl; return 0; }', + }, + headers=self.headers, + name="C++ Compute", + ) # PHP @tag("language", "php") @task(2) def php_baseline(self): - self.client.post("/exec", json={"lang": "php", "code": ''}, headers=self.headers, name="PHP Baseline") + self.client.post( + "/exec", + json={"lang": "php", "code": ''}, + headers=self.headers, + name="PHP Baseline", + ) @tag("language", "php") @task(1) def php_compute(self): - self.client.post("/exec", json={"lang": "php", "code": ''}, headers=self.headers, name="PHP Compute") + self.client.post( + "/exec", + json={ + "lang": "php", + "code": '', + }, + headers=self.headers, + name="PHP Compute", + ) # Rust @tag("language", "rs") @task(2) def rust_baseline(self): - self.client.post("/exec", json={"lang": "rs", "code": 'fn main() { println!("Hello from Rust"); }'}, headers=self.headers, name="Rust Baseline") + self.client.post( + "/exec", + json={"lang": "rs", "code": 'fn main() { println!("Hello from Rust"); }'}, + headers=self.headers, + name="Rust Baseline", + ) @tag("language", "rs") @task(1) def rust_compute(self): - self.client.post("/exec", json={"lang": "rs", "code": 'fn main() { let r: i64 = (0..10000).map(|i: i64| i * i).sum(); println!("Result: {}", r); }'}, headers=self.headers, name="Rust Compute") + self.client.post( + "/exec", + json={ + "lang": "rs", + "code": 'fn main() { let r: i64 = (0..10000).map(|i: i64| i * i).sum(); println!("Result: {}", r); }', + }, + headers=self.headers, + name="Rust Compute", + ) # R @tag("language", "r") @task(2) def r_baseline(self): - self.client.post("/exec", json={"lang": "r", "code": 'print("Hello from R")'}, headers=self.headers, name="R Baseline") + self.client.post( + "/exec", + json={"lang": "r", "code": 'print("Hello from R")'}, + headers=self.headers, + name="R Baseline", + ) @tag("language", "r") @task(1) def r_compute(self): - self.client.post("/exec", json={"lang": "r", "code": 'r <- sum((0:9999)^2)\nprint(paste("Result:", r))'}, headers=self.headers, name="R Compute") + self.client.post( + "/exec", + json={ + "lang": "r", + "code": 'r <- sum((0:9999)^2)\nprint(paste("Result:", r))', + }, + headers=self.headers, + name="R Compute", + ) # Fortran @tag("language", "f90") @task(2) def fortran_baseline(self): - self.client.post("/exec", json={"lang": "f90", "code": 'program hello\n print *, "Hello from Fortran"\nend program hello'}, headers=self.headers, name="Fortran Baseline") + self.client.post( + "/exec", + json={ + "lang": "f90", + "code": 'program hello\n print *, "Hello from Fortran"\nend program hello', + }, + headers=self.headers, + name="Fortran Baseline", + ) @tag("language", "f90") @task(1) def fortran_compute(self): - self.client.post("/exec", json={"lang": "f90", "code": 'program compute\n integer(8) :: r, i\n r = 0\n do i = 0, 9999\n r = r + i * i\n end do\n print *, "Result:", r\nend program compute'}, headers=self.headers, name="Fortran Compute") + self.client.post( + "/exec", + json={ + "lang": "f90", + "code": 'program compute\n integer(8) :: r, i\n r = 0\n do i = 0, 9999\n r = r + i * i\n end do\n print *, "Result:", r\nend program compute', + }, + headers=self.headers, + name="Fortran Compute", + ) # D @tag("language", "d") @task(2) def d_baseline(self): - self.client.post("/exec", json={"lang": "d", "code": 'import std.stdio;\nvoid main() { writeln("Hello from D"); }'}, headers=self.headers, name="D Baseline") + self.client.post( + "/exec", + json={ + "lang": "d", + "code": 'import std.stdio;\nvoid main() { writeln("Hello from D"); }', + }, + headers=self.headers, + name="D Baseline", + ) @tag("language", "d") @task(1) def d_compute(self): - self.client.post("/exec", json={"lang": "d", "code": 'import std.stdio;\nimport std.algorithm;\nimport std.range;\nvoid main() { long r = iota(0, 10000).map!(i => cast(long)i * i).sum; writeln("Result: ", r); }'}, headers=self.headers, name="D Compute") + self.client.post( + "/exec", + json={ + "lang": "d", + "code": 'import std.stdio;\nimport std.algorithm;\nimport std.range;\nvoid main() { long r = iota(0, 10000).map!(i => cast(long)i * i).sum; writeln("Result: ", r); }', + }, + headers=self.headers, + name="D Compute", + ) + + # Bash + @tag("language", "bash") + @task(2) + def bash_baseline(self): + self.client.post( + "/exec", + json={"lang": "bash", "code": 'echo "Hello from Bash"'}, + headers=self.headers, + name="Bash Baseline", + ) + + @tag("language", "bash") + @task(1) + def bash_compute(self): + self.client.post( + "/exec", + json={ + "lang": "bash", + "code": 'sum=0; for i in $(seq 1 10000); do sum=$((sum + i * i)); done; echo "Result: $sum"', + }, + headers=self.headers, + name="Bash Compute", + ) # ============================================================================= # Specialized User Classes for targeted testing # ============================================================================= + class CPUUser(HttpUser): """CPU-bound workloads only.""" + wait_time = between(0.5, 1.5) def on_start(self): @@ -469,23 +758,50 @@ def on_start(self): @task(10) def cpu_light(self): - self.client.post("/exec", json={"lang": "py", "code": "print(sum(range(10000)))"}, headers=self.headers, name="CPU Light") + self.client.post( + "/exec", + json={"lang": "py", "code": "print(sum(range(10000)))"}, + headers=self.headers, + name="CPU Light", + ) @task(5) def cpu_medium(self): - self.client.post("/exec", json={"lang": "py", "code": "print(sum(i**2 for i in range(100000)))"}, headers=self.headers, name="CPU Medium") + self.client.post( + "/exec", + json={"lang": "py", "code": "print(sum(i**2 for i in range(100000)))"}, + headers=self.headers, + name="CPU Medium", + ) @task(2) def cpu_heavy(self): - self.client.post("/exec", json={"lang": "py", "code": "import numpy as np; a=np.random.rand(500,500); b=np.random.rand(500,500); print(np.dot(a,b).sum())"}, headers=self.headers, name="CPU Heavy") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "import numpy as np; a=np.random.rand(500,500); b=np.random.rand(500,500); print(np.dot(a,b).sum())", + }, + headers=self.headers, + name="CPU Heavy", + ) @task(1) def cpu_sklearn(self): - self.client.post("/exec", json={"lang": "py", "code": "from sklearn.ensemble import RandomForestClassifier; from sklearn.datasets import make_classification; X,y=make_classification(500,20); m=RandomForestClassifier(50); m.fit(X,y); print(m.score(X,y))"}, headers=self.headers, name="CPU Sklearn") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "from sklearn.ensemble import RandomForestClassifier; from sklearn.datasets import make_classification; X,y=make_classification(500,20); m=RandomForestClassifier(50); m.fit(X,y); print(m.score(X,y))", + }, + headers=self.headers, + name="CPU Sklearn", + ) class MemoryUser(HttpUser): """Memory-bound workloads only.""" + wait_time = between(1, 2) def on_start(self): @@ -494,23 +810,56 @@ def on_start(self): @task(5) def mem_10mb(self): - self.client.post("/exec", json={"lang": "py", "code": "import numpy as np; arr=np.random.rand(1310720); print(arr.sum())"}, headers=self.headers, name="Memory 10MB") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "import numpy as np; arr=np.random.rand(1310720); print(arr.sum())", + }, + headers=self.headers, + name="Memory 10MB", + ) @task(3) def mem_50mb(self): - self.client.post("/exec", json={"lang": "py", "code": "import numpy as np; arr=np.random.rand(6553600); print(arr.mean())"}, headers=self.headers, name="Memory 50MB") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "import numpy as np; arr=np.random.rand(6553600); print(arr.mean())", + }, + headers=self.headers, + name="Memory 50MB", + ) @task(2) def mem_100mb(self): - self.client.post("/exec", json={"lang": "py", "code": "import numpy as np; arr=np.random.rand(13107200); print(arr.std())"}, headers=self.headers, name="Memory 100MB") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "import numpy as np; arr=np.random.rand(13107200); print(arr.std())", + }, + headers=self.headers, + name="Memory 100MB", + ) @task(2) def mem_pandas(self): - self.client.post("/exec", json={"lang": "py", "code": "import pandas as pd; import numpy as np; df=pd.DataFrame({'a':np.random.rand(1000000)}); print(df.shape)"}, headers=self.headers, name="Memory Pandas") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "import pandas as pd; import numpy as np; df=pd.DataFrame({'a':np.random.rand(1000000)}); print(df.shape)", + }, + headers=self.headers, + name="Memory Pandas", + ) class IOUser(HttpUser): """I/O-bound workloads only.""" + wait_time = between(1, 2) def on_start(self): @@ -519,19 +868,44 @@ def on_start(self): @task(3) def io_files(self): - self.client.post("/exec", json={"lang": "py", "code": "for i in range(5): open(f'/mnt/data/f{i}.txt','w').write('x'*50000)\nprint('done')"}, headers=self.headers, name="I/O Files") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "for i in range(5): open(f'/mnt/data/f{i}.txt','w').write('x'*50000)\nprint('done')", + }, + headers=self.headers, + name="I/O Files", + ) @task(2) def io_matplotlib(self): - self.client.post("/exec", json={"lang": "py", "code": "import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt; import numpy as np; plt.plot(np.sin(np.linspace(0,10,100))); plt.savefig('/mnt/data/p.png'); print('done')"}, headers=self.headers, name="I/O Matplotlib") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt; import numpy as np; plt.plot(np.sin(np.linspace(0,10,100))); plt.savefig('/mnt/data/p.png'); print('done')", + }, + headers=self.headers, + name="I/O Matplotlib", + ) @task(3) def io_csv(self): - self.client.post("/exec", json={"lang": "py", "code": "import pandas as pd; import numpy as np; pd.DataFrame({'a':np.random.rand(10000)}).to_csv('/mnt/data/d.csv'); print('done')"}, headers=self.headers, name="I/O CSV") + self.client.post( + "/exec", + json={ + "lang": "py", + "code": "import pandas as pd; import numpy as np; pd.DataFrame({'a':np.random.rand(10000)}).to_csv('/mnt/data/d.csv'); print('done')", + }, + headers=self.headers, + name="I/O CSV", + ) class LanguageUser(HttpUser): """Multi-language tests only.""" + wait_time = between(0.5, 1.5) def on_start(self): @@ -540,31 +914,71 @@ def on_start(self): @task(2) def python(self): - self.client.post("/exec", json={"lang": "py", "code": 'print("Hello Python")'}, headers=self.headers, name="Python") + self.client.post( + "/exec", + json={"lang": "py", "code": 'print("Hello Python")'}, + headers=self.headers, + name="Python", + ) @task(2) def javascript(self): - self.client.post("/exec", json={"lang": "js", "code": 'console.log("Hello JS");'}, headers=self.headers, name="JavaScript") + self.client.post( + "/exec", + json={"lang": "js", "code": 'console.log("Hello JS");'}, + headers=self.headers, + name="JavaScript", + ) @task(2) def go(self): - self.client.post("/exec", json={"lang": "go", "code": 'package main\nimport "fmt"\nfunc main(){fmt.Println("Hello Go")}'}, headers=self.headers, name="Go") + self.client.post( + "/exec", + json={ + "lang": "go", + "code": 'package main\nimport "fmt"\nfunc main(){fmt.Println("Hello Go")}', + }, + headers=self.headers, + name="Go", + ) @task(1) def java(self): - self.client.post("/exec", json={"lang": "java", "code": 'public class Main{public static void main(String[]a){System.out.println("Hello Java");}}'}, headers=self.headers, name="Java") + self.client.post( + "/exec", + json={ + "lang": "java", + "code": 'public class Main{public static void main(String[]a){System.out.println("Hello Java");}}', + }, + headers=self.headers, + name="Java", + ) @task(1) def rust(self): - self.client.post("/exec", json={"lang": "rs", "code": 'fn main(){println!("Hello Rust");}'}, headers=self.headers, name="Rust") + self.client.post( + "/exec", + json={"lang": "rs", "code": 'fn main(){println!("Hello Rust");}'}, + headers=self.headers, + name="Rust", + ) @task(1) def cpp(self): - self.client.post("/exec", json={"lang": "cpp", "code": '#include\nint main(){std::cout<<"Hello C++"<\nint main(){std::cout<<"Hello C++"< ProgrammaticService: + """Get or create the PTC service singleton.""" + global _ptc_service + if _ptc_service is None: + _ptc_service = ProgrammaticService() + return _ptc_service + + +@router.post("/exec/programmatic", response_model=ProgrammaticExecResponse) +async def execute_programmatic( + request: ProgrammaticExecRequest, + http_request: Request, + session_service: SessionServiceDep, +) -> ProgrammaticExecResponse: + """Execute code with programmatic tool calling support. + + Supports two modes: + 1. Initial execution: provide code + tools + 2. Continuation: provide continuation_token + tool_results + + Args: + request: PTC execution request + http_request: HTTP request for auth state + session_service: Session service for session management + + Returns: + ProgrammaticExecResponse with status and optional tool_calls + """ + request_id = generate_request_id()[:8] + ptc_service = _get_ptc_service() + + # Continuation mode + if request.continuation_token: + logger.info( + "PTC continuation request", + request_id=request_id, + continuation_token=request.continuation_token[:12], + tool_results_count=len(request.tool_results), + ) + + response = await ptc_service.continue_execution( + continuation_token=request.continuation_token, + tool_results=request.tool_results, + ) + + logger.info( + "PTC continuation completed", + request_id=request_id, + status=response.status, + ) + + return response + + # Initial execution mode + if not request.code: + return ProgrammaticExecResponse( + status="error", + error="Either 'code' or 'continuation_token' must be provided", + ) + + # Get or create session + session_id = request.session_id + if not session_id: + metadata = {} + if request.entity_id: + metadata["entity_id"] = request.entity_id + if request.user_id: + metadata["user_id"] = request.user_id + + session = await session_service.create_session(SessionCreate(metadata=metadata)) + session_id = session.session_id + + logger.info( + "PTC execution request", + request_id=request_id, + session_id=session_id[:12], + code_length=len(request.code), + tools_count=len(request.tools), + ) + + response = await ptc_service.start_execution( + code=request.code, + tools=request.tools, + session_id=session_id, + timeout=request.timeout, + files=request.files, + ) + + # Ensure session_id is set in response + if not response.session_id: + response.session_id = session_id + + logger.info( + "PTC execution completed", + request_id=request_id, + session_id=session_id[:12], + status=response.status, + ) + + return response diff --git a/src/config/languages.py b/src/config/languages.py index a6310f9..e5546f4 100644 --- a/src/config/languages.py +++ b/src/config/languages.py @@ -26,7 +26,7 @@ class LanguageConfig: environment: Dict[str, str] = field(default_factory=dict) -# All 12 supported languages with complete configuration +# All 13 supported languages with complete configuration LANGUAGES: Dict[str, LanguageConfig] = { "py": LanguageConfig( code="py", @@ -149,6 +149,16 @@ class LanguageConfig: timeout_multiplier=2.0, memory_multiplier=1.2, ), + "bash": LanguageConfig( + code="bash", + name="Bash", + user_id=1001, + file_extension="sh", + execution_command="bash", + uses_stdin=True, + timeout_multiplier=1.0, + memory_multiplier=1.0, + ), } diff --git a/src/main.py b/src/main.py index 5e892c5..4e874bf 100644 --- a/src/main.py +++ b/src/main.py @@ -15,7 +15,7 @@ from pydantic import ValidationError # Local application imports -from .api import files, exec, health, admin, dashboard_metrics +from .api import files, exec, health, admin, dashboard_metrics, programmatic from .config import settings from .middleware.security import SecurityMiddleware, RequestLoggingMiddleware from .middleware.metrics import MetricsMiddleware @@ -140,13 +140,23 @@ async def _perform_health_checks() -> None: async def _shutdown_services(app: FastAPI) -> None: - """Stop monitoring services, sandbox pool, and cleanup scheduler.""" + """Stop monitoring services, sandbox pool, PTC contexts, and cleanup scheduler.""" try: await metrics_service.stop() logger.info("Metrics service stopped") except Exception as e: logger.error("Error stopping metrics service", error=str(e)) + # Clean up PTC paused contexts + try: + from .api.programmatic import _ptc_service + + if _ptc_service is not None: + await _ptc_service.cleanup_all() + logger.info("PTC service cleaned up") + except Exception as e: + logger.error("Error cleaning up PTC service", error=str(e)) + if hasattr(app.state, "sandbox_pool") and app.state.sandbox_pool: try: await app.state.sandbox_pool.stop() @@ -257,6 +267,8 @@ async def config_info(): app.include_router(exec.router, tags=["exec"]) +app.include_router(programmatic.router, tags=["exec", "programmatic"]) + app.include_router(health.router, tags=["health", "monitoring"]) app.include_router(admin.router, prefix="/api/v1", tags=["admin"]) diff --git a/src/models/__init__.py b/src/models/__init__.py index 5cbb100..577e214 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -23,6 +23,13 @@ FileDeleteResponse, ) from .exec import ExecRequest, ExecResponse, FileRef, RequestFile +from .programmatic import ( + PTCToolDefinition, + PTCToolCall, + PTCToolResult, + ProgrammaticExecRequest, + ProgrammaticExecResponse, +) from .errors import ( ErrorType, ErrorDetail, @@ -58,6 +65,12 @@ "ExecResponse", "FileRef", "RequestFile", + # PTC models + "PTCToolDefinition", + "PTCToolCall", + "PTCToolResult", + "ProgrammaticExecRequest", + "ProgrammaticExecResponse", # Error models "ErrorType", "ErrorDetail", diff --git a/src/models/programmatic.py b/src/models/programmatic.py new file mode 100644 index 0000000..7ccc620 --- /dev/null +++ b/src/models/programmatic.py @@ -0,0 +1,117 @@ +"""Models for the Programmatic Tool Calling (PTC) API. + +PTC allows code running inside the sandbox to call external tools +(defined by the caller) and receive results back before continuing +execution. This enables agentic workflows where code can request +information or actions from the outside world. +""" + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class PTCToolDefinition(BaseModel): + """Definition of a tool available to sandbox code.""" + + name: str = Field(..., description="Tool function name") + description: str = Field( + default="", description="Human-readable description of the tool" + ) + parameters: Dict[str, Any] = Field( + default_factory=dict, + description="JSON Schema describing the tool's parameters", + ) + + +class PTCToolCall(BaseModel): + """A tool call requested by sandbox code.""" + + id: str = Field(..., description="Unique identifier for this tool call") + name: str = Field(..., description="Name of the tool to call") + input: Dict[str, Any] = Field( + default_factory=dict, description="Arguments for the tool call" + ) + + +class PTCToolResult(BaseModel): + """Result of a tool call to be sent back to sandbox code.""" + + call_id: str = Field(..., description="ID of the tool call this result is for") + result: Any = Field(default=None, description="Tool call result value") + is_error: bool = Field(default=False, description="Whether the tool call errored") + error_message: Optional[str] = Field( + default=None, description="Error message if is_error is True" + ) + + +class ProgrammaticExecRequest(BaseModel): + """Request model for POST /exec/programmatic. + + Supports two modes: + 1. Initial execution: provide code + tools (+ optional session_id, etc.) + 2. Continuation: provide continuation_token + tool_results + """ + + # Initial execution fields + code: Optional[str] = Field( + default=None, description="Python code to execute (initial request)" + ) + tools: List[PTCToolDefinition] = Field( + default_factory=list, + description="Tools available to the code (initial request)", + ) + session_id: Optional[str] = Field( + default=None, description="Optional session ID for continuity" + ) + user_id: Optional[str] = Field(default=None, description="Optional user identifier") + entity_id: Optional[str] = Field( + default=None, + description="Optional assistant/agent identifier", + max_length=40, + pattern=r"^[A-Za-z0-9_-]+$", + ) + timeout: Optional[int] = Field( + default=None, description="Execution timeout in seconds" + ) + files: List[Dict[str, Any]] = Field( + default_factory=list, description="Files to mount in sandbox" + ) + + # Continuation fields + continuation_token: Optional[str] = Field( + default=None, + description="Token from a previous tool_call_required response", + ) + tool_results: List[PTCToolResult] = Field( + default_factory=list, + description="Results for tool calls (continuation request)", + ) + + +class ProgrammaticExecResponse(BaseModel): + """Response model for POST /exec/programmatic.""" + + status: str = Field( + ..., + description="Execution status: tool_call_required, completed, or error", + ) + session_id: Optional[str] = Field( + default=None, description="Session ID for this execution" + ) + continuation_token: Optional[str] = Field( + default=None, + description="Token to continue execution after providing tool results", + ) + tool_calls: List[PTCToolCall] = Field( + default_factory=list, + description="Tool calls requested by the code (when status=tool_call_required)", + ) + stdout: str = Field(default="", description="Standard output from code execution") + stderr: str = Field(default="", description="Standard error from code execution") + files: List[Dict[str, Any]] = Field( + default_factory=list, description="Files generated during execution" + ) + error: Optional[str] = Field( + default=None, description="Error message when status=error" + ) diff --git a/src/services/orchestrator.py b/src/services/orchestrator.py index c4e9c35..8559dd7 100644 --- a/src/services/orchestrator.py +++ b/src/services/orchestrator.py @@ -217,9 +217,16 @@ async def _get_or_create_session(self, ctx: ExecutionContext) -> str: Session lookup priority: 1. Use session_id from request (for explicit session continuity/state persistence) - 2. Reuse session from file references (for file-based workflows) + 2. Reuse session from file references, but ONLY if the session belongs to + the same user (prevents cross-user session sharing via shared agent files) 3. Reuse session by entity_id (for session continuity within same entity) 4. Create new session + + SECURITY: File references carry a session_id that indicates where the file + is stored, NOT which session to execute in. When multiple users share an + agent with attached files, they all reference the same upload session. + Blindly reusing that session would leak state between users. We only reuse + a file-referenced session if its user_id matches the current request. """ request = ctx.request @@ -240,8 +247,12 @@ async def _get_or_create_session(self, ctx: ExecutionContext) -> str: error=str(e), ) - # Priority 2: Try to reuse session from files array - if request.files: + # Priority 2: Try to reuse session from files array, but only if the + # session was created by the same user. This enables same-user session + # continuity (ToolNode injects files from previous execution) while + # preventing cross-user sharing (agent files reference a shared upload + # session that has no user_id). + if request.files and request.user_id: for file_ref in request.files: if file_ref.session_id: try: @@ -249,11 +260,17 @@ async def _get_or_create_session(self, ctx: ExecutionContext) -> str: file_ref.session_id ) if existing and existing.status.value == "active": - logger.debug( - "Reusing session from file reference", - session_id=file_ref.session_id, + session_user = ( + existing.metadata.get("user_id") + if existing.metadata + else None ) - return file_ref.session_id + if session_user and session_user == request.user_id: + logger.debug( + "Reusing session from file reference (same user)", + session_id=file_ref.session_id[:12], + ) + return file_ref.session_id except Exception as e: logger.warning( "Error looking up session", @@ -261,7 +278,11 @@ async def _get_or_create_session(self, ctx: ExecutionContext) -> str: error=str(e), ) - # Try to reuse session by entity_id (enables session continuity) + # Priority 3: Try to reuse session by entity_id. + # Only use explicit entity_id — do NOT fall back to user_id. + # LibreChat manages session continuity via file references (priority 2), + # not entity_id. Using user_id here would incorrectly share sessions + # across different conversations of the same user. if request.entity_id: try: entity_sessions = await self.session_service.list_sessions_by_entity( diff --git a/src/services/programmatic.py b/src/services/programmatic.py new file mode 100644 index 0000000..a06758f --- /dev/null +++ b/src/services/programmatic.py @@ -0,0 +1,543 @@ +"""Programmatic Tool Calling (PTC) service. + +Manages sandbox lifecycle for PTC executions where code can pause +to request external tool calls and resume with results. +""" + +import asyncio +import json +import os +import re +import shlex +import signal +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import structlog + +from ..config import settings +from ..models.programmatic import ( + PTCToolCall, + PTCToolDefinition, + PTCToolResult, + ProgrammaticExecResponse, +) +from .sandbox.manager import SandboxManager +from .sandbox.nsjail import NsjailConfig, SandboxInfo + +logger = structlog.get_logger(__name__) + +# Protocol delimiter must match docker/ptc_server.py +PTC_DELIMITER = "\n---PTC_END---\n" + +# Default timeout for paused contexts (seconds) +PTC_PAUSE_TIMEOUT = 300 # 5 minutes + +# Maximum round trips per execution +PTC_MAX_ROUND_TRIPS = 50 + + +@dataclass +class PausedContext: + """Stores state for a paused PTC execution waiting for tool results.""" + + sandbox_info: SandboxInfo + process: asyncio.subprocess.Process + session_id: str + round_trip_count: int = 0 + timeout_handle: Optional[asyncio.TimerHandle] = None + accumulated_stdout: str = "" + accumulated_stderr: str = "" + + +class ProgrammaticService: + """Manages PTC execution lifecycle. + + Creates nsjail sandboxes, runs ptc_server.py inside them, and + manages the request/response protocol for tool calls. + """ + + def __init__(self, sandbox_manager: Optional[SandboxManager] = None): + self._sandbox_manager = sandbox_manager or SandboxManager() + self._nsjail_config = NsjailConfig() + self._paused_contexts: Dict[str, PausedContext] = {} + + async def start_execution( + self, + code: str, + tools: List[PTCToolDefinition], + session_id: str, + timeout: Optional[int] = None, + files: Optional[List[Dict[str, Any]]] = None, + ) -> ProgrammaticExecResponse: + """Start a new PTC execution. + + Creates an nsjail sandbox, copies ptc_server.py into it, + and starts execution with the provided code and tools. + + Args: + code: Python code to execute + tools: Tool definitions available to the code + session_id: Session identifier + timeout: Execution timeout in seconds + files: Optional files to mount in sandbox + + Returns: + ProgrammaticExecResponse with status and optional tool_calls + """ + execution_timeout = timeout or settings.max_execution_time + + # Create sandbox + sandbox_info = self._sandbox_manager.create_sandbox( + session_id=session_id, + language="py", + repl_mode=False, + ) + + try: + # Copy ptc_server.py into the sandbox data dir + ptc_server_path = Path("/opt/ptc_server.py") + if not ptc_server_path.exists(): + # Fallback: try relative path (local development) + ptc_server_path = ( + Path(__file__).parent.parent.parent / "docker" / "ptc_server.py" + ) + + if ptc_server_path.exists(): + self._sandbox_manager.copy_content_to_sandbox( + sandbox_info, + ptc_server_path.read_bytes(), + "/mnt/data/ptc_server.py", + language="py", + ) + else: + return ProgrammaticExecResponse( + status="error", + session_id=session_id, + error="PTC server script not found", + ) + + # Mount any provided files + if files: + for file_info in files: + filename = file_info.get("filename", "") + content = file_info.get("content", b"") + if filename and content: + self._sandbox_manager.copy_content_to_sandbox( + sandbox_info, + content if isinstance(content, bytes) else content.encode(), + f"/mnt/data/{filename}", + language="py", + ) + + # Build nsjail command - wrap in /bin/sh -c like SandboxExecutor + env = self._sandbox_manager.executor._build_sanitized_env("py") + shell_command = [ + "/bin/sh", + "-c", + "python3 /mnt/data/ptc_server.py", + ] + nsjail_args = self._nsjail_config.build_args( + sandbox_dir=str(sandbox_info.data_dir), + command=shell_command, + language="py", + timeout=execution_timeout, + env=env, + ) + + # Build wrapper command (same pattern as SandboxExecutor) + nsjail_cmd = " ".join( + shlex.quote(str(a)) for a in [settings.nsjail_binary] + nsjail_args + ) + + wrapper_cmd = ( + f"mount --bind {shlex.quote(str(sandbox_info.data_dir))} /mnt/data && " + f"mount -t tmpfs -o size=1k tmpfs /var/lib/code-interpreter/sandboxes && " + f"mount -t tmpfs -o size=1k tmpfs /app/data && " + f"mount -t tmpfs -o size=1k tmpfs /var/log && " + f"mount -t tmpfs -o size=1k tmpfs /app/ssl && " + f"mount -t tmpfs -o size=1k tmpfs /app/dashboard && " + f"mount -t tmpfs -o size=1k tmpfs /app/src && " + f"mount --bind /tmp/empty_proc /proc && " + f"{nsjail_cmd}" + ) + + # Start subprocess + proc = await asyncio.create_subprocess_exec( + "unshare", + "--mount", + "--", + "/bin/sh", + "-c", + wrapper_cmd, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + start_new_session=True, + ) + + # Send initial request with code and tools + tools_payload = [ + { + "name": t.name, + "description": t.description, + "parameters": t.parameters, + } + for t in tools + ] + initial_request = json.dumps({"code": code, "tools": tools_payload}) + initial_request += PTC_DELIMITER + + assert proc.stdin is not None + proc.stdin.write(initial_request.encode("utf-8")) + await proc.stdin.drain() + + # Read response from ptc_server + return await self._read_ptc_response( + proc=proc, + sandbox_info=sandbox_info, + session_id=session_id, + timeout=execution_timeout, + ) + + except Exception as e: + # Cleanup sandbox on error + self._sandbox_manager.destroy_sandbox(sandbox_info) + logger.error( + "PTC execution failed", + session_id=session_id[:12], + error=str(e), + ) + return ProgrammaticExecResponse( + status="error", + session_id=session_id, + error=f"Execution failed: {str(e)}", + ) + + async def continue_execution( + self, + continuation_token: str, + tool_results: List[PTCToolResult], + ) -> ProgrammaticExecResponse: + """Continue a paused PTC execution with tool results. + + Args: + continuation_token: Token from a previous tool_call_required response + tool_results: Results for the requested tool calls + + Returns: + ProgrammaticExecResponse with updated status + """ + ctx = self._paused_contexts.get(continuation_token) + if not ctx: + return ProgrammaticExecResponse( + status="error", + error="Invalid or expired continuation token", + ) + + # Cancel the timeout + if ctx.timeout_handle: + ctx.timeout_handle.cancel() + ctx.timeout_handle = None + + # Check round trip limit + ctx.round_trip_count += 1 + if ctx.round_trip_count > PTC_MAX_ROUND_TRIPS: + await self._cleanup_paused_context(continuation_token) + return ProgrammaticExecResponse( + status="error", + session_id=ctx.session_id, + error=f"Maximum round trips ({PTC_MAX_ROUND_TRIPS}) exceeded", + ) + + try: + # Send tool results to the subprocess + results_payload = { + "type": "tool_results", + "results": [ + { + "call_id": r.call_id, + "result": r.result, + "is_error": r.is_error, + "error_message": r.error_message, + } + for r in tool_results + ], + } + data = json.dumps(results_payload) + PTC_DELIMITER + assert ctx.process.stdin is not None + ctx.process.stdin.write(data.encode("utf-8")) + await ctx.process.stdin.drain() + + # Remove from paused (will be re-added if another tool_call happens) + del self._paused_contexts[continuation_token] + + # Read next response + return await self._read_ptc_response( + proc=ctx.process, + sandbox_info=ctx.sandbox_info, + session_id=ctx.session_id, + timeout=settings.max_execution_time, + accumulated_stdout=ctx.accumulated_stdout, + accumulated_stderr=ctx.accumulated_stderr, + round_trip_count=ctx.round_trip_count, + ) + + except Exception as e: + await self._cleanup_paused_context(continuation_token) + logger.error( + "PTC continuation failed", + continuation_token=continuation_token[:12], + error=str(e), + ) + return ProgrammaticExecResponse( + status="error", + session_id=ctx.session_id, + error=f"Continuation failed: {str(e)}", + ) + + async def _read_ptc_response( + self, + proc: asyncio.subprocess.Process, + sandbox_info: SandboxInfo, + session_id: str, + timeout: int, + accumulated_stdout: str = "", + accumulated_stderr: str = "", + round_trip_count: int = 0, + ) -> ProgrammaticExecResponse: + """Read and process a response from the PTC server subprocess. + + Args: + proc: The subprocess running ptc_server.py + sandbox_info: Sandbox info for cleanup + session_id: Session identifier + timeout: Timeout in seconds + accumulated_stdout: Previously accumulated stdout + accumulated_stderr: Previously accumulated stderr + round_trip_count: Current round trip count + + Returns: + ProgrammaticExecResponse + """ + try: + # Read stdout until we get a complete PTC message + stdout_buf = "" + stderr_buf = "" + + async def read_until_delimiter() -> None: + nonlocal stdout_buf + assert proc.stdout is not None + while PTC_DELIMITER not in stdout_buf: + chunk = await proc.stdout.read(4096) + if not chunk: + break + stdout_buf += chunk.decode("utf-8", errors="replace") + + try: + await asyncio.wait_for( + read_until_delimiter(), + timeout=timeout + 5, + ) + except asyncio.TimeoutError: + self._kill_process(proc) + self._sandbox_manager.destroy_sandbox(sandbox_info) + return ProgrammaticExecResponse( + status="error", + session_id=session_id, + error=f"Execution timed out after {timeout} seconds", + stdout=accumulated_stdout, + stderr=accumulated_stderr, + ) + + # Also read any stderr + try: + assert proc.stderr is not None + stderr_data = await asyncio.wait_for( + proc.stderr.read(65536), + timeout=0.5, + ) + if stderr_data: + stderr_buf = stderr_data.decode("utf-8", errors="replace") + except asyncio.TimeoutError: + pass + + # Parse response + if PTC_DELIMITER not in stdout_buf: + # Process may have exited without sending delimiter + self._kill_process(proc) + self._sandbox_manager.destroy_sandbox(sandbox_info) + return ProgrammaticExecResponse( + status="error", + session_id=session_id, + error="PTC server exited without response", + stdout=accumulated_stdout + stdout_buf, + stderr=accumulated_stderr + stderr_buf, + ) + + json_part = stdout_buf.split(PTC_DELIMITER)[0] + + # Sanitize control characters before parsing + json_part = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", "", json_part) + + try: + response = json.loads(json_part) + except json.JSONDecodeError as e: + self._kill_process(proc) + self._sandbox_manager.destroy_sandbox(sandbox_info) + return ProgrammaticExecResponse( + status="error", + session_id=session_id, + error=f"Invalid response from PTC server: {e}", + stdout=accumulated_stdout, + stderr=accumulated_stderr + stderr_buf, + ) + + msg_type = response.get("type", "") + msg_stdout = response.get("stdout", "") + msg_stderr = response.get("stderr", "") + total_stdout = accumulated_stdout + msg_stdout + total_stderr = accumulated_stderr + msg_stderr + stderr_buf + + if msg_type == "tool_calls": + # Code is paused waiting for tool results + calls = [ + PTCToolCall( + id=c["id"], + name=c["name"], + input=c.get("input", {}), + ) + for c in response.get("calls", []) + ] + + # Generate continuation token and store context + token = uuid.uuid4().hex + loop = asyncio.get_event_loop() + + ctx = PausedContext( + sandbox_info=sandbox_info, + process=proc, + session_id=session_id, + round_trip_count=round_trip_count, + accumulated_stdout=total_stdout, + accumulated_stderr=total_stderr, + ) + + # Set timeout for cleanup + def _make_cleanup_callback( + tok: str, + ) -> Callable[[], None]: + def _cb() -> None: + asyncio.ensure_future(self._cleanup_paused_context(tok)) + + return _cb + + ctx.timeout_handle = loop.call_later( + PTC_PAUSE_TIMEOUT, + _make_cleanup_callback(token), + ) + + self._paused_contexts[token] = ctx + + return ProgrammaticExecResponse( + status="tool_call_required", + session_id=session_id, + continuation_token=token, + tool_calls=calls, + stdout=total_stdout, + stderr=total_stderr, + ) + + elif msg_type == "completed": + # Execution completed successfully + self._kill_process(proc) + self._sandbox_manager.destroy_sandbox(sandbox_info) + return ProgrammaticExecResponse( + status="completed", + session_id=session_id, + stdout=total_stdout, + stderr=total_stderr, + ) + + elif msg_type == "error": + # Execution failed + self._kill_process(proc) + self._sandbox_manager.destroy_sandbox(sandbox_info) + return ProgrammaticExecResponse( + status="error", + session_id=session_id, + error=response.get("error", "Unknown error"), + stdout=total_stdout, + stderr=total_stderr, + ) + + else: + # Unknown message type + self._kill_process(proc) + self._sandbox_manager.destroy_sandbox(sandbox_info) + return ProgrammaticExecResponse( + status="error", + session_id=session_id, + error=f"Unknown PTC message type: {msg_type}", + stdout=total_stdout, + stderr=total_stderr, + ) + + except Exception as e: + self._kill_process(proc) + self._sandbox_manager.destroy_sandbox(sandbox_info) + logger.error( + "PTC response reading failed", + session_id=session_id[:12], + error=str(e), + ) + return ProgrammaticExecResponse( + status="error", + session_id=session_id, + error=f"Failed to read PTC response: {str(e)}", + stdout=accumulated_stdout, + stderr=accumulated_stderr, + ) + + def _kill_process(self, proc: asyncio.subprocess.Process) -> None: + """Kill a subprocess and its process group.""" + if proc.returncode is not None: + return + try: + os.killpg(proc.pid, signal.SIGKILL) + except (ProcessLookupError, PermissionError): + try: + proc.kill() + except ProcessLookupError: + pass + + async def _cleanup_paused_context(self, token: str) -> None: + """Clean up a paused PTC context (on timeout or error).""" + ctx = self._paused_contexts.pop(token, None) + if ctx is None: + return + + if ctx.timeout_handle: + ctx.timeout_handle.cancel() + + self._kill_process(ctx.process) + try: + await ctx.process.wait() + except Exception: + pass + + self._sandbox_manager.destroy_sandbox(ctx.sandbox_info) + logger.debug( + "Cleaned up paused PTC context", + token=token[:12], + session_id=ctx.session_id[:12], + ) + + async def cleanup_all(self) -> None: + """Clean up all paused PTC contexts. Called during shutdown.""" + tokens = list(self._paused_contexts.keys()) + for token in tokens: + await self._cleanup_paused_context(token) + logger.info("Cleaned up all PTC contexts", count=len(tokens)) diff --git a/src/services/sandbox/executor.py b/src/services/sandbox/executor.py index c4d3d9d..83123dc 100644 --- a/src/services/sandbox/executor.py +++ b/src/services/sandbox/executor.py @@ -243,6 +243,7 @@ def _build_sanitized_env(self, language: Optional[str]) -> Dict[str, str]: "F95": "gfortran", } ) + # bash and d use default PATH/HOME/TMPDIR only return env_whitelist diff --git a/src/services/sandbox/nsjail.py b/src/services/sandbox/nsjail.py index 64e2590..c3a2494 100644 --- a/src/services/sandbox/nsjail.py +++ b/src/services/sandbox/nsjail.py @@ -91,6 +91,7 @@ class NsjailConfig: "/usr/bin/ldc2", "/usr/bin/ldmd2", ], + "bash": [], } def __init__(self): diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 760a30b..7fd58ef 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -72,6 +72,10 @@ 'import std.stdio;\nvoid main(){ int s=0; foreach(i; 1..11) s+=i; writeln("d: sum(1..10)=", s); }', "55", ), + "bash": ( + 's=0; for i in $(seq 1 10); do s=$((s + i)); done; echo "bash: sum(1..10)=$s"', + "55", + ), } @@ -128,7 +132,7 @@ def unique_entity_id() -> str: @pytest.fixture(params=list(LANGUAGE_SNIPPETS.keys())) def language_test_case(request): - """Parametrized fixture for all 12 languages.""" + """Parametrized fixture for all 13 languages.""" lang = request.param code, expected = LANGUAGE_SNIPPETS[lang] return {"lang": lang, "code": code, "expected_output": expected} diff --git a/tests/functional/test_bash.py b/tests/functional/test_bash.py new file mode 100644 index 0000000..33ddabf --- /dev/null +++ b/tests/functional/test_bash.py @@ -0,0 +1,155 @@ +"""Functional tests for bash execution against a live API endpoint.""" + +import pytest + + +class TestBashExecution: + """Test bash code execution via POST /exec with lang='bash'.""" + + @pytest.mark.asyncio + async def test_bash_echo(self, async_client, auth_headers, unique_entity_id): + """Basic bash echo returns expected output.""" + response = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": "echo hello-from-bash", + "lang": "bash", + "entity_id": unique_entity_id, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "hello-from-bash" in data["stdout"] + + @pytest.mark.asyncio + async def test_bash_response_has_librechat_fields( + self, async_client, auth_headers, unique_entity_id + ): + """Bash response has the same 4 required LibreChat fields as Python.""" + response = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": "echo ok", + "lang": "bash", + "entity_id": unique_entity_id, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert "session_id" in data + assert "stdout" in data + assert "stderr" in data + assert "files" in data + + assert isinstance(data["session_id"], str) + assert isinstance(data["stdout"], str) + assert isinstance(data["stderr"], str) + assert isinstance(data["files"], list) + + @pytest.mark.asyncio + async def test_bash_error_returns_200( + self, async_client, auth_headers, unique_entity_id + ): + """Bash syntax error returns HTTP 200 with error in stderr.""" + response = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": "if then fi done", + "lang": "bash", + "entity_id": unique_entity_id, + }, + ) + + # CRITICAL: Execution errors return 200 (LibreChat compatibility) + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert len(data["stderr"]) > 0 + + @pytest.mark.asyncio + async def test_bash_variables_and_arithmetic( + self, async_client, auth_headers, unique_entity_id + ): + """Bash arithmetic and variable expansion works.""" + response = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": 'x=42; echo "result=$((x * 2))"', + "lang": "bash", + "entity_id": unique_entity_id, + }, + ) + + assert response.status_code == 200 + assert "result=84" in response.json()["stdout"] + + @pytest.mark.asyncio + async def test_bash_multiline_script( + self, async_client, auth_headers, unique_entity_id + ): + """Multi-line bash script with loops and conditionals.""" + code = ( + "total=0\n" + "for i in 1 2 3 4 5; do\n" + " total=$((total + i))\n" + "done\n" + 'echo "sum=$total"' + ) + response = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": code, + "lang": "bash", + "entity_id": unique_entity_id, + }, + ) + + assert response.status_code == 200 + assert "sum=15" in response.json()["stdout"] + + @pytest.mark.asyncio + async def test_bash_exit_code_nonzero( + self, async_client, auth_headers, unique_entity_id + ): + """Bash script with non-zero exit code still returns 200.""" + response = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": "echo before-error; exit 1", + "lang": "bash", + "entity_id": unique_entity_id, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "before-error" in data["stdout"] + + @pytest.mark.asyncio + async def test_bash_piped_commands( + self, async_client, auth_headers, unique_entity_id + ): + """Bash piped commands work correctly.""" + response = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": 'echo -e "cherry\\napple\\nbanana" | sort', + "lang": "bash", + "entity_id": unique_entity_id, + }, + ) + + assert response.status_code == 200 + stdout = response.json()["stdout"] + lines = [l for l in stdout.strip().split("\n") if l] + assert lines == ["apple", "banana", "cherry"] diff --git a/tests/functional/test_exec_languages.py b/tests/functional/test_exec_languages.py index 2a2b25a..c30bb84 100644 --- a/tests/functional/test_exec_languages.py +++ b/tests/functional/test_exec_languages.py @@ -1,4 +1,4 @@ -"""Functional tests for code execution across all 12 supported languages.""" +"""Functional tests for code execution across all 13 supported languages.""" import time diff --git a/tests/functional/test_exec_workflow.py b/tests/functional/test_exec_workflow.py index 63b5ed4..17885b1 100644 --- a/tests/functional/test_exec_workflow.py +++ b/tests/functional/test_exec_workflow.py @@ -107,6 +107,130 @@ async def test_execution_error_returns_200( assert len(data["stderr"]) > 0 +class TestSessionIsolation: + """Test session isolation for agent file sharing scenarios. + + When multiple users share an agent with attached files, each user + must get their own session. The upload session_id in file references + should NOT be blindly reused. + """ + + @pytest.mark.asyncio + async def test_different_users_get_different_sessions( + self, async_client, auth_headers, unique_entity_id + ): + """Two users with the same entity_id but different user_ids get different sessions.""" + r1 = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": "print('user-a')", + "lang": "py", + "user_id": "user-a-isolation-test", + "entity_id": unique_entity_id, + }, + ) + assert r1.status_code == 200 + session_a = r1.json()["session_id"] + + # Different user_id, same entity_id + r2 = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": "print('user-b')", + "lang": "py", + "user_id": "user-b-isolation-test", + "entity_id": unique_entity_id, + }, + ) + assert r2.status_code == 200 + session_b = r2.json()["session_id"] + + # With entity_id-based session reuse, both might share a session. + # The key test is: when file references are involved, sessions diverge. + # This test verifies each user gets a valid session. + assert len(session_a) > 0 + assert len(session_b) > 0 + + @pytest.mark.asyncio + async def test_file_ref_does_not_leak_session_across_users( + self, async_client, auth_headers, unique_entity_id + ): + """File references from an upload session should not share execution sessions. + + Simulates: Agent uploads file (creates upload session S1), + then userA and userB both execute with a reference to that file. + Each should get their own execution session, not reuse S1. + """ + # Upload a file (simulating agent upload with entity_id) + upload = await async_client.post( + "/upload", + headers={"x-api-key": auth_headers["x-api-key"]}, + files={"file": ("shared.txt", b"shared content", "text/plain")}, + data={"entity_id": unique_entity_id}, + ) + assert upload.status_code == 200 + upload_data = upload.json() + upload_session = upload_data["session_id"] + file_id = upload_data["files"][0]["fileId"] + filename = upload_data["files"][0]["filename"] + + # User A executes with file reference + r_a = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": f"print(open('{filename}').read())", + "lang": "py", + "user_id": "isolation-user-a", + "files": [ + { + "id": file_id, + "session_id": upload_session, + "name": filename, + } + ], + }, + ) + assert r_a.status_code == 200 + session_a = r_a.json()["session_id"] + assert "shared content" in r_a.json()["stdout"] + + # User B executes with same file reference + r_b = await async_client.post( + "/exec", + headers=auth_headers, + json={ + "code": f"print(open('{filename}').read())", + "lang": "py", + "user_id": "isolation-user-b", + "files": [ + { + "id": file_id, + "session_id": upload_session, + "name": filename, + } + ], + }, + ) + assert r_b.status_code == 200 + session_b = r_b.json()["session_id"] + assert "shared content" in r_b.json()["stdout"] + + # Neither user should reuse the upload session + assert session_a != upload_session, ( + "User A should not reuse the upload session" + ) + assert session_b != upload_session, ( + "User B should not reuse the upload session" + ) + # Each user should get a different session + assert session_a != session_b, ( + "Different users should get different sessions" + ) + + class TestStatePersistence: """Test Python state persistence across executions.""" diff --git a/tests/functional/test_files.py b/tests/functional/test_files.py index 1ad5c0e..3a084c0 100644 --- a/tests/functional/test_files.py +++ b/tests/functional/test_files.py @@ -183,6 +183,83 @@ async def test_list_files_detail_summary( assert isinstance(files_list, list) +class TestFileMetadata: + """Test file metadata fields required by LibreChat.""" + + @pytest.mark.asyncio + async def test_detail_full_has_original_filename_metadata( + self, async_client, auth_headers, unique_entity_id + ): + """GET /files/{sid}?detail=full must include metadata['original-filename']. + + LibreChat reads this field at CodeExecutor.ts:170 to map sanitized + filenames back to original upload names. + """ + # Upload a file with a distinctive name + files = {"files": ("My Report (2024).csv", b"a,b\n1,2", "text/csv")} + upload = await async_client.post( + "/upload", + headers={"x-api-key": auth_headers["x-api-key"]}, + files=files, + data={"entity_id": unique_entity_id}, + ) + assert upload.status_code == 200 + session_id = upload.json()["session_id"] + + # Get full detail + response = await async_client.get( + f"/files/{session_id}?detail=full", + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) >= 1 + + for item in data: + assert "metadata" in item, "Full detail must include 'metadata'" + assert "original-filename" in item["metadata"], ( + "metadata must include 'original-filename'" + ) + assert isinstance(item["metadata"]["original-filename"], str) + assert len(item["metadata"]["original-filename"]) > 0 + + @pytest.mark.asyncio + async def test_detail_full_has_required_fields( + self, async_client, auth_headers, unique_entity_id + ): + """GET /files/{sid}?detail=full returns all fields LibreChat expects.""" + files = {"files": ("test.txt", b"content", "text/plain")} + upload = await async_client.post( + "/upload", + headers={"x-api-key": auth_headers["x-api-key"]}, + files=files, + data={"entity_id": unique_entity_id}, + ) + session_id = upload.json()["session_id"] + + response = await async_client.get( + f"/files/{session_id}?detail=full", + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) >= 1 + + item = data[0] + # Fields LibreChat expects in full detail + assert "id" in item + assert "name" in item + assert "size" in item + assert "lastModified" in item + assert "contentType" in item + assert "metadata" in item + assert "content-type" in item["metadata"] + assert "original-filename" in item["metadata"] + + class TestFileDownload: """Test GET /download/{session_id}/{file_id}.""" diff --git a/tests/functional/test_ptc.py b/tests/functional/test_ptc.py new file mode 100644 index 0000000..6f24441 --- /dev/null +++ b/tests/functional/test_ptc.py @@ -0,0 +1,289 @@ +"""Functional tests for Programmatic Tool Calling (PTC) against a live API endpoint.""" + +import pytest + + +class TestPTCInitialExecution: + """Test POST /exec/programmatic with initial code execution.""" + + @pytest.mark.asyncio + async def test_ptc_simple_code_completes( + self, async_client, auth_headers + ): + """PTC request with code that doesn't call any tools completes immediately.""" + response = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "code": "print('hello from ptc')", + "tools": [ + {"name": "unused_tool", "description": "Not called"} + ], + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + assert "session_id" in data + assert "hello from ptc" in data["stdout"] + + @pytest.mark.asyncio + async def test_ptc_response_has_all_fields( + self, async_client, auth_headers + ): + """PTC response includes all expected fields.""" + response = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "code": "x = 1 + 1", + "tools": [], + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "status" in data + assert "session_id" in data + assert "continuation_token" in data + assert "tool_calls" in data + assert "stdout" in data + assert "stderr" in data + assert "files" in data + assert "error" in data + + @pytest.mark.asyncio + async def test_ptc_no_code_returns_error( + self, async_client, auth_headers + ): + """PTC request without code or continuation_token returns error.""" + response = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "error" + assert data["error"] is not None + + +class TestPTCToolCallFlow: + """Test the full PTC tool call round-trip: code calls tool, we supply result.""" + + @pytest.mark.asyncio + async def test_ptc_tool_call_and_continuation( + self, async_client, auth_headers + ): + """Full PTC round-trip: code calls a tool, receives result, completes.""" + # Step 1: Send code that calls a tool + initial_response = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "code": ( + "result = await get_number()\n" + "print(f'got: {result}')" + ), + "tools": [ + { + "name": "get_number", + "description": "Returns a number", + "parameters": {"type": "object", "properties": {}}, + } + ], + }, + ) + + assert initial_response.status_code == 200 + data = initial_response.json() + assert data["status"] == "tool_call_required", ( + f"Expected tool_call_required, got {data['status']}. " + f"stderr: {data.get('stderr', '')}, error: {data.get('error', '')}" + ) + assert data["continuation_token"] is not None + assert len(data["tool_calls"]) >= 1 + + tool_call = data["tool_calls"][0] + assert "id" in tool_call + assert tool_call["name"] == "get_number" + assert "input" in tool_call + + # Step 2: Send tool result back + continuation_response = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "continuation_token": data["continuation_token"], + "tool_results": [ + { + "call_id": tool_call["id"], + "result": 42, + "is_error": False, + } + ], + }, + ) + + assert continuation_response.status_code == 200 + result = continuation_response.json() + assert result["status"] == "completed" + assert "got: 42" in result["stdout"] + + @pytest.mark.asyncio + async def test_ptc_tool_with_arguments( + self, async_client, auth_headers + ): + """Tool call passes arguments correctly.""" + initial = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "code": ( + "result = await add(a=3, b=7)\n" + "print(f'sum={result}')" + ), + "tools": [ + { + "name": "add", + "description": "Add two numbers", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + }, + } + ], + }, + ) + + assert initial.status_code == 200 + data = initial.json() + assert data["status"] == "tool_call_required" + + tool_call = data["tool_calls"][0] + assert tool_call["name"] == "add" + assert tool_call["input"]["a"] == 3 + assert tool_call["input"]["b"] == 7 + + # Return sum + cont = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "continuation_token": data["continuation_token"], + "tool_results": [ + { + "call_id": tool_call["id"], + "result": 10, + "is_error": False, + } + ], + }, + ) + + assert cont.status_code == 200 + result = cont.json() + assert result["status"] == "completed" + assert "sum=10" in result["stdout"] + + @pytest.mark.asyncio + async def test_ptc_tool_error_result( + self, async_client, auth_headers + ): + """Tool result with is_error=true is handled by the code.""" + initial = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "code": ( + "try:\n" + " result = await failing_tool()\n" + " print(f'unexpected: {result}')\n" + "except Exception as e:\n" + " print(f'caught: {e}')" + ), + "tools": [ + {"name": "failing_tool", "description": "Will fail"} + ], + }, + ) + + assert initial.status_code == 200 + data = initial.json() + assert data["status"] == "tool_call_required" + + tool_call = data["tool_calls"][0] + + cont = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "continuation_token": data["continuation_token"], + "tool_results": [ + { + "call_id": tool_call["id"], + "result": None, + "is_error": True, + "error_message": "Service unavailable", + } + ], + }, + ) + + assert cont.status_code == 200 + result = cont.json() + # Code should have caught the error or completed with error info + assert result["status"] in ("completed", "error") + + +class TestPTCInvalidToken: + """Test PTC continuation with invalid/expired tokens.""" + + @pytest.mark.asyncio + async def test_ptc_invalid_continuation_token( + self, async_client, auth_headers + ): + """Invalid continuation token returns error status.""" + response = await async_client.post( + "/exec/programmatic", + headers=auth_headers, + json={ + "continuation_token": "nonexistent-token-xyz", + "tool_results": [ + {"call_id": "fake-call", "result": "data"} + ], + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "error" + assert data["error"] is not None + + +class TestPTCAuth: + """Test authentication on PTC endpoint.""" + + @pytest.mark.asyncio + async def test_ptc_no_auth_returns_401(self, async_client): + """PTC request without auth returns 401.""" + response = await async_client.post( + "/exec/programmatic", + json={"code": "print('hello')"}, + ) + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_ptc_invalid_auth_returns_401(self, async_client): + """PTC request with wrong API key returns 401.""" + response = await async_client.post( + "/exec/programmatic", + headers={"x-api-key": "wrong-key-12345"}, + json={"code": "print('hello')"}, + ) + assert response.status_code == 401 diff --git a/tests/integration/test_api_contracts.py b/tests/integration/test_api_contracts.py index 169341f..e48fbcd 100644 --- a/tests/integration/test_api_contracts.py +++ b/tests/integration/test_api_contracts.py @@ -17,7 +17,7 @@ from src.models.session import Session, SessionStatus from src.models.files import FileInfo -# All 12 supported languages +# All 13 supported languages SUPPORTED_LANGUAGES = [ "py", "js", @@ -31,6 +31,7 @@ "r", "f90", "d", + "bash", ] @@ -133,6 +134,7 @@ def mock_file_service(): path="/test.txt", ) service.download_file.return_value = "https://minio.example.com/download-url" + service.validate_uploads = MagicMock(return_value=None) return service diff --git a/tests/integration/test_container_behavior.py b/tests/integration/test_container_behavior.py index 935cc21..b8bbac6 100644 --- a/tests/integration/test_container_behavior.py +++ b/tests/integration/test_container_behavior.py @@ -175,7 +175,7 @@ class TestLanguageExecution: """Test language-specific execution patterns.""" # Languages that support stdin execution (interpreted) - STDIN_LANGUAGES = ["py", "js", "php", "r"] + STDIN_LANGUAGES = ["py", "js", "php", "r", "bash"] # Languages that require file-based execution (compiled) FILE_LANGUAGES = ["go", "java", "c", "cpp", "rs", "f90", "d", "ts"] @@ -241,6 +241,7 @@ def test_stdin_language_execution(self, client, auth_headers, language): "js": "console.log('Hello js')", "php": "", "r": "print('Hello r')", + "bash": 'echo "Hello from Bash"', } response = client.post( diff --git a/tests/integration/test_exec_api.py b/tests/integration/test_exec_api.py index 1640764..3c0fb62 100644 --- a/tests/integration/test_exec_api.py +++ b/tests/integration/test_exec_api.py @@ -282,8 +282,8 @@ def test_exec_invalid_language(self, client, auth_headers): response = client.post("/exec", json=request_data, headers=auth_headers) - # Should either return error or handle gracefully - assert response.status_code in [200, 400, 422] + # Should return 400 for unsupported language + assert response.status_code == 400 def test_exec_empty_code(self, client, auth_headers): """Test executing empty code.""" diff --git a/tests/integration/test_file_api.py b/tests/integration/test_file_api.py index a46525d..e77af84 100644 --- a/tests/integration/test_file_api.py +++ b/tests/integration/test_file_api.py @@ -139,9 +139,8 @@ def test_download_uploaded_file(self, client, auth_headers): follow_redirects=False, ) - # Should redirect to MinIO presigned URL - assert download_response.status_code == 302 - assert "location" in download_response.headers + # Should return file content directly + assert download_response.status_code == 200 def test_download_nonexistent_file(self, client, auth_headers, unique_session_id): """Test downloading a file that doesn't exist.""" diff --git a/tests/integration/test_librechat_compat.py b/tests/integration/test_librechat_compat.py index 1b9f502..0e3090b 100644 --- a/tests/integration/test_librechat_compat.py +++ b/tests/integration/test_librechat_compat.py @@ -496,11 +496,13 @@ def test_files_endpoint_with_detail_summary(self, client, auth_headers): assert "name" in item, "Summary must have 'name' field" assert "lastModified" in item, "Summary must have 'lastModified' field" # LibreChat parses name with: file.name.startsWith(path) where path = "session_id/fileId" - assert item["name"] == "test-session-123/file-123", \ - f"name must be 'session_id/fileId' format, got: {item['name']}" + assert ( + item["name"] == "test-session-123/file-123" + ), f"name must be 'session_id/fileId' format, got: {item['name']}" # lastModified must be ISO 8601 with Z suffix for LibreChat's Date parsing - assert item["lastModified"].endswith("Z"), \ - f"lastModified must end with 'Z', got: {item['lastModified']}" + assert item["lastModified"].endswith( + "Z" + ), f"lastModified must end with 'Z', got: {item['lastModified']}" def test_files_endpoint_with_detail_full(self, client, auth_headers): """ @@ -700,7 +702,9 @@ def setup_mocks(self): from src.dependencies.services import get_file_service, get_session_service app.dependency_overrides[get_file_service] = lambda: self.mock_file_service - app.dependency_overrides[get_session_service] = lambda: self.mock_session_service + app.dependency_overrides[get_session_service] = ( + lambda: self.mock_session_service + ) yield @@ -713,7 +717,9 @@ def test_upload_then_check_summary(self, client, auth_headers): This is the primeFiles check: upload -> GET /files/{session_id}?detail=summary """ # Step 1: Upload file (LibreChat uses 'file' singular) - upload_files = {"file": ("data.csv", io.BytesIO(b"col1,col2\n1,2\n"), "text/csv")} + upload_files = { + "file": ("data.csv", io.BytesIO(b"col1,col2\n1,2\n"), "text/csv") + } upload_data = {"entity_id": "asst_test_agent"} upload_response = client.post( @@ -762,7 +768,10 @@ def test_upload_then_exec_with_file_ref(self, mock_execute, client, auth_headers # Step 1: Upload upload_files = {"file": ("input.txt", io.BytesIO(b"hello world"), "text/plain")} upload_response = client.post( - "/upload", files=upload_files, data={"entity_id": "asst_test"}, headers=auth_headers + "/upload", + files=upload_files, + data={"entity_id": "asst_test"}, + headers=auth_headers, ) assert upload_response.status_code == 200 upload_result = upload_response.json() @@ -782,7 +791,9 @@ def test_upload_then_exec_with_file_ref(self, mock_execute, client, auth_headers json={ "code": "with open('/mnt/data/input.txt') as f: print(f.read())", "lang": "py", - "files": [{"id": file_id, "session_id": session_id, "name": "input.txt"}], + "files": [ + {"id": file_id, "session_id": session_id, "name": "input.txt"} + ], }, headers=auth_headers, ) @@ -812,9 +823,7 @@ def test_download_output_file(self, client, auth_headers): ) self.mock_file_service.get_file_content.return_value = file_content - response = client.get( - f"/download/{session_id}/{file_id}", headers=auth_headers - ) + response = client.get(f"/download/{session_id}/{file_id}", headers=auth_headers) assert response.status_code == 200 assert response.content == file_content @@ -875,7 +884,9 @@ def setup_mocks(self): from src.dependencies.services import get_file_service, get_session_service app.dependency_overrides[get_file_service] = lambda: self.mock_file_service - app.dependency_overrides[get_session_service] = lambda: self.mock_session_service + app.dependency_overrides[get_session_service] = ( + lambda: self.mock_session_service + ) yield @@ -915,8 +926,9 @@ def test_prime_files_check_existing(self, client, auth_headers): # file.name.startsWith("session_id/fileId") file_identifier = f"{session_id}/{file_id}" matching = [f for f in data if f["name"].startswith(file_identifier)] - assert len(matching) == 1, \ - f"LibreChat expects to find file by name.startsWith('{file_identifier}')" + assert ( + len(matching) == 1 + ), f"LibreChat expects to find file by name.startsWith('{file_identifier}')" def test_prime_files_reupload_flow(self, client, auth_headers): """ @@ -1002,12 +1014,16 @@ def test_prime_files_name_format_matches_client_parsing(self, client, auth_heade name = data[0]["name"] # Simulate LibreChat's parsing parts = name.split("/") - assert len(parts) == 2, f"name must have exactly 2 parts split by '/', got: {name}" + assert ( + len(parts) == 2 + ), f"name must have exactly 2 parts split by '/', got: {name}" parsed_session_id, parsed_file_id = parts - assert parsed_session_id == session_id, \ - f"First part must be session_id '{session_id}', got: '{parsed_session_id}'" - assert parsed_file_id == file_id, \ - f"Second part must be file_id '{file_id}', got: '{parsed_file_id}'" + assert ( + parsed_session_id == session_id + ), f"First part must be session_id '{session_id}', got: '{parsed_session_id}'" + assert ( + parsed_file_id == file_id + ), f"Second part must be file_id '{file_id}', got: '{parsed_file_id}'" def test_prime_files_last_modified_is_parseable_date(self, client, auth_headers): """ @@ -1039,8 +1055,9 @@ def test_prime_files_last_modified_is_parseable_date(self, client, auth_headers) parsed = datetime.fromisoformat(last_modified.replace("Z", "+00:00")) assert parsed is not None, "lastModified must be parseable ISO 8601" # Must end with Z (UTC) for JavaScript Date compatibility - assert last_modified.endswith("Z"), \ - f"lastModified must end with 'Z' for JS Date parsing, got: {last_modified}" + assert last_modified.endswith( + "Z" + ), f"lastModified must end with 'Z' for JS Date parsing, got: {last_modified}" # ============================================================================= @@ -1224,10 +1241,545 @@ def test_download_with_full_librechat_headers(self, client): "User-Agent": "LibreChat/1.0", } + response = client.get("/download/dl-session/dl-file", headers=headers) + assert response.status_code == 200 + assert response.content == b"hello" + finally: + app.dependency_overrides.clear() + + +# ============================================================================= +# FIELD NAME GUARD - exec vs upload field name consistency +# ============================================================================= + + +class TestFieldNameGuard: + """Verify that exec response files use 'id'/'name' and upload uses 'fileId'/'filename'. + + LibreChat expects: + - /exec response files[]: {id, name, path?, session_id?} + - /upload response files[]: {fileId, filename} + + These are DIFFERENT field names by design (matching LibreChat's expectations). + """ + + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_exec_response_uses_id_and_name(self, mock_execute, client, auth_headers): + """Exec response files must use 'id' and 'name' fields.""" + mock_execute.return_value = ExecResponse( + session_id="guard-session", + stdout="", + stderr="", + files=[FileRef(id="gen-file-1", name="output.png", path="/output.png")], + ) + + response = client.post( + "/exec", + json={"code": "generate image", "lang": "py"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + file_ref = data["files"][0] + assert "id" in file_ref, "exec response files must use 'id'" + assert "name" in file_ref, "exec response files must use 'name'" + assert "fileId" not in file_ref, "exec response must NOT use 'fileId'" + assert "filename" not in file_ref, "exec response must NOT use 'filename'" + + def test_upload_response_uses_fileid_and_filename(self, client, auth_headers): + """Upload response files must use 'fileId' and 'filename' fields.""" + mock_file_service = AsyncMock() + mock_file_service.store_uploaded_file.return_value = "upload-file-001" + mock_file_service.validate_uploads = MagicMock(return_value=None) + + mock_session_service = AsyncMock() + mock_session_service.create_session.return_value = Session( + session_id="guard-upload-session", + status=SessionStatus.ACTIVE, + created_at=datetime.now(timezone.utc), + last_activity=datetime.now(timezone.utc), + expires_at=datetime.now(timezone.utc) + timedelta(hours=24), + metadata={}, + ) + + from src.dependencies.services import get_file_service, get_session_service + + app.dependency_overrides[get_file_service] = lambda: mock_file_service + app.dependency_overrides[get_session_service] = lambda: mock_session_service + + try: + files = {"file": ("test.txt", io.BytesIO(b"content"), "text/plain")} + response = client.post("/upload", files=files, headers=auth_headers) + + assert response.status_code == 200 + data = response.json() + file_info = data["files"][0] + assert "fileId" in file_info, "upload response must use 'fileId'" + assert "filename" in file_info, "upload response must use 'filename'" + assert "id" not in file_info, "upload response must NOT use 'id'" + assert "name" not in file_info, "upload response must NOT use 'name'" + finally: + app.dependency_overrides.clear() + + +# ============================================================================= +# LIBRECHAT EDGE CASES +# ============================================================================= + + +class TestLibreChatEdgeCases: + """Test edge-case behaviors that LibreChat relies on.""" + + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_session_id_always_present_in_response( + self, mock_execute, client, auth_headers + ): + """Every exec response must include a non-empty session_id string.""" + mock_execute.return_value = ExecResponse( + session_id="edge-session-123", stdout="ok\n", stderr="", files=[] + ) + + response = client.post( + "/exec", + json={"code": "print('ok')", "lang": "py"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert isinstance(data["session_id"], str) + assert len(data["session_id"]) > 0 + + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_extra_fields_in_request_ignored(self, mock_execute, client, auth_headers): + """Extra/unknown fields in the request body must be silently ignored.""" + mock_execute.return_value = ExecResponse( + session_id="extra-session", stdout="ok\n", stderr="", files=[] + ) + + request = { + "code": "print('ok')", + "lang": "py", + "unknown_field": "should be ignored", + "another_extra": 42, + } + + response = client.post("/exec", json=request, headers=auth_headers) + assert response.status_code == 200 + + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_empty_files_array_accepted(self, mock_execute, client, auth_headers): + """Request with files:[] (empty array) must be accepted.""" + mock_execute.return_value = ExecResponse( + session_id="empty-files-session", stdout="ok\n", stderr="", files=[] + ) + + request = { + "code": "print('ok')", + "lang": "py", + "files": [], + } + + response = client.post("/exec", json=request, headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert data["files"] == [] + + def test_detail_full_name_format(self, client, auth_headers): + """GET /files/{session_id}?detail=full must return name as 'session_id/fileId'.""" + mock_file_service = AsyncMock() + mock_file_service.list_files.return_value = [ + FileInfo( + file_id="full-file-789", + filename="report.pdf", + size=4096, + content_type="application/pdf", + created_at=datetime.now(timezone.utc), + path="/report.pdf", + ) + ] + + from src.dependencies.services import get_file_service + + app.dependency_overrides[get_file_service] = lambda: mock_file_service + + try: + session_id = "edge-full-session" response = client.get( - "/download/dl-session/dl-file", headers=headers + f"/files/{session_id}?detail=full", headers=auth_headers ) + assert response.status_code == 200 - assert response.content == b"hello" + data = response.json() + assert isinstance(data, list) + assert len(data) == 1 + item = data[0] + assert item["name"] == f"{session_id}/full-file-789" + assert item["id"] == "full-file-789" finally: app.dependency_overrides.clear() + + def test_detail_full_has_original_filename_metadata(self, client, auth_headers): + """ + GET /files/{sid}?detail=full must include metadata['original-filename']. + + LibreChat reads this field at CodeExecutor.ts:170 to map sanitized + filenames back to original upload names. + """ + mock_file_service = AsyncMock() + mock_file_service.list_files.return_value = [ + FileInfo( + file_id="meta-file-001", + filename="my_report.pdf", + size=2048, + content_type="application/pdf", + created_at=datetime.now(timezone.utc), + path="/my_report.pdf", + ), + FileInfo( + file_id="meta-file-002", + filename="data.csv", + size=512, + content_type="text/csv", + created_at=datetime.now(timezone.utc), + path="/data.csv", + ), + ] + + from src.dependencies.services import get_file_service + + app.dependency_overrides[get_file_service] = lambda: mock_file_service + + try: + response = client.get( + "/files/meta-test-session?detail=full", headers=auth_headers + ) + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) == 2 + + for item in data: + assert "metadata" in item, "Full detail must include 'metadata'" + assert ( + "original-filename" in item["metadata"] + ), "metadata must include 'original-filename'" + assert isinstance( + item["metadata"]["original-filename"], str + ), "original-filename must be a string" + assert ( + len(item["metadata"]["original-filename"]) > 0 + ), "original-filename must not be empty" + finally: + app.dependency_overrides.clear() + + +# ============================================================================= +# LIBRECHAT BASH EXECUTION +# ============================================================================= + + +class TestLibreChatBashExecution: + """Test bash execution compatibility with LibreChat. + + Bash was added as a supported language. These tests verify that bash + requests follow the same API contract as other languages. + """ + + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_bash_minimal_request(self, mock_execute, client, auth_headers): + """Minimal bash request should return 200 with standard response shape.""" + mock_execute.return_value = ExecResponse( + session_id="bash-session-123", + stdout="hello\n", + stderr="", + files=[], + ) + + response = client.post( + "/exec", + json={"code": "echo hello", "lang": "bash"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["session_id"] == "bash-session-123" + assert data["stdout"] == "hello\n" + assert data["stderr"] == "" + + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_bash_error_returns_200(self, mock_execute, client, auth_headers): + """Bash syntax error should return 200 with error in stderr.""" + mock_execute.return_value = ExecResponse( + session_id="bash-err-session", + stdout="", + stderr="bash: syntax error near unexpected token\n", + files=[], + ) + + response = client.post( + "/exec", + json={"code": "if then fi done", "lang": "bash"}, + headers=auth_headers, + ) + + # CRITICAL: Same as other languages - execution errors return 200 + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert data["stderr"] != "" + + @patch("src.services.orchestrator.ExecutionOrchestrator.execute") + def test_bash_response_matches_python_shape( + self, mock_execute, client, auth_headers + ): + """Bash response must have the same 4 fields as Python response.""" + mock_execute.return_value = ExecResponse( + session_id="bash-shape-session", + stdout="result\n", + stderr="", + files=[], + ) + + response = client.post( + "/exec", + json={"code": "echo result", "lang": "bash"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + + # Must have the same 4 required fields as any other language + assert "session_id" in data + assert "stdout" in data + assert "stderr" in data + assert "files" in data + + # Type validation + assert isinstance(data["session_id"], str) + assert isinstance(data["stdout"], str) + assert isinstance(data["stderr"], str) + assert isinstance(data["files"], list) + + +# ============================================================================= +# LIBRECHAT PROGRAMMATIC TOOL CALLING (PTC) +# ============================================================================= + + +class TestLibreChatProgrammaticToolCalling: + """Test PTC endpoint compatibility from a LibreChat client perspective. + + These tests verify the /exec/programmatic endpoint contract using the + same mocking pattern as test_programmatic_api.py. + """ + + @patch("src.api.programmatic._get_ptc_service") + def test_ptc_initial_request_completed( + self, mock_get_service, client, auth_headers + ): + """Initial PTC request with code+tools should return completed response.""" + from src.models.programmatic import ProgrammaticExecResponse + + mock_service = AsyncMock() + mock_service.start_execution.return_value = ProgrammaticExecResponse( + status="completed", + session_id="ptc-compat-session", + stdout="Hello from PTC\n", + stderr="", + ) + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + mock_session_svc.create_session.return_value = Session( + session_id="ptc-compat-session", + status=SessionStatus.ACTIVE, + created_at=datetime.now(timezone.utc), + last_activity=datetime.now(timezone.utc), + expires_at=datetime.now(timezone.utc), + metadata={}, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "code": "print('Hello from PTC')", + "tools": [{"name": "get_data", "description": "Get data"}], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + assert "session_id" in data + assert isinstance(data["stdout"], str) + assert isinstance(data["stderr"], str) + + @patch("src.api.programmatic._get_ptc_service") + def test_ptc_tool_call_required_response( + self, mock_get_service, client, auth_headers + ): + """PTC request that calls a tool should return tool_call_required.""" + from src.models.programmatic import ProgrammaticExecResponse, PTCToolCall + + mock_service = AsyncMock() + mock_service.start_execution.return_value = ProgrammaticExecResponse( + status="tool_call_required", + session_id="ptc-tool-session", + continuation_token="cont-token-xyz", + tool_calls=[ + PTCToolCall( + id="call-abc", name="get_weather", input={"city": "NYC"} + ), + ], + stdout="", + stderr="", + ) + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + mock_session_svc.create_session.return_value = Session( + session_id="ptc-tool-session", + status=SessionStatus.ACTIVE, + created_at=datetime.now(timezone.utc), + last_activity=datetime.now(timezone.utc), + expires_at=datetime.now(timezone.utc), + metadata={}, + ) + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "code": "result = await get_weather(city='NYC')", + "tools": [ + { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + } + ], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "tool_call_required" + assert "continuation_token" in data + assert isinstance(data["continuation_token"], str) + assert len(data["tool_calls"]) == 1 + + tool_call = data["tool_calls"][0] + assert "id" in tool_call + assert "name" in tool_call + assert "input" in tool_call + assert tool_call["name"] == "get_weather" + + @patch("src.api.programmatic._get_ptc_service") + def test_ptc_continuation_flow(self, mock_get_service, client, auth_headers): + """Continuation with tool_results should return completed response.""" + from src.models.programmatic import ProgrammaticExecResponse + + mock_service = AsyncMock() + mock_service.continue_execution.return_value = ProgrammaticExecResponse( + status="completed", + session_id="ptc-cont-session", + stdout="Weather in NYC: 72F\n", + stderr="", + ) + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "continuation_token": "cont-token-xyz", + "tool_results": [ + { + "call_id": "call-abc", + "result": {"temp": 72, "conditions": "sunny"}, + "is_error": False, + } + ], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + assert data["stdout"] == "Weather in NYC: 72F\n" + + mock_service.continue_execution.assert_called_once() + + @patch("src.api.programmatic._get_ptc_service") + def test_ptc_error_tool_result(self, mock_get_service, client, auth_headers): + """Tool result with is_error=true should be handled correctly.""" + from src.models.programmatic import ProgrammaticExecResponse + + mock_service = AsyncMock() + mock_service.continue_execution.return_value = ProgrammaticExecResponse( + status="completed", + session_id="ptc-err-session", + stdout="Tool failed: API unavailable\n", + stderr="", + ) + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "continuation_token": "cont-token-err", + "tool_results": [ + { + "call_id": "call-fail", + "result": None, + "is_error": True, + "error_message": "API unavailable", + } + ], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + # Response should be valid regardless of tool error + assert data["status"] in ("completed", "error") + mock_service.continue_execution.assert_called_once() diff --git a/tests/integration/test_programmatic_api.py b/tests/integration/test_programmatic_api.py new file mode 100644 index 0000000..e920114 --- /dev/null +++ b/tests/integration/test_programmatic_api.py @@ -0,0 +1,461 @@ +"""Integration tests for the Programmatic Tool Calling (PTC) API endpoint. + +Tests use TestClient with mocked ProgrammaticService to verify the API +contract without requiring actual sandbox infrastructure. +""" + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock, patch + +from src.main import app +from src.models.programmatic import ( + ProgrammaticExecResponse, + PTCToolCall, +) +from src.models.session import Session, SessionStatus +from datetime import datetime, timezone + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture +def auth_headers(): + """Provide authentication headers for tests.""" + return {"x-api-key": "test-api-key-for-testing-12345"} + + +@pytest.fixture +def mock_session(): + """Create a mock session for session service.""" + return Session( + session_id="ptc-session-123", + status=SessionStatus.ACTIVE, + created_at=datetime.now(timezone.utc), + last_activity=datetime.now(timezone.utc), + expires_at=datetime.now(timezone.utc), + metadata={}, + ) + + +@pytest.fixture +def mock_ptc_completed_response(): + """A completed PTC response.""" + return ProgrammaticExecResponse( + status="completed", + session_id="ptc-session-123", + stdout="Hello from PTC\n", + stderr="", + ) + + +@pytest.fixture +def mock_ptc_tool_call_response(): + """A tool_call_required PTC response.""" + return ProgrammaticExecResponse( + status="tool_call_required", + session_id="ptc-session-123", + continuation_token="cont-token-abc", + tool_calls=[ + PTCToolCall(id="call-1", name="get_weather", input={"city": "NYC"}), + ], + stdout="", + stderr="", + ) + + +@pytest.fixture +def mock_ptc_error_response(): + """An error PTC response.""" + return ProgrammaticExecResponse( + status="error", + error="Invalid or expired continuation token", + ) + + +# ============================================================================= +# INITIAL EXECUTION +# ============================================================================= + + +class TestProgrammaticInitialExecution: + """Tests for POST /exec/programmatic with initial execution.""" + + @patch("src.api.programmatic._get_ptc_service") + def test_initial_request_returns_completed( + self, + mock_get_service, + client, + auth_headers, + mock_session, + mock_ptc_completed_response, + ): + """Initial request with code should return completed response.""" + mock_service = AsyncMock() + mock_service.start_execution.return_value = mock_ptc_completed_response + mock_get_service.return_value = mock_service + + with patch( + "src.api.programmatic.SessionServiceDep", + ): + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + mock_session_svc.create_session.return_value = mock_session + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "code": "print('hello')", + "tools": [ + {"name": "get_weather", "description": "Get weather"} + ], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + assert data["session_id"] == "ptc-session-123" + assert data["stdout"] == "Hello from PTC\n" + + @patch("src.api.programmatic._get_ptc_service") + def test_initial_request_returns_tool_calls( + self, + mock_get_service, + client, + auth_headers, + mock_session, + mock_ptc_tool_call_response, + ): + """Initial request should return tool_call_required when code calls tools.""" + mock_service = AsyncMock() + mock_service.start_execution.return_value = mock_ptc_tool_call_response + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + mock_session_svc.create_session.return_value = mock_session + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "code": "result = await get_weather(city='NYC')", + "tools": [ + { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + } + ], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "tool_call_required" + assert data["continuation_token"] == "cont-token-abc" + assert len(data["tool_calls"]) == 1 + assert data["tool_calls"][0]["name"] == "get_weather" + assert data["tool_calls"][0]["id"] == "call-1" + + @patch("src.api.programmatic._get_ptc_service") + def test_initial_request_with_session_id( + self, + mock_get_service, + client, + auth_headers, + mock_session, + mock_ptc_completed_response, + ): + """Initial request with session_id should use existing session.""" + mock_service = AsyncMock() + mock_service.start_execution.return_value = mock_ptc_completed_response + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "code": "print('hello')", + "session_id": "existing-session-456", + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + # Should not have created a new session + mock_session_svc.create_session.assert_not_called() + + +# ============================================================================= +# CONTINUATION +# ============================================================================= + + +class TestProgrammaticContinuation: + """Tests for POST /exec/programmatic with continuation.""" + + @patch("src.api.programmatic._get_ptc_service") + def test_continuation_with_tool_results( + self, + mock_get_service, + client, + auth_headers, + mock_ptc_completed_response, + ): + """Continuation with tool_results should return response.""" + mock_service = AsyncMock() + mock_service.continue_execution.return_value = mock_ptc_completed_response + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "continuation_token": "cont-token-abc", + "tool_results": [ + { + "call_id": "call-1", + "result": {"temp": 72, "conditions": "sunny"}, + } + ], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + + mock_service.continue_execution.assert_called_once() + + @patch("src.api.programmatic._get_ptc_service") + def test_continuation_invalid_token( + self, + mock_get_service, + client, + auth_headers, + mock_ptc_error_response, + ): + """Continuation with invalid token should return error.""" + mock_service = AsyncMock() + mock_service.continue_execution.return_value = mock_ptc_error_response + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "continuation_token": "invalid-token-xyz", + "tool_results": [{"call_id": "call-1", "result": "data"}], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "error" + assert "Invalid or expired" in data["error"] + + +# ============================================================================= +# VALIDATION ERRORS +# ============================================================================= + + +class TestProgrammaticValidation: + """Tests for request validation on the PTC endpoint.""" + + @patch("src.api.programmatic._get_ptc_service") + def test_missing_code_returns_error( + self, mock_get_service, client, auth_headers, mock_session + ): + """Request without code or continuation_token should return error.""" + mock_service = AsyncMock() + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + mock_session_svc.create_session.return_value = mock_session + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={}, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "error" + assert ( + "code" in data["error"].lower() or "continuation" in data["error"].lower() + ) + + def test_invalid_json_returns_422(self, client, auth_headers): + """Sending invalid JSON should return 422.""" + response = client.post( + "/exec/programmatic", + content="not-json", + headers={**auth_headers, "content-type": "application/json"}, + ) + assert response.status_code == 422 + + +# ============================================================================= +# RESPONSE SCHEMA +# ============================================================================= + + +class TestProgrammaticResponseSchema: + """Tests for response schema compliance.""" + + @patch("src.api.programmatic._get_ptc_service") + def test_completed_response_has_expected_fields( + self, + mock_get_service, + client, + auth_headers, + mock_session, + mock_ptc_completed_response, + ): + """Completed response should have all expected fields.""" + mock_service = AsyncMock() + mock_service.start_execution.return_value = mock_ptc_completed_response + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + mock_session_svc.create_session.return_value = mock_session + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={"code": "print('hi')"}, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + data = response.json() + # All fields should be present in response + assert "status" in data + assert "session_id" in data + assert "continuation_token" in data + assert "tool_calls" in data + assert "stdout" in data + assert "stderr" in data + assert "files" in data + assert "error" in data + + @patch("src.api.programmatic._get_ptc_service") + def test_tool_call_response_has_expected_fields( + self, + mock_get_service, + client, + auth_headers, + mock_session, + mock_ptc_tool_call_response, + ): + """Tool call response should have tool_calls with id, name, input.""" + mock_service = AsyncMock() + mock_service.start_execution.return_value = mock_ptc_tool_call_response + mock_get_service.return_value = mock_service + + from src.dependencies.services import get_session_service + + mock_session_svc = AsyncMock() + mock_session_svc.create_session.return_value = mock_session + app.dependency_overrides[get_session_service] = lambda: mock_session_svc + + try: + response = client.post( + "/exec/programmatic", + json={ + "code": "await get_weather(city='NYC')", + "tools": [{"name": "get_weather"}], + }, + headers=auth_headers, + ) + finally: + app.dependency_overrides.clear() + + data = response.json() + assert data["status"] == "tool_call_required" + assert len(data["tool_calls"]) > 0 + tool_call = data["tool_calls"][0] + assert "id" in tool_call + assert "name" in tool_call + assert "input" in tool_call + + +# ============================================================================= +# AUTHENTICATION +# ============================================================================= + + +class TestProgrammaticAuth: + """Tests for authentication on the PTC endpoint.""" + + def test_missing_auth_returns_401(self, client): + """Request without auth headers should return 401.""" + response = client.post( + "/exec/programmatic", + json={"code": "print('hello')"}, + ) + assert response.status_code == 401 + + def test_invalid_auth_returns_401(self, client): + """Request with invalid API key should return 401.""" + response = client.post( + "/exec/programmatic", + json={"code": "print('hello')"}, + headers={"x-api-key": "wrong-key"}, + ) + assert response.status_code == 401 diff --git a/tests/unit/test_language_config.py b/tests/unit/test_language_config.py new file mode 100644 index 0000000..7ca25e1 --- /dev/null +++ b/tests/unit/test_language_config.py @@ -0,0 +1,324 @@ +"""Unit tests for language configuration module (src/config/languages.py). + +Tests the unified language configuration: language definitions, lookup +functions, and correctness of all 13 supported languages. +""" + +import pytest + +from src.config.languages import ( + LANGUAGES, + LanguageConfig, + get_language, + get_supported_languages, + is_supported_language, + get_user_id_for_language, + get_execution_command, + uses_stdin, + get_file_extension, +) + +# All expected language codes +ALL_LANGUAGE_CODES = [ + "py", + "js", + "ts", + "go", + "java", + "c", + "cpp", + "php", + "rs", + "r", + "f90", + "d", + "bash", +] + + +class TestLanguageRegistry: + """Test the LANGUAGES registry has the correct entries.""" + + def test_exactly_13_languages_registered(self): + """There must be exactly 13 supported languages.""" + assert len(LANGUAGES) == 13 + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_language_code_present(self, code): + """Every expected language code must exist in the registry.""" + assert code in LANGUAGES + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_language_config_is_frozen_dataclass(self, code): + """Each language config must be a frozen LanguageConfig dataclass.""" + lang = LANGUAGES[code] + assert isinstance(lang, LanguageConfig) + with pytest.raises(AttributeError): + lang.code = "modified" + + def test_all_codes_match_dict_keys(self): + """The code field of each LanguageConfig must match its dict key.""" + for key, lang in LANGUAGES.items(): + assert lang.code == key + + +class TestBashLanguage: + """Test bash-specific configuration.""" + + def test_bash_code(self): + lang = get_language("bash") + assert lang is not None + assert lang.code == "bash" + + def test_bash_name(self): + lang = get_language("bash") + assert lang.name == "Bash" + + def test_bash_extension(self): + lang = get_language("bash") + assert lang.file_extension == "sh" + + def test_bash_uses_stdin(self): + lang = get_language("bash") + assert lang.uses_stdin is True + + def test_bash_user_id(self): + lang = get_language("bash") + assert lang.user_id == 1001 + + def test_bash_execution_command(self): + lang = get_language("bash") + assert lang.execution_command == "bash" + + def test_bash_timeout_multiplier(self): + lang = get_language("bash") + assert lang.timeout_multiplier == 1.0 + + def test_bash_memory_multiplier(self): + lang = get_language("bash") + assert lang.memory_multiplier == 1.0 + + +class TestPythonLanguage: + """Test Python-specific configuration.""" + + def test_python_user_id(self): + lang = get_language("py") + assert lang.user_id == 999 + + def test_python_uses_stdin(self): + assert uses_stdin("py") is True + + def test_python_extension(self): + assert get_file_extension("py") == "py" + + +class TestStdinVsFileLanguages: + """Test that stdin and file-based language sets are correct.""" + + EXPECTED_STDIN = {"py", "js", "php", "bash"} + EXPECTED_FILE = {"ts", "go", "java", "c", "cpp", "rs", "r", "f90", "d"} + + def test_stdin_languages(self): + """Languages that pass code via stdin.""" + stdin_langs = {code for code, lang in LANGUAGES.items() if lang.uses_stdin} + assert stdin_langs == self.EXPECTED_STDIN + + def test_file_languages(self): + """Languages that use file-based execution.""" + file_langs = {code for code, lang in LANGUAGES.items() if not lang.uses_stdin} + assert file_langs == self.EXPECTED_FILE + + +class TestGetLanguage: + """Test get_language() lookup function.""" + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_returns_config_for_known_code(self, code): + result = get_language(code) + assert result is not None + assert isinstance(result, LanguageConfig) + assert result.code == code + + def test_returns_none_for_unknown(self): + assert get_language("unknown") is None + + def test_case_insensitive(self): + assert get_language("PY") is not None + assert get_language("Py") is not None + assert get_language("BASH") is not None + + +class TestGetSupportedLanguages: + """Test get_supported_languages() function.""" + + def test_returns_list_of_strings(self): + result = get_supported_languages() + assert isinstance(result, list) + assert all(isinstance(code, str) for code in result) + + def test_contains_all_expected_codes(self): + result = get_supported_languages() + for code in ALL_LANGUAGE_CODES: + assert code in result + + def test_length_matches_registry(self): + assert len(get_supported_languages()) == len(LANGUAGES) + + +class TestIsSupportedLanguage: + """Test is_supported_language() function.""" + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_true_for_known_code(self, code): + assert is_supported_language(code) is True + + def test_false_for_unknown(self): + assert is_supported_language("unknown") is False + assert is_supported_language("python") is False + assert is_supported_language("") is False + + def test_case_insensitive(self): + assert is_supported_language("PY") is True + assert is_supported_language("BASH") is True + + +class TestGetUserIdForLanguage: + """Test get_user_id_for_language() function.""" + + def test_python_user_id(self): + assert get_user_id_for_language("py") == 999 + + def test_java_user_id(self): + assert get_user_id_for_language("java") == 999 + + def test_bash_user_id(self): + assert get_user_id_for_language("bash") == 1001 + + def test_d_user_id(self): + assert get_user_id_for_language("d") == 0 + + def test_raises_for_unknown(self): + with pytest.raises(ValueError, match="Unsupported language"): + get_user_id_for_language("unknown") + + +class TestGetExecutionCommand: + """Test get_execution_command() function.""" + + def test_python_command(self): + assert get_execution_command("py") == "python3 -" + + def test_bash_command(self): + assert get_execution_command("bash") == "bash" + + def test_go_command(self): + cmd = get_execution_command("go") + assert "go build" in cmd + + def test_raises_for_unknown(self): + with pytest.raises(ValueError, match="Unsupported language"): + get_execution_command("unknown") + + +class TestUsesStdin: + """Test uses_stdin() function.""" + + def test_true_for_stdin_languages(self): + for code in ["py", "js", "php", "bash"]: + assert uses_stdin(code) is True, f"{code} should use stdin" + + def test_false_for_file_languages(self): + for code in ["ts", "go", "java", "c", "cpp", "rs", "r", "f90", "d"]: + assert uses_stdin(code) is False, f"{code} should not use stdin" + + def test_false_for_unknown(self): + assert uses_stdin("unknown") is False + + +class TestGetFileExtension: + """Test get_file_extension() function.""" + + def test_python_extension(self): + assert get_file_extension("py") == "py" + + def test_bash_extension(self): + assert get_file_extension("bash") == "sh" + + def test_java_extension(self): + assert get_file_extension("java") == "java" + + def test_cpp_extension(self): + assert get_file_extension("cpp") == "cpp" + + def test_fortran_extension(self): + assert get_file_extension("f90") == "f90" + + def test_raises_for_unknown(self): + with pytest.raises(ValueError, match="Unsupported language"): + get_file_extension("unknown") + + +class TestResourceMultipliers: + """Test timeout and memory multiplier values.""" + + def test_rust_has_highest_timeout(self): + """Rust compilation is slow, so it should have a high timeout.""" + rs = get_language("rs") + assert rs.timeout_multiplier == 3.0 + + def test_java_has_high_memory(self): + """Java needs more memory for the JVM.""" + java = get_language("java") + assert java.memory_multiplier == 1.5 + + def test_typescript_has_above_default_timeout(self): + """TypeScript needs extra time for tsc compilation.""" + ts = get_language("ts") + assert ts.timeout_multiplier > 1.0 + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_multipliers_are_positive(self, code): + lang = get_language(code) + assert lang.timeout_multiplier > 0 + assert lang.memory_multiplier > 0 + + +class TestLanguageConfigFields: + """Test that all LanguageConfig instances have valid field values.""" + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_code_is_nonempty_string(self, code): + lang = get_language(code) + assert isinstance(lang.code, str) + assert len(lang.code) > 0 + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_name_is_nonempty_string(self, code): + lang = get_language(code) + assert isinstance(lang.name, str) + assert len(lang.name) > 0 + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_file_extension_is_nonempty(self, code): + lang = get_language(code) + assert isinstance(lang.file_extension, str) + assert len(lang.file_extension) > 0 + assert "." not in lang.file_extension, "extension should not contain dot" + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_execution_command_is_nonempty(self, code): + lang = get_language(code) + assert isinstance(lang.execution_command, str) + assert len(lang.execution_command) > 0 + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_user_id_is_non_negative(self, code): + lang = get_language(code) + assert isinstance(lang.user_id, int) + assert lang.user_id >= 0 + + @pytest.mark.parametrize("code", ALL_LANGUAGE_CODES) + def test_environment_is_dict(self, code): + lang = get_language(code) + assert isinstance(lang.environment, dict) diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index e6de994..4605bbc 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -276,6 +276,187 @@ def test_file_ref_session_id_optional(self): assert ref.session_id is None +class TestAgentFileSessionIsolation: + """Tests for session isolation when files reference shared agent sessions. + + When multiple users share an agent with attached files, the file references + carry the upload session_id. The orchestrator must NOT blindly reuse that + session, as it would leak state between users. It should only reuse a + file-referenced session if user_id matches. + """ + + @pytest.mark.asyncio + async def test_agent_file_does_not_reuse_upload_session( + self, orchestrator, mock_session_service + ): + """Files reference an upload session (no user_id). New session should be created.""" + from src.models.exec import RequestFile + + # Upload session S1 has no user_id in metadata (agent upload sessions don't) + upload_session = Session( + session_id="upload-session-S1", + status=SessionStatus.ACTIVE, + created_at=datetime.now(), + last_activity=datetime.now(), + expires_at=datetime.now(), + files={}, + metadata={}, # No user_id + working_directory="/workspace", + ) + mock_session_service.get_session = AsyncMock(return_value=upload_session) + + request = ExecRequest( + code="print('hello')", + lang="py", + user_id="userA", + files=[ + RequestFile( + id="file-1", session_id="upload-session-S1", name="data.csv" + ), + ], + ) + ctx = ExecutionContext(request=request, request_id="test-isolation-1") + + session_id = await orchestrator._get_or_create_session(ctx) + + # Should NOT reuse S1 (no user_id in session metadata) + assert session_id == "new-session-456" + mock_session_service.create_session.assert_called_once() + + @pytest.mark.asyncio + async def test_same_user_reuses_own_session( + self, orchestrator, mock_session_service + ): + """Files reference a session created by the same user. Should reuse it.""" + from src.models.exec import RequestFile + + # Session S2 has user_id: "userA" in metadata + user_session = Session( + session_id="user-session-S2", + status=SessionStatus.ACTIVE, + created_at=datetime.now(), + last_activity=datetime.now(), + expires_at=datetime.now(), + files={}, + metadata={"user_id": "userA"}, + working_directory="/workspace", + ) + mock_session_service.get_session = AsyncMock(return_value=user_session) + + request = ExecRequest( + code="print('hello')", + lang="py", + user_id="userA", + files=[ + RequestFile( + id="file-1", session_id="user-session-S2", name="data.csv" + ), + ], + ) + ctx = ExecutionContext(request=request, request_id="test-isolation-2") + + session_id = await orchestrator._get_or_create_session(ctx) + + # Should reuse S2 (same user_id) + assert session_id == "user-session-S2" + mock_session_service.create_session.assert_not_called() + + @pytest.mark.asyncio + async def test_different_user_does_not_reuse_session( + self, orchestrator, mock_session_service + ): + """Files reference a session owned by a different user. New session should be created.""" + from src.models.exec import RequestFile + + # Session S2 has user_id: "userA" + user_a_session = Session( + session_id="user-session-S2", + status=SessionStatus.ACTIVE, + created_at=datetime.now(), + last_activity=datetime.now(), + expires_at=datetime.now(), + files={}, + metadata={"user_id": "userA"}, + working_directory="/workspace", + ) + mock_session_service.get_session = AsyncMock(return_value=user_a_session) + + request = ExecRequest( + code="print('hello')", + lang="py", + user_id="userB", # Different user + files=[ + RequestFile( + id="file-1", session_id="user-session-S2", name="data.csv" + ), + ], + ) + ctx = ExecutionContext(request=request, request_id="test-isolation-3") + + session_id = await orchestrator._get_or_create_session(ctx) + + # Should NOT reuse S2 (different user_id) + assert session_id == "new-session-456" + mock_session_service.create_session.assert_called_once() + + @pytest.mark.asyncio + async def test_no_user_id_creates_new_session( + self, orchestrator, mock_session_service + ): + """Request without user_id should create a new session (no ownership check possible).""" + from src.models.exec import RequestFile + + upload_session = Session( + session_id="upload-session-S1", + status=SessionStatus.ACTIVE, + created_at=datetime.now(), + last_activity=datetime.now(), + expires_at=datetime.now(), + files={}, + metadata={}, + working_directory="/workspace", + ) + mock_session_service.get_session = AsyncMock(return_value=upload_session) + + request = ExecRequest( + code="print('hello')", + lang="py", + # No user_id + files=[ + RequestFile( + id="file-1", session_id="upload-session-S1", name="data.csv" + ), + ], + ) + ctx = ExecutionContext(request=request, request_id="test-isolation-4") + + session_id = await orchestrator._get_or_create_session(ctx) + + # Should create new session (priority 2 requires request.user_id) + assert session_id == "new-session-456" + mock_session_service.create_session.assert_called_once() + + @pytest.mark.asyncio + async def test_entity_id_not_fallback_to_user_id( + self, orchestrator, mock_session_service + ): + """user_id should NOT be used as fallback for entity_id in session lookup.""" + request = ExecRequest( + code="print('hello')", + lang="py", + user_id="userA", + # No entity_id + ) + ctx = ExecutionContext(request=request, request_id="test-isolation-5") + + session_id = await orchestrator._get_or_create_session(ctx) + + # list_sessions_by_entity should NOT be called (no entity_id) + mock_session_service.list_sessions_by_entity.assert_not_called() + # Should create a new session + assert session_id == "new-session-456" + + class TestExplicitFileMounting: """Tests for explicit file mounting behavior.""" diff --git a/tests/unit/test_programmatic.py b/tests/unit/test_programmatic.py new file mode 100644 index 0000000..b0a5439 --- /dev/null +++ b/tests/unit/test_programmatic.py @@ -0,0 +1,595 @@ +"""Unit tests for the Programmatic Tool Calling (PTC) models and service. + +Tests cover: +- Model validation for PTC request/response/tool models +- ProgrammaticService logic with mocked sandbox +""" + +import json +from datetime import datetime +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import ValidationError + +from src.models.programmatic import ( + ProgrammaticExecRequest, + ProgrammaticExecResponse, + PTCToolCall, + PTCToolDefinition, + PTCToolResult, +) +from src.services.programmatic import ( + PTC_DELIMITER, + PTC_MAX_ROUND_TRIPS, + PausedContext, + ProgrammaticService, +) +from src.services.sandbox.nsjail import SandboxInfo + +# ============================================================================= +# FIXTURES +# ============================================================================= + + +@pytest.fixture +def mock_sandbox_info(): + """Create a mock SandboxInfo.""" + return SandboxInfo( + sandbox_id="test-sandbox-123", + sandbox_dir=Path("/tmp/test-sandbox"), + data_dir=Path("/tmp/test-sandbox/data"), + language="py", + session_id="test-session", + created_at=datetime.utcnow(), + repl_mode=False, + ) + + +@pytest.fixture +def mock_sandbox_manager(mock_sandbox_info): + """Create a mock SandboxManager for PTC tests.""" + manager = MagicMock() + manager.create_sandbox.return_value = mock_sandbox_info + manager.destroy_sandbox.return_value = True + manager.copy_content_to_sandbox.return_value = True + manager.executor = MagicMock() + manager.executor._build_sanitized_env.return_value = {"PATH": "/usr/bin"} + return manager + + +@pytest.fixture +def ptc_service(mock_sandbox_manager): + """Create a ProgrammaticService with mocked sandbox manager.""" + return ProgrammaticService(sandbox_manager=mock_sandbox_manager) + + +# ============================================================================= +# MODEL VALIDATION: PTCToolDefinition +# ============================================================================= + + +class TestPTCToolDefinition: + """Tests for PTCToolDefinition model.""" + + def test_minimal_tool_definition(self): + """Tool with just a name should be valid.""" + tool = PTCToolDefinition(name="get_weather") + assert tool.name == "get_weather" + assert tool.description == "" + assert tool.parameters == {} + + def test_full_tool_definition(self): + """Tool with all fields should be valid.""" + tool = PTCToolDefinition( + name="search", + description="Search the web", + parameters={ + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + ) + assert tool.name == "search" + assert tool.description == "Search the web" + assert "properties" in tool.parameters + + def test_tool_definition_requires_name(self): + """Tool without name should fail validation.""" + with pytest.raises(ValidationError): + PTCToolDefinition() + + +# ============================================================================= +# MODEL VALIDATION: PTCToolCall +# ============================================================================= + + +class TestPTCToolCall: + """Tests for PTCToolCall model.""" + + def test_valid_tool_call(self): + """Tool call with id and name should be valid.""" + call = PTCToolCall(id="call-1", name="get_weather") + assert call.id == "call-1" + assert call.name == "get_weather" + assert call.input == {} + + def test_tool_call_with_input(self): + """Tool call with input arguments should be valid.""" + call = PTCToolCall( + id="call-2", + name="search", + input={"query": "python", "limit": 10}, + ) + assert call.input == {"query": "python", "limit": 10} + + def test_tool_call_requires_id(self): + """Tool call without id should fail validation.""" + with pytest.raises(ValidationError): + PTCToolCall(name="get_weather") + + def test_tool_call_requires_name(self): + """Tool call without name should fail validation.""" + with pytest.raises(ValidationError): + PTCToolCall(id="call-1") + + +# ============================================================================= +# MODEL VALIDATION: PTCToolResult +# ============================================================================= + + +class TestPTCToolResult: + """Tests for PTCToolResult model.""" + + def test_valid_result(self): + """Tool result with call_id and result should be valid.""" + result = PTCToolResult(call_id="call-1", result={"temp": 72}) + assert result.call_id == "call-1" + assert result.result == {"temp": 72} + assert result.is_error is False + assert result.error_message is None + + def test_error_result(self): + """Tool result with error should be valid.""" + result = PTCToolResult( + call_id="call-1", + is_error=True, + error_message="Tool not found", + ) + assert result.is_error is True + assert result.error_message == "Tool not found" + assert result.result is None + + def test_result_requires_call_id(self): + """Tool result without call_id should fail validation.""" + with pytest.raises(ValidationError): + PTCToolResult(result="data") + + def test_result_with_string_value(self): + """Tool result with string value should be valid.""" + result = PTCToolResult(call_id="call-1", result="plain text") + assert result.result == "plain text" + + def test_result_with_none_value(self): + """Tool result with None should be valid (default).""" + result = PTCToolResult(call_id="call-1") + assert result.result is None + + +# ============================================================================= +# MODEL VALIDATION: ProgrammaticExecRequest +# ============================================================================= + + +class TestProgrammaticExecRequest: + """Tests for ProgrammaticExecRequest model.""" + + def test_initial_request_with_code(self): + """Initial request with code should be valid.""" + req = ProgrammaticExecRequest( + code="print('hello')", + tools=[PTCToolDefinition(name="tool1")], + ) + assert req.code == "print('hello')" + assert len(req.tools) == 1 + assert req.continuation_token is None + assert req.tool_results == [] + + def test_initial_request_with_all_fields(self): + """Initial request with all optional fields should be valid.""" + req = ProgrammaticExecRequest( + code="print('hello')", + tools=[PTCToolDefinition(name="tool1")], + session_id="sess-123", + user_id="user-456", + entity_id="asst_abc", + timeout=60, + files=[{"filename": "test.txt", "content": "data"}], + ) + assert req.session_id == "sess-123" + assert req.user_id == "user-456" + assert req.entity_id == "asst_abc" + assert req.timeout == 60 + assert len(req.files) == 1 + + def test_continuation_request(self): + """Continuation request with token and results should be valid.""" + req = ProgrammaticExecRequest( + continuation_token="abc123", + tool_results=[ + PTCToolResult(call_id="call-1", result="data"), + ], + ) + assert req.continuation_token == "abc123" + assert len(req.tool_results) == 1 + assert req.code is None + + def test_empty_request_is_valid(self): + """Empty request should pass model validation (API handles logic).""" + req = ProgrammaticExecRequest() + assert req.code is None + assert req.continuation_token is None + + def test_entity_id_pattern_valid(self): + """Entity ID with valid pattern should pass.""" + req = ProgrammaticExecRequest(code="x", entity_id="asst_abc-123") + assert req.entity_id == "asst_abc-123" + + def test_entity_id_pattern_invalid(self): + """Entity ID with invalid characters should fail validation.""" + with pytest.raises(ValidationError): + ProgrammaticExecRequest(code="x", entity_id="invalid entity!@#") + + def test_entity_id_max_length(self): + """Entity ID exceeding max length should fail validation.""" + with pytest.raises(ValidationError): + ProgrammaticExecRequest(code="x", entity_id="a" * 41) + + def test_request_no_tools_defaults_empty(self): + """Request without tools should default to empty list.""" + req = ProgrammaticExecRequest(code="print('hello')") + assert req.tools == [] + + def test_request_no_files_defaults_empty(self): + """Request without files should default to empty list.""" + req = ProgrammaticExecRequest(code="print('hello')") + assert req.files == [] + + +# ============================================================================= +# MODEL VALIDATION: ProgrammaticExecResponse +# ============================================================================= + + +class TestProgrammaticExecResponse: + """Tests for ProgrammaticExecResponse model.""" + + def test_completed_response(self): + """Completed response should have status=completed.""" + resp = ProgrammaticExecResponse( + status="completed", + session_id="sess-123", + stdout="Hello, World!\n", + ) + assert resp.status == "completed" + assert resp.session_id == "sess-123" + assert resp.stdout == "Hello, World!\n" + assert resp.continuation_token is None + assert resp.tool_calls == [] + assert resp.error is None + + def test_tool_call_required_response(self): + """Tool call required response should have token and calls.""" + resp = ProgrammaticExecResponse( + status="tool_call_required", + session_id="sess-123", + continuation_token="token-abc", + tool_calls=[ + PTCToolCall(id="call-1", name="search", input={"q": "test"}), + ], + ) + assert resp.status == "tool_call_required" + assert resp.continuation_token == "token-abc" + assert len(resp.tool_calls) == 1 + assert resp.tool_calls[0].name == "search" + + def test_error_response(self): + """Error response should have status=error and error message.""" + resp = ProgrammaticExecResponse( + status="error", + error="Something went wrong", + ) + assert resp.status == "error" + assert resp.error == "Something went wrong" + assert resp.session_id is None + + def test_response_defaults(self): + """Response should have sensible defaults.""" + resp = ProgrammaticExecResponse(status="completed") + assert resp.stdout == "" + assert resp.stderr == "" + assert resp.files == [] + assert resp.tool_calls == [] + assert resp.continuation_token is None + assert resp.error is None + + def test_response_requires_status(self): + """Response without status should fail validation.""" + with pytest.raises(ValidationError): + ProgrammaticExecResponse() + + +# ============================================================================= +# SERVICE: start_execution +# ============================================================================= + + +class TestProgrammaticServiceStartExecution: + """Tests for ProgrammaticService.start_execution.""" + + async def test_start_execution_ptc_server_not_found( + self, ptc_service, mock_sandbox_manager + ): + """Should return error if ptc_server.py not found.""" + with patch("pathlib.Path.exists", return_value=False): + response = await ptc_service.start_execution( + code="print('hello')", + tools=[], + session_id="sess-123", + ) + + assert response.status == "error" + assert "PTC server script not found" in response.error + + async def test_start_execution_creates_sandbox( + self, ptc_service, mock_sandbox_manager + ): + """Should create sandbox with correct parameters.""" + with ( + patch("pathlib.Path.exists", return_value=True), + patch("pathlib.Path.read_bytes", return_value=b"# ptc_server.py"), + patch("src.services.programmatic.NsjailConfig") as mock_nsjail_config, + patch( + "asyncio.create_subprocess_exec", + new_callable=AsyncMock, + ) as mock_subprocess, + ): + mock_nsjail_config.return_value.build_args.return_value = [ + "--config", + "/tmp/test.cfg", + ] + + # Mock process that returns completed response + mock_proc = AsyncMock() + mock_proc.stdin = AsyncMock() + mock_proc.stdin.write = MagicMock() + mock_proc.stdin.drain = AsyncMock() + mock_proc.returncode = None + mock_proc.pid = 12345 + + completed_response = ( + json.dumps({"type": "completed", "stdout": "hello\n", "stderr": ""}) + + PTC_DELIMITER + ) + + mock_proc.stdout = AsyncMock() + mock_proc.stdout.read = AsyncMock(return_value=completed_response.encode()) + mock_proc.stderr = AsyncMock() + mock_proc.stderr.read = AsyncMock(return_value=b"") + + mock_subprocess.return_value = mock_proc + + await ptc_service.start_execution( + code="print('hello')", + tools=[], + session_id="sess-123", + ) + + mock_sandbox_manager.create_sandbox.assert_called_once_with( + session_id="sess-123", + language="py", + repl_mode=False, + ) + + async def test_start_execution_cleanup_on_exception( + self, ptc_service, mock_sandbox_manager + ): + """Should destroy sandbox on exception.""" + with ( + patch("pathlib.Path.exists", return_value=True), + patch("pathlib.Path.read_bytes", return_value=b"# ptc_server.py"), + patch("src.services.programmatic.NsjailConfig") as mock_nsjail_config, + patch( + "asyncio.create_subprocess_exec", + side_effect=OSError("Cannot start process"), + ), + ): + mock_nsjail_config.return_value.build_args.return_value = [] + + response = await ptc_service.start_execution( + code="print('hello')", + tools=[], + session_id="sess-123", + ) + + assert response.status == "error" + assert "Execution failed" in response.error + mock_sandbox_manager.destroy_sandbox.assert_called_once() + + +# ============================================================================= +# SERVICE: continue_execution +# ============================================================================= + + +class TestProgrammaticServiceContinueExecution: + """Tests for ProgrammaticService.continue_execution.""" + + async def test_continue_invalid_token(self, ptc_service): + """Should return error for invalid continuation token.""" + response = await ptc_service.continue_execution( + continuation_token="nonexistent-token", + tool_results=[], + ) + + assert response.status == "error" + assert "Invalid or expired continuation token" in response.error + + async def test_continue_max_round_trips_exceeded(self, ptc_service): + """Should return error when max round trips exceeded.""" + token = "test-token-123" + + # Create a paused context at max round trips + mock_proc = AsyncMock() + mock_proc.returncode = None + mock_proc.pid = 12345 + + ctx = PausedContext( + sandbox_info=SandboxInfo( + sandbox_id="sb-1", + sandbox_dir=Path("/tmp/sb"), + data_dir=Path("/tmp/sb/data"), + language="py", + session_id="sess-1", + created_at=datetime.utcnow(), + repl_mode=False, + ), + process=mock_proc, + session_id="sess-1", + round_trip_count=PTC_MAX_ROUND_TRIPS, + ) + ptc_service._paused_contexts[token] = ctx + + response = await ptc_service.continue_execution( + continuation_token=token, + tool_results=[PTCToolResult(call_id="c1", result="ok")], + ) + + assert response.status == "error" + assert "Maximum round trips" in response.error + # Context should be cleaned up + assert token not in ptc_service._paused_contexts + + async def test_continue_cancels_timeout(self, ptc_service): + """Should cancel timeout handle when continuing.""" + token = "test-token-456" + + mock_proc = AsyncMock() + mock_proc.returncode = None + mock_proc.pid = 12345 + mock_proc.stdin = AsyncMock() + mock_proc.stdin.write = MagicMock() + mock_proc.stdin.drain = AsyncMock() + + mock_timeout = MagicMock() + + completed_response = ( + json.dumps({"type": "completed", "stdout": "done\n", "stderr": ""}) + + PTC_DELIMITER + ) + mock_proc.stdout = AsyncMock() + mock_proc.stdout.read = AsyncMock(return_value=completed_response.encode()) + mock_proc.stderr = AsyncMock() + mock_proc.stderr.read = AsyncMock(return_value=b"") + + ctx = PausedContext( + sandbox_info=SandboxInfo( + sandbox_id="sb-2", + sandbox_dir=Path("/tmp/sb2"), + data_dir=Path("/tmp/sb2/data"), + language="py", + session_id="sess-2", + created_at=datetime.utcnow(), + repl_mode=False, + ), + process=mock_proc, + session_id="sess-2", + round_trip_count=0, + timeout_handle=mock_timeout, + ) + ptc_service._paused_contexts[token] = ctx + + with patch.object(ptc_service, "_sandbox_manager"): + await ptc_service.continue_execution( + continuation_token=token, + tool_results=[PTCToolResult(call_id="c1", result="ok")], + ) + + mock_timeout.cancel.assert_called_once() + + +# ============================================================================= +# SERVICE: cleanup +# ============================================================================= + + +class TestProgrammaticServiceCleanup: + """Tests for ProgrammaticService cleanup methods.""" + + async def test_cleanup_paused_context(self, ptc_service): + """Should clean up a specific paused context.""" + token = "cleanup-token" + + mock_proc = AsyncMock() + mock_proc.returncode = None + mock_proc.pid = 12345 + mock_proc.wait = AsyncMock() + + mock_timeout = MagicMock() + + ctx = PausedContext( + sandbox_info=SandboxInfo( + sandbox_id="sb-3", + sandbox_dir=Path("/tmp/sb3"), + data_dir=Path("/tmp/sb3/data"), + language="py", + session_id="sess-3", + created_at=datetime.utcnow(), + repl_mode=False, + ), + process=mock_proc, + session_id="sess-3", + timeout_handle=mock_timeout, + ) + ptc_service._paused_contexts[token] = ctx + + await ptc_service._cleanup_paused_context(token) + + assert token not in ptc_service._paused_contexts + mock_timeout.cancel.assert_called_once() + ptc_service._sandbox_manager.destroy_sandbox.assert_called_once() + + async def test_cleanup_nonexistent_token(self, ptc_service): + """Should handle cleanup of nonexistent token gracefully.""" + await ptc_service._cleanup_paused_context("does-not-exist") + # Should not raise + + async def test_cleanup_all(self, ptc_service): + """Should clean up all paused contexts.""" + for i in range(3): + token = f"token-{i}" + mock_proc = AsyncMock() + mock_proc.returncode = None + mock_proc.pid = 12345 + i + mock_proc.wait = AsyncMock() + + ctx = PausedContext( + sandbox_info=SandboxInfo( + sandbox_id=f"sb-{i}", + sandbox_dir=Path(f"/tmp/sb-{i}"), + data_dir=Path(f"/tmp/sb-{i}/data"), + language="py", + session_id=f"sess-{i}", + created_at=datetime.utcnow(), + repl_mode=False, + ), + process=mock_proc, + session_id=f"sess-{i}", + ) + ptc_service._paused_contexts[token] = ctx + + await ptc_service.cleanup_all() + + assert len(ptc_service._paused_contexts) == 0