Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions src/winml/modelkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,20 @@
model = WinMLAutoModel.from_pretrained("facebook/convnext-tiny-224", config=config)
"""

from __future__ import annotations

import importlib
from importlib.metadata import PackageNotFoundError, version
from typing import TYPE_CHECKING

from . import _warnings # Configure warning filters before importing subpackages
from .config import WinMLBuildConfig
from .models import (
WinMLAutoModel,
WinMLModelForImageClassification,
WinMLPreTrainedModel,
)

if TYPE_CHECKING:
from .config import WinMLBuildConfig
from .models import (
WinMLAutoModel,
WinMLModelForImageClassification,
WinMLPreTrainedModel,
)

try:
__version__ = version("winml-modelkit")
Expand All @@ -51,3 +55,28 @@
"WinMLPreTrainedModel",
"__version__",
]

# Lazy imports — heavy ML dependencies (torch, transformers, optimum,
# diffusers) are only loaded when a symbol is actually accessed, so
# lightweight entry-points like ``winml sys`` stay fast.
_LAZY_IMPORT_MAP: dict[str, str] = {
"WinMLBuildConfig": ".config",
"WinMLAutoModel": ".models",
"WinMLModelForImageClassification": ".models",
"WinMLPreTrainedModel": ".models",
}


def __getattr__(name: str) -> object:
module_path = _LAZY_IMPORT_MAP.get(name)
if module_path is not None:
# Configure warning filters once before the first heavy import
if not globals().get("_warnings_loaded"):
globals()["_warnings_loaded"] = True
from . import _warnings
mod = importlib.import_module(module_path, __name__)
attr = getattr(mod, name)
# Cache on the module so __getattr__ is not called again
globals()[name] = attr
return attr
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
132 changes: 72 additions & 60 deletions src/winml/modelkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,78 @@
logger = logging.getLogger(__name__)


@click.group()
class _LazyGroup(click.Group):
"""Click group that discovers and imports commands lazily.

Command modules under ``commands/`` are imported only when the user
actually invokes them (or asks for ``--help``). This avoids pulling
in heavy ML dependencies (torch, transformers, …) for lightweight
sub-commands like ``winml sys``.
"""

def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
super().__init__(*args, **kwargs)
self._commands_dir = Path(__file__).parent / "commands"
# Lazily-discovered module names (without .py, excluding _private)
self._lazy_names: list[str] | None = None

# ------------------------------------------------------------------
def _scan_names(self) -> list[str]:
if self._lazy_names is None:
if self._commands_dir.exists():
self._lazy_names = [
f.stem for f in self._commands_dir.glob("*.py") if not f.name.startswith("_")
]
else:
self._lazy_names = []
return self._lazy_names

# ------------------------------------------------------------------
def list_commands(self, ctx: click.Context) -> list[str]:
# Merge eagerly-registered commands (if any) with lazy ones
eager = set(super().list_commands(ctx))
lazy = set(self._scan_names())
return sorted(eager | lazy)

# ------------------------------------------------------------------
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
# Already registered?
cmd = super().get_command(ctx, cmd_name)
if cmd is not None:
return cmd

# Lazy-load from commands/<cmd_name>.py
if cmd_name not in self._scan_names():
return None

try:
module = import_module(
f".commands.{cmd_name}",
package=__package__,
)
except ImportError as exc:
logger.warning("Failed to import command module %s: %s", cmd_name, exc)
return None
except Exception as exc:
logger.error("Error loading command %s: %s", cmd_name, exc)
return None

# Find the Click command in the module
discovered: click.Command | None = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, click.Group):
discovered = attr
break
if isinstance(attr, click.Command) and discovered is None:
discovered = attr

if discovered is not None:
self.add_command(discovered, name=cmd_name)
return discovered


@click.group(cls=_LazyGroup)
@click.version_option(version=__version__, prog_name="winml")
@click.option(
"--debug",
Expand All @@ -57,64 +128,5 @@ def main(ctx: click.Context, debug: bool) -> None:
ctx.obj["debug"] = debug


def _discover_commands() -> None:
"""Auto-discover Click commands from commands/ directory.

This function scans the commands/ directory for Python modules and
registers any Click commands found. Commands are registered using
the module filename as the command name.

Command Discovery Rules:
- Skips files starting with underscore (_)
- Looks for any object that is a click.Command instance
- Uses module filename (without .py) as command name
"""
commands_dir = Path(__file__).parent / "commands"

# Early exit if commands directory doesn't exist
if not commands_dir.exists():
logger.debug("Commands directory not found: %s", commands_dir)
return

# Scan for Python modules
for py_file in commands_dir.glob("*.py"):
# Skip private modules
if py_file.name.startswith("_"):
continue

module_name = py_file.stem
try:
# Import the module
module = import_module(
f".commands.{module_name}",
package=__package__,
)

# Find Click command in module
# Prefer click.Group over click.Command for hierarchical commands
discovered_command = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, click.Group):
discovered_command = attr
break
if isinstance(attr, click.Command) and discovered_command is None:
discovered_command = attr

if discovered_command:
# Register command with module name
main.add_command(discovered_command, name=module_name)
logger.debug("Discovered command: %s", module_name)

except ImportError as e:
logger.warning("Failed to import command module %s: %s", module_name, e)
except Exception as e:
logger.error("Error loading command %s: %s", module_name, e)


# Discover and register commands at module load time
_discover_commands()


if __name__ == "__main__":
main()
Loading
Loading