diff --git a/src/agentunit/datasets/base.py b/src/agentunit/datasets/base.py index 093ba9f..9582fd5 100644 --- a/src/agentunit/datasets/base.py +++ b/src/agentunit/datasets/base.py @@ -79,7 +79,25 @@ def _loader() -> Iterable[DatasetCase]: return DatasetSource(name=str(file_path.stem), loader=_loader) -def load_local_csv(path: str | Path) -> DatasetSource: +def _parse_list_field(value: str | None, delimiter: str) -> list[str] | None: + """Safely parse a delimited list field from CSV.""" + if not value or not isinstance(value, str): + return None + + if not delimiter: + # Empty delimiter is invalid for str.split + return [value.strip()] if value.strip() else None + + parts = [item.strip() for item in value.split(delimiter)] + cleaned = [p for p in parts if p] + return cleaned or None + + +def load_local_csv( + path: str | Path, + tools_delimiter: str = ";", + context_delimiter: str = "||", +) -> DatasetSource: file_path = Path(path) if not file_path.exists(): msg = f"Dataset file not found: {file_path}" @@ -88,18 +106,29 @@ def load_local_csv(path: str | Path) -> DatasetSource: def _loader() -> Iterable[DatasetCase]: with file_path.open(newline="", encoding="utf-8") as fh: reader = csv.DictReader(fh) + for idx, row in enumerate(reader): - yield DatasetCase( - id=row.get("id") or f"case-{idx}", - query=row["query"], - expected_output=row.get("expected_output"), - tools=row.get("tools", "").split(";") if row.get("tools") else None, - context=row.get("context", "").split("||") if row.get("context") else None, - metadata={ - k: v - for k, v in row.items() - if k not in {"id", "query", "expected_output", "tools", "context"} - }, - ) + try: + yield DatasetCase( + id=row.get("id") or f"case-{idx}", + query=row["query"], + expected_output=row.get("expected_output"), + tools=_parse_list_field(row.get("tools"), tools_delimiter), + context=_parse_list_field(row.get("context"), context_delimiter), + metadata={ + k: v + for k, v in row.items() + if k + not in { + "id", + "query", + "expected_output", + "tools", + "context", + } + }, + ) + except Exception as exc: + raise AgentUnitError(f"Malformed CSV row at index {idx}: {row}") from exc return DatasetSource(name=str(file_path.stem), loader=_loader)