diff --git a/dev/clint/src/clint/config.py b/dev/clint/src/clint/config.py index 12c0642977ddc..20464eb3f54c4 100644 --- a/dev/clint/src/clint/config.py +++ b/dev/clint/src/clint/config.py @@ -10,6 +10,7 @@ class Config: exclude: list[str] # Path -> List of modules that should not be imported globally under that path forbidden_top_level_imports: dict[str, list[str]] + typing_extensions_allowlist: list[str] @classmethod def load(cls) -> Config: @@ -17,4 +18,9 @@ def load(cls) -> Config: data = tomli.load(f) exclude = data["tool"]["clint"]["exclude"] forbidden_imports = data["tool"]["clint"]["forbidden-top-level-imports"] - return cls(exclude, forbidden_imports) + typing_extensions_allowlist = data["tool"]["clint"]["typing-extensions-allowlist"] + return cls( + exclude, + forbidden_imports, + typing_extensions_allowlist, + ) diff --git a/dev/clint/src/clint/linter.py b/dev/clint/src/clint/linter.py index 247b62236b4d4..59edf85890ed6 100644 --- a/dev/clint/src/clint/linter.py +++ b/dev/clint/src/clint/linter.py @@ -352,21 +352,45 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: self.stack.pop() def visit_Import(self, node: ast.Import) -> None: - if self._is_in_function(): - for alias in node.names: - if alias.name.split(".", 1)[0] in BUILTIN_MODULES: - self._check(Location.from_node(node), rules.LazyBuiltinImport()) + for alias in node.names: + root_module = alias.name.split(".", 1)[0] + if self._is_in_function() and root_module in BUILTIN_MODULES: + self._check(Location.from_node(node), rules.LazyBuiltinImport()) - if self._is_at_top_level(): - for alias in node.names: - self._check_forbidden_top_level_import(node, alias.name.split(".", 1)[0]) + if ( + alias.name.split(".", 1)[0] == "typing_extensions" + and alias.name not in self.config.typing_extensions_allowlist + ): + self._check( + Location.from_node(node), + rules.TypingExtensions( + full_name=alias.name, + allowlist=self.config.typing_extensions_allowlist, + ), + ) + + if self._is_at_top_level(): + self._check_forbidden_top_level_import(node, root_module) self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - if self._is_in_function() and node.module.split(".", 1)[0] in BUILTIN_MODULES: + root_module = node.module and node.module.split(".", 1)[0] + if self._is_in_function() and root_module in BUILTIN_MODULES: self._check(Location.from_node(node), rules.LazyBuiltinImport()) + if root_module == "typing_extensions": + for alias in node.names: + full_name = f"{node.module}.{alias.name}" + if full_name not in self.config.typing_extensions_allowlist: + self._check( + Location.from_node(node), + rules.TypingExtensions( + full_name=full_name, + allowlist=self.config.typing_extensions_allowlist, + ), + ) + if self._is_at_top_level(): self._check_forbidden_top_level_import(node, node.module) diff --git a/dev/clint/src/clint/rules.py b/dev/clint/src/clint/rules.py index 508a871e841e0..28cb32de172a8 100644 --- a/dev/clint/src/clint/rules.py +++ b/dev/clint/src/clint/rules.py @@ -302,3 +302,20 @@ def _message(self) -> str: raise ValueError( f"Unexpected type: {self.type_hint}. It must be one of {list(self.MAPPING)}." ) + + +class TypingExtensions(Rule): + def __init__(self, *, full_name: str, allowlist: list[str]) -> None: + self.full_name = full_name + self.allowlist = allowlist + + def _id(self) -> str: + return "MLF0017" + + def _message(self) -> str: + return ( + f"`{self.full_name}` is not allowed to use. Only {self.allowlist} are allowed. " + "You can extend `tool.clint.typing-extensions-allowlist` in `pyproject.toml` if needed " + "but make sure that the version requirement for `typing-extensions` is compatible with " + "the added types." + ) diff --git a/mlflow/data/code_dataset_source.py b/mlflow/data/code_dataset_source.py index fc4e685f0f2e3..31d5400c5b5a8 100644 --- a/mlflow/data/code_dataset_source.py +++ b/mlflow/data/code_dataset_source.py @@ -1,5 +1,7 @@ from typing import Any +from typing_extensions import Self + from mlflow.data.dataset_source import DatasetSource @@ -25,14 +27,14 @@ def _can_resolve(raw_source: Any): return False @classmethod - def _resolve(cls, raw_source: str) -> "CodeDatasetSource": + def _resolve(cls, raw_source: str) -> Self: raise NotImplementedError def to_dict(self) -> dict[Any, Any]: return {"tags": self._tags} @classmethod - def from_dict(cls, source_dict: dict[Any, Any]) -> "CodeDatasetSource": + def from_dict(cls, source_dict: dict[Any, Any]) -> Self: return cls( tags=source_dict.get("tags"), ) diff --git a/pyproject.skinny.toml b/pyproject.skinny.toml index 653e705958996..01f4c8a971405 100644 --- a/pyproject.skinny.toml +++ b/pyproject.skinny.toml @@ -36,6 +36,7 @@ dependencies = [ "pyyaml<7,>=5.1", "requests<3,>=2.17.3", "sqlparse<1,>=0.4.0", + "typing-extensions<5,>=4.0.0", ] [[project.maintainers]] name = "Databricks" diff --git a/pyproject.toml b/pyproject.toml index d69c2a1d3c67d..15931725d3b2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "scipy<2", "sqlalchemy<3,>=1.4.0", "sqlparse<1,>=0.4.0", + "typing-extensions<5,>=4.0.0", "waitress<4; platform_system == 'Windows'", ] [[project.maintainers]] @@ -270,6 +271,10 @@ exclude = [ "mlflow/store/db_migrations", "tests/protos", ] +typing-extensions-allowlist = [ + # Docs: https://typing-extensions.readthedocs.io/en/latest/ + "typing_extensions.Self", # Added in 4.0.0 +] [tool.clint.forbidden-top-level-imports] "mlflow/gateway/providers/*" = ["fastapi", "starlette"] diff --git a/requirements/skinny-requirements.txt b/requirements/skinny-requirements.txt index 123b3d022128a..2eb1c57ba87ee 100644 --- a/requirements/skinny-requirements.txt +++ b/requirements/skinny-requirements.txt @@ -16,3 +16,4 @@ opentelemetry-api<3,>=1.9.0 opentelemetry-sdk<3,>=1.9.0 databricks-sdk<1,>=0.20.0 pydantic<3,>=1.0 +typing-extensions<5,>=4.0.0 diff --git a/requirements/skinny-requirements.yaml b/requirements/skinny-requirements.yaml index 26f5f86a0e7a6..ae5e04bcf67e3 100644 --- a/requirements/skinny-requirements.yaml +++ b/requirements/skinny-requirements.yaml @@ -79,3 +79,8 @@ pydantic: pip_release: pydantic minimum: "1.0" max_major_version: 2 + +typing-extensions: + pip_release: typing-extensions + minimum: "4.0.0" + max_major_version: 4