diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d68172..68d2ef5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## 0.9.0 - 2025-01-22 + +### Added + +- Support arbitrary command classes through new `:command_class` option + ## 0.8.1 - 2023-09-18 ### Fixed diff --git a/README.md b/README.md index 5818aad..718cb58 100644 --- a/README.md +++ b/README.md @@ -140,3 +140,4 @@ Options: - `show_hidden`: _(Optional, default: `False`)_ Show commands and options that are marked as hidden. - `list_subcommands`: _(Optional, default: `False`)_ List subcommands of a given command. If _attr_list_ is installed, add links to subcommands also. +- `command_class`: _(Optional, default: `click.BaseCommand`)_ The class of the `command` option. diff --git a/mkdocs_click/_extension.py b/mkdocs_click/_extension.py index 60f1e07..01abdb9 100644 --- a/mkdocs_click/_extension.py +++ b/mkdocs_click/_extension.py @@ -28,8 +28,9 @@ def replace_command_docs(has_attr_list: bool = False, **options: Any) -> Iterato remove_ascii_art = options.get("remove_ascii_art", False) show_hidden = options.get("show_hidden", False) list_subcommands = options.get("list_subcommands", False) + command_class = options.get("command_class", "click.BaseCommand") - command_obj = load_command(module, command) + command_obj = load_command(module, command, command_class) prog_name = prog_name or command_obj.name or command diff --git a/mkdocs_click/_loader.py b/mkdocs_click/_loader.py index 1e95d95..742d063 100644 --- a/mkdocs_click/_loader.py +++ b/mkdocs_click/_loader.py @@ -6,20 +6,18 @@ import importlib from typing import Any -import click - from ._exceptions import MkDocsClickException -def load_command(module: str, attribute: str) -> click.BaseCommand: +def load_command(module: str, attribute: str, command_class: str = "click.BaseCommand") -> Any: """ Load and return the Click command object located at ':'. """ command = _load_obj(module, attribute) - if not isinstance(command, click.BaseCommand): + if not isinstance(command, _load_command_class(command_class)): raise MkDocsClickException( - f"{attribute!r} must be a 'click.BaseCommand' object, got {type(command)}" + f"{attribute!r} must be a '{command_class}' object, got {type(command)}" ) return command @@ -35,3 +33,11 @@ def _load_obj(module: str, attribute: str) -> Any: return getattr(mod, attribute) except AttributeError: raise MkDocsClickException(f"Module {module!r} has no attribute {attribute!r}") + + +def _load_command_class(command_class: str) -> Any: + module, attribute = command_class.rsplit(".", 1) + try: + return _load_obj(module, attribute) + except ModuleNotFoundError: + raise MkDocsClickException(f"Could not import {module!r}") diff --git a/tests/test_loader.py b/tests/test_loader.py index 7825b6c..78ac746 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -10,16 +10,42 @@ @pytest.mark.parametrize( - "module, command, exc", + "module, command, exc, command_class", [ - pytest.param("tests.app.cli", "cli", None, id="ok"), + pytest.param("tests.app.cli", "cli", None, "click.BaseCommand", id="ok"), pytest.param( - "tests.app.cli", "doesnotexist", MkDocsClickException, id="command-does-not-exist" + "tests.app.cli", + "doesnotexist", + MkDocsClickException, + "click.BaseCommand", + id="command-does-not-exist", + ), + pytest.param( + "doesnotexist", "cli", ImportError, "click.BaseCommand", id="module-does-not-exist" + ), + pytest.param( + "tests.app.cli", + "NOT_A_COMMAND", + MkDocsClickException, + "click.BaseCommand", + id="not-a-command", + ), + pytest.param( + "tests.app.cli", + "cli", + MkDocsClickException, + "foo.Bar", + id="bad-command-class", + ), + pytest.param( + "tests.app.cli", + "cli", + MkDocsClickException, + "pathlib.Path", + id="arbitrary-command-class", ), - pytest.param("doesnotexist", "cli", ImportError, id="module-does-not-exist"), - pytest.param("tests.app.cli", "NOT_A_COMMAND", MkDocsClickException, id="not-a-command"), ], ) -def test_load_command(module: str, command: str, exc): +def test_load_command(module: str, command: str, exc, command_class: str): with pytest.raises(exc) if exc is not None else nullcontext(): - load_command(module, command) + load_command(module, command, command_class)