|
| 1 | +from abc import ABC, abstractmethod |
| 2 | +from enum import Enum |
| 3 | +from inspect import isabstract |
| 4 | +from typing import ( |
| 5 | + TYPE_CHECKING, |
| 6 | + Any, |
| 7 | + ClassVar, |
| 8 | + Literal, |
| 9 | + Self, |
| 10 | + Type, |
| 11 | +) |
| 12 | + |
| 13 | +from pydantic import BaseModel, ConfigDict, Field, Tag |
| 14 | +from pydantic_core import PydanticUndefined |
| 15 | + |
| 16 | +from invokeai.app.util.misc import uuid_string |
| 17 | +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk |
| 18 | +from invokeai.backend.model_manager.taxonomy import ( |
| 19 | + AnyVariant, |
| 20 | + BaseModelType, |
| 21 | + ModelFormat, |
| 22 | + ModelRepoVariant, |
| 23 | + ModelSourceType, |
| 24 | + ModelType, |
| 25 | +) |
| 26 | + |
| 27 | +if TYPE_CHECKING: |
| 28 | + pass |
| 29 | + |
| 30 | + |
| 31 | +class Config_Base(ABC, BaseModel): |
| 32 | + """ |
| 33 | + Abstract base class for model configurations. A model config describes a specific combination of model base, type and |
| 34 | + format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format |
| 35 | + would have base=sd-1, type=main, format=checkpoint. |
| 36 | +
|
| 37 | + To create a new config type, inherit from this class and implement its interface: |
| 38 | + - Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be |
| 39 | + called during model installation to determine the correct config class for a model. |
| 40 | + - Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A |
| 41 | + default must be provided for each of these fields. |
| 42 | +
|
| 43 | + If multiple combinations of base, type and format need to be supported, create a separate subclass for each. |
| 44 | +
|
| 45 | + See MinimalConfigExample in test_model_probe.py for an example implementation. |
| 46 | + """ |
| 47 | + |
| 48 | + # These fields are common to all model configs. |
| 49 | + |
| 50 | + key: str = Field( |
| 51 | + default_factory=uuid_string, |
| 52 | + description="A unique key for this model.", |
| 53 | + ) |
| 54 | + hash: str = Field( |
| 55 | + description="The hash of the model file(s).", |
| 56 | + ) |
| 57 | + path: str = Field( |
| 58 | + description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.", |
| 59 | + ) |
| 60 | + file_size: int = Field( |
| 61 | + description="The size of the model in bytes.", |
| 62 | + ) |
| 63 | + name: str = Field( |
| 64 | + description="Name of the model.", |
| 65 | + ) |
| 66 | + description: str | None = Field( |
| 67 | + default=None, |
| 68 | + description="Model description", |
| 69 | + ) |
| 70 | + source: str = Field( |
| 71 | + description="The original source of the model (path, URL or repo_id).", |
| 72 | + ) |
| 73 | + source_type: ModelSourceType = Field( |
| 74 | + description="The type of source", |
| 75 | + ) |
| 76 | + source_api_response: str | None = Field( |
| 77 | + default=None, |
| 78 | + description="The original API response from the source, as stringified JSON.", |
| 79 | + ) |
| 80 | + cover_image: str | None = Field( |
| 81 | + default=None, |
| 82 | + description="Url for image to preview model", |
| 83 | + ) |
| 84 | + usage_info: str | None = Field( |
| 85 | + default=None, |
| 86 | + description="Usage information for this model", |
| 87 | + ) |
| 88 | + |
| 89 | + CONFIG_CLASSES: ClassVar[set[Type["Config_Base"]]] = set() |
| 90 | + """Set of all non-abstract subclasses of Config_Base, for use during model probing. In other words, this is the set |
| 91 | + of all known model config types.""" |
| 92 | + |
| 93 | + model_config = ConfigDict( |
| 94 | + validate_assignment=True, |
| 95 | + json_schema_serialization_defaults_required=True, |
| 96 | + json_schema_mode_override="serialization", |
| 97 | + ) |
| 98 | + |
| 99 | + @classmethod |
| 100 | + def __init_subclass__(cls, **kwargs): |
| 101 | + super().__init_subclass__(**kwargs) |
| 102 | + # Register non-abstract subclasses so we can iterate over them later during model probing. Note that |
| 103 | + # isabstract() will return False if the class does not have any abstract methods, even if it inherits from ABC. |
| 104 | + # We must check for ABC lest we unintentionally register some abstract model config classes. |
| 105 | + if not isabstract(cls) and ABC not in cls.__bases__: |
| 106 | + cls.CONFIG_CLASSES.add(cls) |
| 107 | + |
| 108 | + @classmethod |
| 109 | + def __pydantic_init_subclass__(cls, **kwargs): |
| 110 | + # Ensure that model configs define 'base', 'type' and 'format' fields and provide defaults for them. Each |
| 111 | + # subclass is expected to represent a single combination of base, type and format. |
| 112 | + # |
| 113 | + # This pydantic dunder method is called after the pydantic model for a class is created. The normal |
| 114 | + # __init_subclass__ is too early to do this check. |
| 115 | + for name in ("type", "base", "format"): |
| 116 | + if name not in cls.model_fields: |
| 117 | + raise NotImplementedError(f"{cls.__name__} must define a '{name}' field") |
| 118 | + if cls.model_fields[name].default is PydanticUndefined: |
| 119 | + raise NotImplementedError(f"{cls.__name__} must define a default for the '{name}' field") |
| 120 | + |
| 121 | + @classmethod |
| 122 | + def get_tag(cls) -> Tag: |
| 123 | + """Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized, |
| 124 | + pydantic uses the tag to determine which subclass to instantiate. |
| 125 | +
|
| 126 | + The tag is a dot-separated string of the type, format, base and variant (if applicable). |
| 127 | + """ |
| 128 | + tag_strings: list[str] = [] |
| 129 | + for name in ("type", "format", "base", "variant"): |
| 130 | + if field := cls.model_fields.get(name): |
| 131 | + # The check in __pydantic_init_subclass__ ensures that type, format and base are always present with |
| 132 | + # defaults. variant does not require a default, but if it has one, we need to add it to the tag. We can |
| 133 | + # check for the presence of a default by seeing if it's not PydanticUndefined, a sentinel value used by |
| 134 | + # pydantic to indicate that no default was provided. |
| 135 | + if field.default is not PydanticUndefined: |
| 136 | + # We expect each of these fields has an Enum for its default; we want the value of the enum. |
| 137 | + tag_strings.append(field.default.value) |
| 138 | + return Tag(".".join(tag_strings)) |
| 139 | + |
| 140 | + @staticmethod |
| 141 | + def get_model_discriminator_value(v: Any) -> str: |
| 142 | + """Computes the discriminator value for a model config discriminated union.""" |
| 143 | + # This is called by pydantic during deserialization and serialization to determine which model the data |
| 144 | + # represents. It can get either a dict (during deserialization) or an instance of a Config_Base subclass |
| 145 | + # (during serialization). |
| 146 | + # |
| 147 | + # See: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator |
| 148 | + if isinstance(v, Config_Base): |
| 149 | + # We have an instance of a ModelConfigBase subclass - use its tag directly. |
| 150 | + return v.get_tag().tag |
| 151 | + if isinstance(v, dict): |
| 152 | + # We have a dict - attempt to compute a tag from its fields. |
| 153 | + tag_strings: list[str] = [] |
| 154 | + if type_ := v.get("type"): |
| 155 | + if isinstance(type_, Enum): |
| 156 | + type_ = str(type_.value) |
| 157 | + elif not isinstance(type_, str): |
| 158 | + raise TypeError("Model config dict 'type' field must be a string or Enum") |
| 159 | + tag_strings.append(type_) |
| 160 | + |
| 161 | + if format_ := v.get("format"): |
| 162 | + if isinstance(format_, Enum): |
| 163 | + format_ = str(format_.value) |
| 164 | + elif not isinstance(format_, str): |
| 165 | + raise TypeError("Model config dict 'format' field must be a string or Enum") |
| 166 | + tag_strings.append(format_) |
| 167 | + |
| 168 | + if base_ := v.get("base"): |
| 169 | + if isinstance(base_, Enum): |
| 170 | + base_ = str(base_.value) |
| 171 | + elif not isinstance(base_, str): |
| 172 | + raise TypeError("Model config dict 'base' field must be a string or Enum") |
| 173 | + tag_strings.append(base_) |
| 174 | + |
| 175 | + # Special case: CLIP Embed models also need the variant to distinguish them. |
| 176 | + if ( |
| 177 | + type_ == ModelType.CLIPEmbed.value |
| 178 | + and format_ == ModelFormat.Diffusers.value |
| 179 | + and base_ == BaseModelType.Any.value |
| 180 | + ): |
| 181 | + if variant_ := v.get("variant"): |
| 182 | + if isinstance(variant_, Enum): |
| 183 | + variant_ = variant_.value |
| 184 | + elif not isinstance(variant_, str): |
| 185 | + raise TypeError("Model config dict 'variant' field must be a string or Enum") |
| 186 | + tag_strings.append(variant_) |
| 187 | + else: |
| 188 | + raise ValueError("CLIP Embed model config dict must include a 'variant' field") |
| 189 | + |
| 190 | + return ".".join(tag_strings) |
| 191 | + else: |
| 192 | + raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance") |
| 193 | + |
| 194 | + @abstractmethod |
| 195 | + @classmethod |
| 196 | + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: |
| 197 | + """Given the model on disk and any override fields, attempt to construct an instance of this config class. |
| 198 | +
|
| 199 | + This method serves to identify whether the model on disk matches this config class, and if so, to extract any |
| 200 | + additional metadata needed to instantiate the config. |
| 201 | +
|
| 202 | + Implementations should raise a NotAMatchError if the model does not match this config class.""" |
| 203 | + raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}") |
| 204 | + |
| 205 | + |
| 206 | +class Checkpoint_Config_Base(ABC, BaseModel): |
| 207 | + """Base class for checkpoint-style models.""" |
| 208 | + |
| 209 | + config_path: str | None = Field( |
| 210 | + description="Path to the config for this model, if any.", |
| 211 | + default=None, |
| 212 | + ) |
| 213 | + |
| 214 | + |
| 215 | +class Diffusers_Config_Base(ABC, BaseModel): |
| 216 | + """Base class for diffusers-style models.""" |
| 217 | + |
| 218 | + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) |
| 219 | + repo_variant: ModelRepoVariant = Field(ModelRepoVariant.Default) |
| 220 | + |
| 221 | + @classmethod |
| 222 | + def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant: |
| 223 | + # get all files ending in .bin or .safetensors |
| 224 | + weight_files = list(mod.path.glob("**/*.safetensors")) |
| 225 | + weight_files.extend(list(mod.path.glob("**/*.bin"))) |
| 226 | + for x in weight_files: |
| 227 | + if ".fp16" in x.suffixes: |
| 228 | + return ModelRepoVariant.FP16 |
| 229 | + if "openvino_model" in x.name: |
| 230 | + return ModelRepoVariant.OpenVINO |
| 231 | + if "flax_model" in x.name: |
| 232 | + return ModelRepoVariant.Flax |
| 233 | + if x.suffix == ".onnx": |
| 234 | + return ModelRepoVariant.ONNX |
| 235 | + return ModelRepoVariant.Default |
| 236 | + |
| 237 | + |
| 238 | +class SubmodelDefinition(BaseModel): |
| 239 | + path_or_prefix: str |
| 240 | + model_type: ModelType |
| 241 | + variant: AnyVariant | None = None |
| 242 | + |
| 243 | + model_config = ConfigDict(protected_namespaces=()) |
0 commit comments