-
Notifications
You must be signed in to change notification settings - Fork 44
Add comprehensive test suite for utility modules and core models #297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| """Tests for inference_gateway/cost_hash_map.py — in-memory cost tracking.""" | ||
|
|
||
| import time | ||
| import pytest | ||
|
|
||
| from uuid import uuid4 | ||
| from inference_gateway.cost_hash_map import CostHashMap, CostHashMapEntry, COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS | ||
|
|
||
|
|
||
| class TestCostHashMapEntry: | ||
| """Tests for CostHashMapEntry model.""" | ||
|
|
||
| def test_creation(self): | ||
| entry = CostHashMapEntry(cost=1.5, last_accessed_at=time.time()) | ||
| assert entry.cost == 1.5 | ||
|
|
||
| def test_zero_cost(self): | ||
| entry = CostHashMapEntry(cost=0.0, last_accessed_at=time.time()) | ||
| assert entry.cost == 0.0 | ||
|
|
||
|
|
||
| class TestCostHashMapGetCost: | ||
| """Tests for CostHashMap.get_cost.""" | ||
|
|
||
| def test_unknown_uuid_returns_zero(self): | ||
| chm = CostHashMap() | ||
| assert chm.get_cost(uuid4()) == 0 | ||
|
|
||
| def test_known_uuid_returns_cost(self): | ||
| chm = CostHashMap() | ||
| uid = uuid4() | ||
| chm.add_cost(uid, 3.14) | ||
| assert chm.get_cost(uid) == 3.14 | ||
|
|
||
| def test_multiple_uuids_independent(self): | ||
| chm = CostHashMap() | ||
| uid1, uid2 = uuid4(), uuid4() | ||
| chm.add_cost(uid1, 1.0) | ||
| chm.add_cost(uid2, 2.0) | ||
| assert chm.get_cost(uid1) == 1.0 | ||
| assert chm.get_cost(uid2) == 2.0 | ||
|
|
||
|
|
||
| class TestCostHashMapAddCost: | ||
| """Tests for CostHashMap.add_cost.""" | ||
|
|
||
| def test_add_cost_accumulates(self): | ||
| chm = CostHashMap() | ||
| uid = uuid4() | ||
| chm.add_cost(uid, 1.0) | ||
| chm.add_cost(uid, 2.5) | ||
| chm.add_cost(uid, 0.5) | ||
| assert chm.get_cost(uid) == 4.0 | ||
|
|
||
| def test_add_cost_creates_entry(self): | ||
| chm = CostHashMap() | ||
| uid = uuid4() | ||
| chm.add_cost(uid, 5.0) | ||
| assert uid in chm.cost_hash_map | ||
| assert chm.cost_hash_map[uid].cost == 5.0 | ||
|
|
||
| def test_add_cost_updates_last_accessed(self): | ||
| chm = CostHashMap() | ||
| uid = uuid4() | ||
| before = time.time() | ||
| chm.add_cost(uid, 1.0) | ||
| after = time.time() | ||
| assert before <= chm.cost_hash_map[uid].last_accessed_at <= after | ||
|
|
||
|
|
||
| class TestCostHashMapCleanup: | ||
| """Tests for CostHashMap._cleanup method.""" | ||
|
|
||
| def test_cleanup_removes_stale_entries(self): | ||
| chm = CostHashMap() | ||
| uid = uuid4() | ||
| chm.add_cost(uid, 1.0) | ||
| # Manually age the entry | ||
| chm.cost_hash_map[uid].last_accessed_at = time.time() - COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 10 | ||
| # Force cleanup by setting last_cleanup_at to the past | ||
| chm.last_cleanup_at = time.time() - COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 10 | ||
| chm._cleanup() | ||
| assert uid not in chm.cost_hash_map | ||
|
|
||
| def test_cleanup_preserves_fresh_entries(self): | ||
| chm = CostHashMap() | ||
| uid = uuid4() | ||
| chm.add_cost(uid, 1.0) | ||
| chm.last_cleanup_at = time.time() - COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 10 | ||
| chm._cleanup() | ||
| assert uid in chm.cost_hash_map | ||
|
|
||
| def test_cleanup_skipped_when_recent(self): | ||
| chm = CostHashMap() | ||
| stale_uid = uuid4() | ||
| chm.add_cost(stale_uid, 1.0) | ||
| chm.cost_hash_map[stale_uid].last_accessed_at = time.time() - COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 10 | ||
| # last_cleanup_at is recent, so cleanup should not run | ||
| chm.last_cleanup_at = time.time() | ||
| chm._cleanup() | ||
| # Stale entry should still be there since cleanup was skipped | ||
| assert stale_uid in chm.cost_hash_map |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| """Tests for utils/diff.py — file diff computation, validation, and application.""" | ||
|
|
||
| import os | ||
| import pytest | ||
| import tempfile | ||
| import subprocess | ||
|
|
||
| from unittest.mock import patch, MagicMock | ||
| from utils.diff import get_file_diff, validate_diff_for_local_repo, apply_diff_to_local_repo | ||
|
|
||
|
|
||
| class TestGetFileDiff: | ||
| """Tests for the get_file_diff function.""" | ||
|
|
||
| def _write_temp(self, content: str) -> str: | ||
| fd, path = tempfile.mkstemp(suffix=".txt") | ||
| with os.fdopen(fd, "w") as f: | ||
| f.write(content) | ||
| return path | ||
|
|
||
| def test_identical_files_returns_empty_diff(self): | ||
| path_a = self._write_temp("hello\nworld\n") | ||
| path_b = self._write_temp("hello\nworld\n") | ||
| try: | ||
| diff = get_file_diff(path_a, path_b) | ||
| assert diff.strip() == "" | ||
| finally: | ||
| os.unlink(path_a) | ||
| os.unlink(path_b) | ||
|
|
||
| def test_different_files_returns_unified_diff(self): | ||
| path_a = self._write_temp("line1\nline2\n") | ||
| path_b = self._write_temp("line1\nmodified\n") | ||
| try: | ||
| diff = get_file_diff(path_a, path_b) | ||
| assert "-line2" in diff | ||
| assert "+modified" in diff | ||
| finally: | ||
| os.unlink(path_a) | ||
| os.unlink(path_b) | ||
|
|
||
| def test_diff_header_uses_basename(self): | ||
| path_a = self._write_temp("a\n") | ||
| path_b = self._write_temp("b\n") | ||
| try: | ||
| diff = get_file_diff(path_a, path_b) | ||
| basename = os.path.basename(path_a) | ||
| assert f"--- {basename}" in diff | ||
| assert f"+++ {basename}" in diff | ||
| finally: | ||
| os.unlink(path_a) | ||
| os.unlink(path_b) | ||
|
|
||
| def test_missing_file_raises_exception(self): | ||
| existing = self._write_temp("content\n") | ||
| try: | ||
| with pytest.raises(Exception): | ||
| get_file_diff(existing, "/nonexistent/file.txt") | ||
| finally: | ||
| os.unlink(existing) | ||
|
|
||
| def test_both_files_missing_raises_exception(self): | ||
| with pytest.raises(Exception): | ||
| get_file_diff("/nonexistent/a.txt", "/nonexistent/b.txt") | ||
|
|
||
| def test_added_lines_in_diff(self): | ||
| path_a = self._write_temp("line1\n") | ||
| path_b = self._write_temp("line1\nline2\nline3\n") | ||
| try: | ||
| diff = get_file_diff(path_a, path_b) | ||
| assert "+line2" in diff | ||
| assert "+line3" in diff | ||
| finally: | ||
| os.unlink(path_a) | ||
| os.unlink(path_b) | ||
|
|
||
| def test_removed_lines_in_diff(self): | ||
| path_a = self._write_temp("line1\nline2\nline3\n") | ||
| path_b = self._write_temp("line1\n") | ||
| try: | ||
| diff = get_file_diff(path_a, path_b) | ||
| assert "-line2" in diff | ||
| assert "-line3" in diff | ||
| finally: | ||
| os.unlink(path_a) | ||
| os.unlink(path_b) | ||
|
|
||
|
|
||
| class TestValidateDiffForLocalRepo: | ||
| """Tests for the validate_diff_for_local_repo function.""" | ||
|
|
||
| def _create_git_repo(self, files: dict) -> str: | ||
| repo_dir = tempfile.mkdtemp() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: I noticed the directory created here doesn't appear to be cleaned up after the test finishes, so repeated test runs may leave behind temp directories. If it would be helpful, I think pytest's built-in |
||
| subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, check=True) | ||
| subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_dir, capture_output=True) | ||
| subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_dir, capture_output=True) | ||
| for name, content in files.items(): | ||
| filepath = os.path.join(repo_dir, name) | ||
| os.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
| with open(filepath, "w") as f: | ||
| f.write(content) | ||
| subprocess.run(["git", "add", "."], cwd=repo_dir, capture_output=True, check=True) | ||
| subprocess.run(["git", "commit", "-m", "init"], cwd=repo_dir, capture_output=True, check=True) | ||
| return repo_dir | ||
|
|
||
| def test_valid_diff_returns_true(self): | ||
| repo = self._create_git_repo({"hello.txt": "line1\nline2\n"}) | ||
| diff = "--- a/hello.txt\n+++ b/hello.txt\n@@ -1,2 +1,2 @@\n line1\n-line2\n+modified\n" | ||
| is_valid, error = validate_diff_for_local_repo(diff, repo) | ||
| assert is_valid is True | ||
| assert error is None | ||
|
|
||
| def test_invalid_diff_returns_false(self): | ||
| repo = self._create_git_repo({"hello.txt": "line1\n"}) | ||
| diff = "--- a/nonexistent.txt\n+++ b/nonexistent.txt\n@@ -1 +1 @@\n-old\n+new\n" | ||
| is_valid, error = validate_diff_for_local_repo(diff, repo) | ||
| assert is_valid is False | ||
| assert error is not None | ||
|
|
||
| def test_empty_diff_is_valid(self): | ||
| repo = self._create_git_repo({"hello.txt": "content\n"}) | ||
| is_valid, error = validate_diff_for_local_repo("", repo) | ||
| assert is_valid is True | ||
|
|
||
|
|
||
| class TestApplyDiffToLocalRepo: | ||
| """Tests for the apply_diff_to_local_repo function.""" | ||
|
|
||
| def _create_git_repo(self, files: dict) -> str: | ||
| repo_dir = tempfile.mkdtemp() | ||
| subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, check=True) | ||
| subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_dir, capture_output=True) | ||
| subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_dir, capture_output=True) | ||
| for name, content in files.items(): | ||
| filepath = os.path.join(repo_dir, name) | ||
| os.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
| with open(filepath, "w") as f: | ||
| f.write(content) | ||
| subprocess.run(["git", "add", "."], cwd=repo_dir, capture_output=True, check=True) | ||
| subprocess.run(["git", "commit", "-m", "init"], cwd=repo_dir, capture_output=True, check=True) | ||
| return repo_dir | ||
|
|
||
| def test_apply_valid_diff_modifies_file(self): | ||
| repo = self._create_git_repo({"hello.txt": "line1\nline2\n"}) | ||
| diff = "--- a/hello.txt\n+++ b/hello.txt\n@@ -1,2 +1,2 @@\n line1\n-line2\n+modified\n" | ||
| apply_diff_to_local_repo(diff, repo) | ||
| with open(os.path.join(repo, "hello.txt")) as f: | ||
| assert f.read() == "line1\nmodified\n" | ||
|
|
||
| def test_apply_invalid_diff_raises(self): | ||
| repo = self._create_git_repo({"hello.txt": "content\n"}) | ||
| diff = "--- a/nonexistent.txt\n+++ b/nonexistent.txt\n@@ -1 +1 @@\n-old\n+new\n" | ||
| with pytest.raises(Exception): | ||
| apply_diff_to_local_repo(diff, repo) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| """Tests for models/ — Pydantic model validation for core domain objects.""" | ||
|
|
||
| import pytest | ||
| from uuid import uuid4 | ||
| from datetime import datetime | ||
|
|
||
| from models.evaluation_run import EvaluationRun, EvaluationRunStatus, EvaluationRunErrorCode | ||
|
|
||
|
|
||
| class TestEvaluationRunStatus: | ||
| """Tests for EvaluationRunStatus enum.""" | ||
|
|
||
| def test_all_statuses_exist(self): | ||
| expected = { | ||
| "pending", "initializing_agent", "running_agent", | ||
| "initializing_eval", "running_eval", "finished", "error" | ||
| } | ||
| actual = {s.value for s in EvaluationRunStatus} | ||
| assert expected.issubset(actual) | ||
|
|
||
| def test_status_values_are_strings(self): | ||
| for status in EvaluationRunStatus: | ||
| assert isinstance(status.value, str) | ||
|
|
||
|
|
||
| class TestEvaluationRunErrorCode: | ||
| """Tests for EvaluationRunErrorCode enum.""" | ||
|
|
||
| def test_validator_internal_error_exists(self): | ||
| assert hasattr(EvaluationRunErrorCode, "VALIDATOR_INTERNAL_ERROR") | ||
|
|
||
| def test_validator_unknown_problem_exists(self): | ||
| assert hasattr(EvaluationRunErrorCode, "VALIDATOR_UNKNOWN_PROBLEM") | ||
|
|
||
| def test_get_error_message_returns_string(self): | ||
| for code in EvaluationRunErrorCode: | ||
| msg = code.get_error_message() | ||
| assert isinstance(msg, str) | ||
| assert len(msg) > 0 | ||
|
|
||
| def test_error_code_categories(self): | ||
| assert EvaluationRunErrorCode.AGENT_EXCEPTION_RUNNING_AGENT.is_agent_error() | ||
| assert not EvaluationRunErrorCode.AGENT_EXCEPTION_RUNNING_AGENT.is_validator_error() | ||
| assert EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.is_validator_error() | ||
| assert not EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.is_agent_error() | ||
| assert EvaluationRunErrorCode.PLATFORM_RESTARTED_WHILE_PENDING.is_platform_error() | ||
| assert not EvaluationRunErrorCode.PLATFORM_RESTARTED_WHILE_PENDING.is_agent_error() | ||
|
|
||
| def test_all_agent_errors_in_1xxx_range(self): | ||
| for code in EvaluationRunErrorCode: | ||
| if code.is_agent_error(): | ||
| assert 1000 <= code.value < 2000 | ||
|
|
||
| def test_all_validator_errors_in_2xxx_range(self): | ||
| for code in EvaluationRunErrorCode: | ||
| if code.is_validator_error(): | ||
| assert 2000 <= code.value < 3000 | ||
|
|
||
| def test_all_platform_errors_in_3xxx_range(self): | ||
| for code in EvaluationRunErrorCode: | ||
| if code.is_platform_error(): | ||
| assert 3000 <= code.value < 4000 | ||
|
|
||
|
|
||
| class TestEvaluationRun: | ||
| """Tests for EvaluationRun model.""" | ||
|
|
||
| def test_minimal_creation(self): | ||
| run = EvaluationRun( | ||
| evaluation_run_id=uuid4(), | ||
| evaluation_id=uuid4(), | ||
| problem_name="test-problem", | ||
| status=EvaluationRunStatus.pending, | ||
| created_at=datetime.now(), | ||
| ) | ||
| assert run.status == EvaluationRunStatus.pending | ||
| assert run.patch is None | ||
| assert run.error_code is None | ||
|
|
||
| def test_finished_run_with_results(self): | ||
| from models.problem import ProblemTestResult, ProblemTestResultStatus | ||
| run = EvaluationRun( | ||
| evaluation_run_id=uuid4(), | ||
| evaluation_id=uuid4(), | ||
| problem_name="test-problem", | ||
| status=EvaluationRunStatus.finished, | ||
| patch="--- a/file.py\n+++ b/file.py\n", | ||
| test_results=[ | ||
| ProblemTestResult(name="test1", category="default", status=ProblemTestResultStatus.PASS), | ||
| ], | ||
| created_at=datetime.now(), | ||
| finished_or_errored_at=datetime.now(), | ||
| ) | ||
| assert run.status == EvaluationRunStatus.finished | ||
| assert len(run.test_results) == 1 | ||
|
|
||
| def test_error_run(self): | ||
| run = EvaluationRun( | ||
| evaluation_run_id=uuid4(), | ||
| evaluation_id=uuid4(), | ||
| problem_name="test-problem", | ||
| status=EvaluationRunStatus.error, | ||
| error_code=EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR, | ||
| error_message="Something went wrong", | ||
| created_at=datetime.now(), | ||
| ) | ||
| assert run.error_code == EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion(non-blocking): Since
Exceptionis the root of most exception hierarchies, these assertions would also pass if the source code raised something unexpected likeAttributeErrorfrom a typo. If the function raises a more specific type (e.g.FileNotFoundErroror a custom exception), catching that explicitly would make the tests a bit stricter:That said, I don't know the project's conventions here if broad catching is the preferred style, please disregard.