Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,19 @@ builder = [
"docker"
]

model-service = [
"psutil",
"alibabacloud_cr20181201==2.0.5",
"swebench",
"fastapi",
"uvicorn",
]

all = [
"rl-rock[admin]",
"rl-rock[rocklet]",
"rl-rock[builder]",
"rl-rock[model-service]",
]

[project.scripts]
Expand Down
2 changes: 2 additions & 0 deletions rock/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

# Model Service Config
ROCK_MODEL_SERVICE_DATA_DIR: str
ROCK_MODEL_SERVICE_PROXY_TARGET_URL: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里至少得是个map<model_name, base_url>。一个agent可能得调用多个模型,不同的模型可能是不同的服务商。


# Agentic
ROCK_AGENT_PRE_STARTUP_BASH_CMD_LIST: list[str] = []
Expand Down Expand Up @@ -79,6 +80,7 @@
"ROCK_CLI_DEFAULT_CONFIG_PATH", Path.home() / ".rock" / "config.ini"
),
"ROCK_MODEL_SERVICE_DATA_DIR": lambda: os.getenv("ROCK_MODEL_SERVICE_DATA_DIR", "/data/logs"),
"ROCK_MODEL_SERVICE_PROXY_TARGET_URL": lambda: os.getenv("ROCK_MODEL_SERVICE_PROXY_TARGET_URL", ""),
"ROCK_AGENT_PYTHON_INSTALL_CMD": lambda: os.getenv(
"ROCK_AGENT_PYTHON_INSTALL_CMD",
"[ -f cpython31114.tar.gz ] && rm cpython31114.tar.gz; [ -d python ] && rm -rf python; wget -q -O cpython31114.tar.gz https://github.com/astral-sh/python-build-standalone/releases/download/20251120/cpython-3.11.14+20251120-x86_64-unknown-linux-gnu-install_only.tar.gz && tar -xzf cpython31114.tar.gz",
Expand Down
121 changes: 119 additions & 2 deletions rock/sdk/model/server/api/proxy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,127 @@
from typing import Any

from fastapi import APIRouter, Request
import httpx
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse

from rock.logger import init_logger
from rock.sdk.model.server.config import PROXY_TARGET_URL

logger = init_logger(__name__)

proxy_router = APIRouter()


async def forward_non_streaming_request(
body: dict[str, Any], headers: dict[str, str], target_url: str
) -> tuple[Any, int]:
"""Forward non-streaming request to target API"""
async with httpx.AsyncClient() as client:
try:
logger.info(f"Forwarding non-streaming request body: {body}")
logger.info(
f"Forwarding headers: {['Authorization' if k.lower() == 'authorization' else k for k in headers.keys()] if headers else 'No headers'}"
)

# Use provided headers to forward the request
response = await client.post(
target_url,
json=body,
headers=headers,
timeout=120.0, # Set timeout to 60 seconds
)

logger.info(f"Target API non-streaming response status: {response.status_code}")

# Try to parse the response as JSON
try:
response_data = response.json()
logger.info(f"Target API non-streaming response data: {response_data}")
return response_data, response.status_code
except Exception:
# If response is not JSON, return as text
response_text = response.text
logger.info(f"Target API non-streaming response text: {response_text}")
return response_text, response.status_code

except httpx.TimeoutException:
logger.error("Request to target API timed out")
raise HTTPException(status_code=504, detail="Request to target API timed out")
except httpx.RequestError as e:
logger.error(f"Error making non-streaming request to target API: {str(e)}")
raise HTTPException(status_code=502, detail=f"Error contacting target API: {str(e)}")
except Exception as e:
logger.error(f"Unknown error making non-streaming request to target API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal proxy error: {str(e)}")


async def forward_streaming_request(
body: dict[str, Any], headers: dict[str, str], target_url: str
) -> StreamingResponse:
"""Forward streaming request to target API"""
async with httpx.AsyncClient() as client:
try:
logger.info(f"Forwarding streaming request body: {body}")
logger.info(
f"Forwarding headers: {['Authorization' if k.lower() == 'authorization' else k for k in headers.keys()] if headers else 'No headers'}"
)

# Use provided headers to forward the request
response = await client.post(
target_url,
json=body,
headers=headers,
timeout=120.0, # Set timeout to 60 seconds
)

logger.info(f"Target API streaming response status: {response.status_code}")

# Handle streaming response
content_type = response.headers.get("content-type", "")

async def generate():
# Stream response data in chunks
async for chunk in response.aiter_bytes():
yield chunk

return StreamingResponse(generate(), media_type=content_type)

except httpx.TimeoutException:
logger.error("Request to target API timed out")
raise HTTPException(status_code=504, detail="Request to target API timed out")
except httpx.RequestError as e:
logger.error(f"Error making streaming request to target API: {str(e)}")
raise HTTPException(status_code=502, detail=f"Error contacting target API: {str(e)}")
except Exception as e:
logger.error(f"Unknown error making streaming request to target API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal proxy error: {str(e)}")


@proxy_router.post("/v1/chat/completions")
async def chat_completions(body: dict[str, Any], request: Request):
raise NotImplementedError("Proxy chat completions not implemented yet")
# Build forwarded headers while preserving original request headers
forwarded_headers = {}
for key, value in request.headers.items():
# Copy all headers, but skip certain headers that httpx should set automatically
if key.lower() in ["content-length", "content-type", "host", "transfer-encoding"]:
continue # Let httpx set these headers
forwarded_headers[key] = value

logger.info(f"Received request at proxy endpoint with body: {body}")

# Determine target URL
target_url = PROXY_TARGET_URL

# Choose handler based on stream parameter
if body.get("stream", False):
# Forward streaming request
result = await forward_streaming_request(body, forwarded_headers, target_url)
return result
else:
# Forward non-streaming request
response_data, status_code = await forward_non_streaming_request(body, forwarded_headers, target_url)

if status_code == 200:
return response_data
else:
return JSONResponse(content=response_data, status_code=status_code)
3 changes: 3 additions & 0 deletions rock/sdk/model/server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@
RESPONSE_START_MARKER = "LLM_RESPONSE_START"
RESPONSE_END_MARKER = "LLM_RESPONSE_END"
SESSION_END_MARKER = "SESSION_END"

# proxy_api
PROXY_TARGET_URL = env_vars.ROCK_MODEL_SERVICE_PROXY_TARGET_URL
Loading