|
| 1 | +import os |
| 2 | +from typing import Any, Optional |
| 3 | + |
| 4 | +from codegen.agents.client.openapi_client.api.agents_api import AgentsApi |
| 5 | +from codegen.agents.client.openapi_client.api_client import ApiClient |
| 6 | +from codegen.agents.client.openapi_client.configuration import Configuration |
| 7 | +from codegen.agents.client.openapi_client.models.agent_run_response import AgentRunResponse |
| 8 | +from codegen.agents.client.openapi_client.models.create_agent_run_input import CreateAgentRunInput |
| 9 | +from codegen.agents.constants import CODEGEN_BASE_API_URL |
| 10 | + |
| 11 | + |
| 12 | +class AgentTask: |
| 13 | + """Represents an agent run job.""" |
| 14 | + |
| 15 | + def __init__(self, task_data: AgentRunResponse, api_client: ApiClient, org_id: int): |
| 16 | + self.id = task_data.id |
| 17 | + self.org_id = org_id |
| 18 | + self.status = task_data.status |
| 19 | + self.result = task_data.result |
| 20 | + self.web_url = task_data.web_url |
| 21 | + self._api_client = api_client |
| 22 | + self._agents_api = AgentsApi(api_client) |
| 23 | + |
| 24 | + def refresh(self) -> None: |
| 25 | + """Refresh the job status from the API.""" |
| 26 | + if self.id is None: |
| 27 | + return |
| 28 | + |
| 29 | + job_data = self._agents_api.get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get( |
| 30 | + agent_run_id=int(self.id), org_id=int(self.org_id), authorization=f"Bearer {self._api_client.configuration.access_token}" |
| 31 | + ) |
| 32 | + |
| 33 | + # Convert API response to dict for attribute access |
| 34 | + job_dict = {} |
| 35 | + if hasattr(job_data, "__dict__"): |
| 36 | + job_dict = job_data.__dict__ |
| 37 | + elif isinstance(job_data, dict): |
| 38 | + job_dict = job_data |
| 39 | + |
| 40 | + self.status = job_dict.get("status") |
| 41 | + self.result = job_dict.get("result") |
| 42 | + |
| 43 | + |
| 44 | +class Agent: |
| 45 | + """API client for interacting with Codegen AI agents.""" |
| 46 | + |
| 47 | + def __init__(self, token: str, org_id: Optional[int] = None, base_url: Optional[str] = CODEGEN_BASE_API_URL): |
| 48 | + """Initialize a new Agent client. |
| 49 | +
|
| 50 | + Args: |
| 51 | + token: API authentication token |
| 52 | + org_id: Optional organization ID. If not provided, default org will be used. |
| 53 | + """ |
| 54 | + self.token = token |
| 55 | + self.org_id = org_id or int(os.environ.get("CODEGEN_ORG_ID", "1")) # Default to org ID 1 if not specified |
| 56 | + |
| 57 | + # Configure API client |
| 58 | + config = Configuration(host=base_url, access_token=token) |
| 59 | + self.api_client = ApiClient(configuration=config) |
| 60 | + self.agents_api = AgentsApi(self.api_client) |
| 61 | + |
| 62 | + # Current job |
| 63 | + self.current_job = None |
| 64 | + |
| 65 | + def run(self, prompt: str) -> AgentTask: |
| 66 | + """Run an agent with the given prompt. |
| 67 | +
|
| 68 | + Args: |
| 69 | + prompt: The instruction for the agent to execute |
| 70 | +
|
| 71 | + Returns: |
| 72 | + Job: A job object representing the agent run |
| 73 | + """ |
| 74 | + run_input = CreateAgentRunInput(prompt=prompt) |
| 75 | + agent_run_response = self.agents_api.create_agent_run_v1_organizations_org_id_agent_run_post( |
| 76 | + org_id=int(self.org_id), create_agent_run_input=run_input, authorization=f"Bearer {self.token}", _headers={"Content-Type": "application/json"} |
| 77 | + ) |
| 78 | + # Convert API response to dict for Job initialization |
| 79 | + |
| 80 | + job = AgentTask(agent_run_response, self.api_client, self.org_id) |
| 81 | + self.current_job = job |
| 82 | + return job |
| 83 | + |
| 84 | + def get_status(self) -> Optional[dict[str, Any]]: |
| 85 | + """Get the status of the current job. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + dict: A dictionary containing job status information, |
| 89 | + or None if no job has been run. |
| 90 | + """ |
| 91 | + if self.current_job: |
| 92 | + self.current_job.refresh() |
| 93 | + return {"id": self.current_job.id, "status": self.current_job.status, "result": self.current_job.result, "web_url": self.current_job.web_url} |
| 94 | + return None |
0 commit comments