diff --git a/src/agentunit/adapters/autogen_ag2.py b/src/agentunit/adapters/autogen_ag2.py index 377f209..51ff324 100644 --- a/src/agentunit/adapters/autogen_ag2.py +++ b/src/agentunit/adapters/autogen_ag2.py @@ -318,7 +318,7 @@ def end_session(self, session_id: SessionID) -> ScenarioResult: case_id=session_id, success=True, metrics=metrics, - duration_ms=0.0, # TODO: Track actual duration + duration_ms=((session["end_time"] - session["start_time"]).total_seconds() * 1000), trace=trace, ) @@ -340,6 +340,19 @@ def _calculate_session_metrics(self, session_id: SessionID) -> dict[str, Any]: interaction for interaction in self.interactions if interaction.session_id == session_id ] + response_times = [] + + # Go through messages one by one + for i in range(1, len(session_interactions)): + previous_message_time = session_interactions[i - 1].timestamp + current_message_time = session_interactions[i].timestamp + + time_difference = (current_message_time - previous_message_time).total_seconds() + + response_times.append(time_difference) + + average_response_time = sum(response_times) / len(response_times) if response_times else 0.0 + # Basic metrics metrics = { "total_messages": len(session_interactions), @@ -347,7 +360,7 @@ def _calculate_session_metrics(self, session_id: SessionID) -> dict[str, Any]: "duration_seconds": ( session.get("end_time", datetime.now()) - session["start_time"] ).total_seconds(), - "average_response_time": 0.0, # Would need timing data + "average_response_time": average_response_time, "conversation_turns": session.get("message_count", 0), } diff --git a/src/agentunit/datasets/base.py b/src/agentunit/datasets/base.py index 093ba9f..9582fd5 100644 --- a/src/agentunit/datasets/base.py +++ b/src/agentunit/datasets/base.py @@ -79,7 +79,25 @@ def _loader() -> Iterable[DatasetCase]: return DatasetSource(name=str(file_path.stem), loader=_loader) -def load_local_csv(path: str | Path) -> DatasetSource: +def _parse_list_field(value: str | None, delimiter: str) -> list[str] | None: + """Safely parse a delimited list field from CSV.""" + if not value or not isinstance(value, str): + return None + + if not delimiter: + # Empty delimiter is invalid for str.split + return [value.strip()] if value.strip() else None + + parts = [item.strip() for item in value.split(delimiter)] + cleaned = [p for p in parts if p] + return cleaned or None + + +def load_local_csv( + path: str | Path, + tools_delimiter: str = ";", + context_delimiter: str = "||", +) -> DatasetSource: file_path = Path(path) if not file_path.exists(): msg = f"Dataset file not found: {file_path}" @@ -88,18 +106,29 @@ def load_local_csv(path: str | Path) -> DatasetSource: def _loader() -> Iterable[DatasetCase]: with file_path.open(newline="", encoding="utf-8") as fh: reader = csv.DictReader(fh) + for idx, row in enumerate(reader): - yield DatasetCase( - id=row.get("id") or f"case-{idx}", - query=row["query"], - expected_output=row.get("expected_output"), - tools=row.get("tools", "").split(";") if row.get("tools") else None, - context=row.get("context", "").split("||") if row.get("context") else None, - metadata={ - k: v - for k, v in row.items() - if k not in {"id", "query", "expected_output", "tools", "context"} - }, - ) + try: + yield DatasetCase( + id=row.get("id") or f"case-{idx}", + query=row["query"], + expected_output=row.get("expected_output"), + tools=_parse_list_field(row.get("tools"), tools_delimiter), + context=_parse_list_field(row.get("context"), context_delimiter), + metadata={ + k: v + for k, v in row.items() + if k + not in { + "id", + "query", + "expected_output", + "tools", + "context", + } + }, + ) + except Exception as exc: + raise AgentUnitError(f"Malformed CSV row at index {idx}: {row}") from exc return DatasetSource(name=str(file_path.stem), loader=_loader)