Skip to content

Commit 95a6e29

Browse files
committed
Add runtime type checking
This patch adds beartype for runtime type checking. This gives us the best of both worlds: we do static type checking of our own library with mypy, and we export our static types, but for clients who do not run static type checking of their own code, runtime type checking in our library can help them catch bugs earlier. `tests/test_codex_tool.py::test_bad_argument_type` serves as an example: this fails at initialization time of `CodexTool`, whereas without runtime type checking, this would fail later (e.g., when the user calls the `query` method on the object). Because we're performing runtime type checking, some of the imports that were behind `if TYPE_CHECKING` flags have to be moved to runtime. This patch updates the linter config to allow imports that are only used for type checking. This patch also switches to consistent `from __future__ import annotations` everywhere, stops using type hints deprecated by PEP 585, and uses PEP 585 / PEP 604 syntax everywhere. This patch updates the linter config to match this style. beartype relies on `isinstance` for runtime type checks, which needs to be taken into account when using mocks by overriding the `__class__` attribute. This patch updates the tests accordingly.
1 parent 453263b commit 95a6e29

File tree

12 files changed

+78
-46
lines changed

12 files changed

+78
-46
lines changed

pyproject.toml

+6-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers = [
2727
dependencies = [
2828
"codex-sdk==0.1.0a9",
2929
"pydantic>=1.9.0, <3",
30+
"beartype>=0.17.0",
3031
]
3132

3233
[project.urls]
@@ -98,4 +99,8 @@ html = "coverage html"
9899
xml = "coverage xml"
99100

100101
[tool.ruff.lint]
101-
ignore = ["FA100", "UP007", "UP006"]
102+
ignore = [
103+
"TCH001", # this package does run-time type checking
104+
"TCH002",
105+
"TCH003"
106+
]

src/cleanlab_codex/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
# SPDX-License-Identifier: MIT
2+
3+
from beartype.claw import beartype_this_package
4+
5+
# this must run before any other imports from the cleanlab_codex package
6+
beartype_this_package()
7+
8+
# ruff: noqa: E402
29
from cleanlab_codex.codex import Codex
310
from cleanlab_codex.codex_tool import CodexTool
411

src/cleanlab_codex/codex.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Optional
4-
53
from cleanlab_codex.internal.project import create_project, query_project
64
from cleanlab_codex.internal.utils import init_codex_client
7-
8-
if TYPE_CHECKING:
9-
from cleanlab_codex.types.entry import Entry, EntryCreate
10-
from cleanlab_codex.types.organization import Organization
5+
from cleanlab_codex.types.entry import Entry, EntryCreate
6+
from cleanlab_codex.types.organization import Organization
117

128

139
class Codex:
@@ -41,7 +37,7 @@ def list_organizations(self) -> list[Organization]:
4137
"""
4238
return self._client.users.myself.organizations.list().organizations
4339

44-
def create_project(self, name: str, organization_id: str, description: Optional[str] = None) -> str:
40+
def create_project(self, name: str, organization_id: str, description: str | None = None) -> str:
4541
"""Create a new Codex project.
4642
4743
Args:
@@ -77,7 +73,7 @@ def create_project_access_key(
7773
self,
7874
project_id: str,
7975
access_key_name: str,
80-
access_key_description: Optional[str] = None,
76+
access_key_description: str | None = None,
8177
) -> str:
8278
"""Create a new access key for a project.
8379
@@ -99,15 +95,15 @@ def query(
9995
self,
10096
question: str,
10197
*,
102-
project_id: Optional[str] = None, # TODO: update to uuid once project IDs are changed to UUIDs
103-
fallback_answer: Optional[str] = None,
98+
project_id: str | None = None, # TODO: update to uuid once project IDs are changed to UUIDs
99+
fallback_answer: str | None = None,
104100
read_only: bool = False,
105-
) -> tuple[Optional[str], Optional[Entry]]:
101+
) -> tuple[str | None, Entry | None]:
106102
"""Query Codex to check if the Codex project contains an answer to this question and add the question to the Codex project for SME review if it does not.
107103
108104
Args:
109105
question (str): The question to ask the Codex API.
110-
project_id (:obj:`int`, optional): The ID of the project to query.
106+
project_id (:obj:`str`, optional): The ID of the project to query.
111107
If the client is authenticated with a user-level API Key, this is required.
112108
If the client is authenticated with a project-level Access Key, this is optional. The client will use the Access Key's project ID by default.
113109
fallback_answer (:obj:`str`, optional): Optional fallback answer to return if Codex is unable to answer the question.

src/cleanlab_codex/codex_tool.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, ClassVar, Optional
3+
from typing import Any, ClassVar
44

55
from cleanlab_codex.codex import Codex
66

@@ -23,8 +23,8 @@ def __init__(
2323
self,
2424
codex_client: Codex,
2525
*,
26-
project_id: Optional[str] = None,
27-
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
26+
project_id: str | None = None,
27+
fallback_answer: str | None = DEFAULT_FALLBACK_ANSWER,
2828
):
2929
self._codex_client = codex_client
3030
self._project_id = project_id
@@ -35,8 +35,8 @@ def from_access_key(
3535
cls,
3636
access_key: str,
3737
*,
38-
project_id: Optional[str] = None,
39-
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
38+
project_id: str | None = None,
39+
fallback_answer: str | None = DEFAULT_FALLBACK_ANSWER,
4040
) -> CodexTool:
4141
"""Creates a CodexTool from an access key. The project ID that the CodexTool will use is the one that is associated with the access key."""
4242
return cls(
@@ -50,8 +50,8 @@ def from_client(
5050
cls,
5151
codex_client: Codex,
5252
*,
53-
project_id: Optional[str] = None,
54-
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
53+
project_id: str | None = None,
54+
fallback_answer: str | None = DEFAULT_FALLBACK_ANSWER,
5555
) -> CodexTool:
5656
"""Creates a CodexTool from a Codex client.
5757
If the Codex client is initialized with a project access key, the CodexTool will use the project ID that is associated with the access key.
@@ -74,16 +74,16 @@ def tool_description(self) -> str:
7474
return self._tool_description
7575

7676
@property
77-
def fallback_answer(self) -> Optional[str]:
77+
def fallback_answer(self) -> str | None:
7878
"""The fallback answer to use if the Codex project cannot answer the question."""
7979
return self._fallback_answer
8080

8181
@fallback_answer.setter
82-
def fallback_answer(self, value: Optional[str]) -> None:
82+
def fallback_answer(self, value: str | None) -> None:
8383
"""Sets the fallback answer to use if the Codex project cannot answer the question."""
8484
self._fallback_answer = value
8585

86-
def query(self, question: str) -> Optional[str]:
86+
def query(self, question: str) -> str | None:
8787
"""Asks an all-knowing advisor this question in cases where it cannot be answered from the provided Context. If the answer is not available, this returns a fallback answer or None.
8888
8989
Args:

src/cleanlab_codex/internal/project.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Optional
4-
5-
if TYPE_CHECKING:
6-
from codex import Codex as _Codex
7-
8-
from cleanlab_codex.types.entry import Entry
3+
from codex import Codex as _Codex
94

5+
from cleanlab_codex.types.entry import Entry
106
from cleanlab_codex.types.project import ProjectConfig
117

128

@@ -17,7 +13,7 @@ def __str__(self) -> str:
1713
return "project_id is required when authenticating with a user-level API Key"
1814

1915

20-
def create_project(client: _Codex, name: str, organization_id: str, description: Optional[str] = None) -> str:
16+
def create_project(client: _Codex, name: str, organization_id: str, description: str | None = None) -> str:
2117
project = client.projects.create(
2218
config=ProjectConfig(),
2319
organization_id=organization_id,
@@ -31,10 +27,10 @@ def query_project(
3127
client: _Codex,
3228
question: str,
3329
*,
34-
project_id: Optional[str] = None,
35-
fallback_answer: Optional[str] = None,
30+
project_id: str | None = None,
31+
fallback_answer: str | None = None,
3632
read_only: bool = False,
37-
) -> tuple[Optional[str], Optional[Entry]]:
33+
) -> tuple[str | None, Entry | None]:
3834
if client.access_key is not None:
3935
project_id = client.projects.access_keys.retrieve_project_id().project_id
4036
elif project_id is None:

src/cleanlab_codex/utils/llamaindex.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
from collections.abc import Callable
34
from inspect import signature
4-
from typing import Any, Callable
5+
from typing import Any
56

67
from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model
78

src/cleanlab_codex/utils/openai.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, List, Literal
3+
from typing import Any, Literal
44

55
from pydantic import BaseModel
66

@@ -12,8 +12,8 @@ class Property(BaseModel):
1212

1313
class FunctionParameters(BaseModel):
1414
type: Literal["object"] = "object"
15-
properties: Dict[str, Property]
16-
required: List[str]
15+
properties: dict[str, Property]
16+
required: list[str]
1717

1818

1919
class Function(BaseModel):
@@ -30,9 +30,9 @@ class Tool(BaseModel):
3030
def format_as_openai_tool(
3131
tool_name: str,
3232
tool_description: str,
33-
tool_properties: Dict[str, Any],
34-
required_properties: List[str],
35-
) -> Dict[str, Any]:
33+
tool_properties: dict[str, Any],
34+
required_properties: list[str],
35+
) -> dict[str, Any]:
3636
return Tool(
3737
function=Function(
3838
name=tool_name,
+6-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from typing import Callable, Dict, Optional
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
24

35
from smolagents import Tool # type: ignore
46

57

68
class CodexTool(Tool): # type: ignore[misc]
79
def __init__(
810
self,
9-
query: Callable[[str], Optional[str]],
11+
query: Callable[[str], str | None],
1012
tool_name: str,
1113
tool_description: str,
12-
inputs: Dict[str, Dict[str, str]],
14+
inputs: dict[str, dict[str, str]],
1315
):
1416
super().__init__()
1517
self._query = query
@@ -18,5 +20,5 @@ def __init__(
1820
self.inputs = inputs
1921
self.output_type = "string"
2022

21-
def forward(self, question: str) -> Optional[str]:
23+
def forward(self, question: str) -> str | None:
2224
return self._query(question)

tests/fixtures/client.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
from typing import Generator
1+
from __future__ import annotations
2+
3+
from collections.abc import Generator
24
from unittest.mock import MagicMock, patch
35

46
import pytest
57

8+
from cleanlab_codex.internal.utils import _Codex
9+
610

711
@pytest.fixture
812
def mock_client() -> Generator[MagicMock, None, None]:
913
with patch("cleanlab_codex.codex.init_codex_client") as mock_init:
1014
mock_client = MagicMock()
15+
mock_client.__class__ = _Codex
1116
mock_init.return_value = mock_client
1217
yield mock_client

tests/internal/test_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from __future__ import annotations
2+
13
import os
24
from unittest.mock import MagicMock, patch
35

46
import pytest
57

6-
from cleanlab_codex.internal.utils import MissingAuthKeyError, init_codex_client, is_access_key
8+
from cleanlab_codex.internal.utils import MissingAuthKeyError, _Codex, init_codex_client, is_access_key
79

810
DUMMY_ACCESS_KEY = "sk-1-EMOh6UrRo7exTEbEi8_azzACAEdtNiib2LLa1IGo6kA"
911
DUMMY_API_KEY = "GP0FzPfA7wYy5L64luII2YaRT2JoSXkae7WEo7dH6Bw"
@@ -16,6 +18,7 @@ def test_is_access_key() -> None:
1618

1719
def test_init_codex_client_access_key() -> None:
1820
mock_client = MagicMock()
21+
mock_client.__class__ = _Codex
1922
with patch("cleanlab_codex.internal.utils._Codex", autospec=True, return_value=mock_client) as mock_init:
2023
mock_client.projects.access_keys.retrieve_project_id.return_value = "test_project_id"
2124
client = init_codex_client(DUMMY_ACCESS_KEY)
@@ -25,6 +28,7 @@ def test_init_codex_client_access_key() -> None:
2528

2629
def test_init_codex_client_api_key() -> None:
2730
mock_client = MagicMock()
31+
mock_client.__class__ = _Codex
2832
with patch("cleanlab_codex.internal.utils._Codex", autospec=True, return_value=mock_client) as mock_init:
2933
mock_client.users.myself.api_key.retrieve.return_value = "test_project_id"
3034
client = init_codex_client(DUMMY_API_KEY)
@@ -40,6 +44,7 @@ def test_init_codex_client_no_key() -> None:
4044
def test_init_codex_client_access_key_env_var() -> None:
4145
with patch.dict(os.environ, {"CODEX_ACCESS_KEY": DUMMY_ACCESS_KEY}):
4246
mock_client = MagicMock()
47+
mock_client.__class__ = _Codex
4348
with patch(
4449
"cleanlab_codex.internal.utils._Codex",
4550
autospec=True,
@@ -54,6 +59,7 @@ def test_init_codex_client_access_key_env_var() -> None:
5459
def test_init_codex_client_api_key_env_var() -> None:
5560
with patch.dict(os.environ, {"CODEX_API_KEY": DUMMY_API_KEY}):
5661
mock_client = MagicMock()
62+
mock_client.__class__ = _Codex
5763
with patch(
5864
"cleanlab_codex.internal.utils._Codex",
5965
autospec=True,

tests/test_codex.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# ruff: noqa: DTZ005
22

3+
from __future__ import annotations
4+
35
import uuid
46
from datetime import datetime
57
from unittest.mock import MagicMock
@@ -86,6 +88,9 @@ def test_create_project_access_key(mock_client: MagicMock) -> None:
8688
codex = Codex("")
8789
access_key_name = "Test Access Key"
8890
access_key_description = "Test Access Key Description"
91+
access_key = MagicMock()
92+
access_key.token.__class__ = str
93+
mock_client.projects.access_keys.create.return_value = access_key
8994
codex.create_project_access_key(FAKE_PROJECT_ID, access_key_name, access_key_description)
9095
mock_client.projects.access_keys.create.assert_called_once_with(
9196
project_id=FAKE_PROJECT_ID,

tests/test_codex_tool.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import sys
24
from unittest.mock import MagicMock
35

@@ -34,3 +36,10 @@ def test_to_smolagents_tool(mock_client: MagicMock) -> None: # noqa: ARG001
3436
assert isinstance(smolagents_tool, Tool)
3537
assert smolagents_tool.name == tool.tool_name
3638
assert smolagents_tool.description == tool.tool_description
39+
40+
41+
def test_bad_argument_type() -> None:
42+
from beartype.roar import BeartypeException
43+
44+
with pytest.raises(BeartypeException):
45+
CodexTool("asdf")

0 commit comments

Comments
 (0)