Skip to content

Commit aea7e0f

Browse files
refactor(mm): split configs into separate files
1 parent 508c488 commit aea7e0f

20 files changed

+2982
-0
lines changed

invokeai/backend/model_manager/configs/__init__.py

Whitespace-only changes.
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
from abc import ABC, abstractmethod
2+
from enum import Enum
3+
from inspect import isabstract
4+
from typing import (
5+
TYPE_CHECKING,
6+
Any,
7+
ClassVar,
8+
Literal,
9+
Self,
10+
Type,
11+
)
12+
13+
from pydantic import BaseModel, ConfigDict, Field, Tag
14+
from pydantic_core import PydanticUndefined
15+
16+
from invokeai.app.util.misc import uuid_string
17+
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
18+
from invokeai.backend.model_manager.taxonomy import (
19+
AnyVariant,
20+
BaseModelType,
21+
ModelFormat,
22+
ModelRepoVariant,
23+
ModelSourceType,
24+
ModelType,
25+
)
26+
27+
if TYPE_CHECKING:
28+
pass
29+
30+
31+
class Config_Base(ABC, BaseModel):
32+
"""
33+
Abstract base class for model configurations. A model config describes a specific combination of model base, type and
34+
format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format
35+
would have base=sd-1, type=main, format=checkpoint.
36+
37+
To create a new config type, inherit from this class and implement its interface:
38+
- Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be
39+
called during model installation to determine the correct config class for a model.
40+
- Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A
41+
default must be provided for each of these fields.
42+
43+
If multiple combinations of base, type and format need to be supported, create a separate subclass for each.
44+
45+
See MinimalConfigExample in test_model_probe.py for an example implementation.
46+
"""
47+
48+
# These fields are common to all model configs.
49+
50+
key: str = Field(
51+
default_factory=uuid_string,
52+
description="A unique key for this model.",
53+
)
54+
hash: str = Field(
55+
description="The hash of the model file(s).",
56+
)
57+
path: str = Field(
58+
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.",
59+
)
60+
file_size: int = Field(
61+
description="The size of the model in bytes.",
62+
)
63+
name: str = Field(
64+
description="Name of the model.",
65+
)
66+
description: str | None = Field(
67+
default=None,
68+
description="Model description",
69+
)
70+
source: str = Field(
71+
description="The original source of the model (path, URL or repo_id).",
72+
)
73+
source_type: ModelSourceType = Field(
74+
description="The type of source",
75+
)
76+
source_api_response: str | None = Field(
77+
default=None,
78+
description="The original API response from the source, as stringified JSON.",
79+
)
80+
cover_image: str | None = Field(
81+
default=None,
82+
description="Url for image to preview model",
83+
)
84+
usage_info: str | None = Field(
85+
default=None,
86+
description="Usage information for this model",
87+
)
88+
89+
CONFIG_CLASSES: ClassVar[set[Type["Config_Base"]]] = set()
90+
"""Set of all non-abstract subclasses of Config_Base, for use during model probing. In other words, this is the set
91+
of all known model config types."""
92+
93+
model_config = ConfigDict(
94+
validate_assignment=True,
95+
json_schema_serialization_defaults_required=True,
96+
json_schema_mode_override="serialization",
97+
)
98+
99+
@classmethod
100+
def __init_subclass__(cls, **kwargs):
101+
super().__init_subclass__(**kwargs)
102+
# Register non-abstract subclasses so we can iterate over them later during model probing. Note that
103+
# isabstract() will return False if the class does not have any abstract methods, even if it inherits from ABC.
104+
# We must check for ABC lest we unintentionally register some abstract model config classes.
105+
if not isabstract(cls) and ABC not in cls.__bases__:
106+
cls.CONFIG_CLASSES.add(cls)
107+
108+
@classmethod
109+
def __pydantic_init_subclass__(cls, **kwargs):
110+
# Ensure that model configs define 'base', 'type' and 'format' fields and provide defaults for them. Each
111+
# subclass is expected to represent a single combination of base, type and format.
112+
#
113+
# This pydantic dunder method is called after the pydantic model for a class is created. The normal
114+
# __init_subclass__ is too early to do this check.
115+
for name in ("type", "base", "format"):
116+
if name not in cls.model_fields:
117+
raise NotImplementedError(f"{cls.__name__} must define a '{name}' field")
118+
if cls.model_fields[name].default is PydanticUndefined:
119+
raise NotImplementedError(f"{cls.__name__} must define a default for the '{name}' field")
120+
121+
@classmethod
122+
def get_tag(cls) -> Tag:
123+
"""Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized,
124+
pydantic uses the tag to determine which subclass to instantiate.
125+
126+
The tag is a dot-separated string of the type, format, base and variant (if applicable).
127+
"""
128+
tag_strings: list[str] = []
129+
for name in ("type", "format", "base", "variant"):
130+
if field := cls.model_fields.get(name):
131+
# The check in __pydantic_init_subclass__ ensures that type, format and base are always present with
132+
# defaults. variant does not require a default, but if it has one, we need to add it to the tag. We can
133+
# check for the presence of a default by seeing if it's not PydanticUndefined, a sentinel value used by
134+
# pydantic to indicate that no default was provided.
135+
if field.default is not PydanticUndefined:
136+
# We expect each of these fields has an Enum for its default; we want the value of the enum.
137+
tag_strings.append(field.default.value)
138+
return Tag(".".join(tag_strings))
139+
140+
@staticmethod
141+
def get_model_discriminator_value(v: Any) -> str:
142+
"""Computes the discriminator value for a model config discriminated union."""
143+
# This is called by pydantic during deserialization and serialization to determine which model the data
144+
# represents. It can get either a dict (during deserialization) or an instance of a Config_Base subclass
145+
# (during serialization).
146+
#
147+
# See: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
148+
if isinstance(v, Config_Base):
149+
# We have an instance of a ModelConfigBase subclass - use its tag directly.
150+
return v.get_tag().tag
151+
if isinstance(v, dict):
152+
# We have a dict - attempt to compute a tag from its fields.
153+
tag_strings: list[str] = []
154+
if type_ := v.get("type"):
155+
if isinstance(type_, Enum):
156+
type_ = str(type_.value)
157+
elif not isinstance(type_, str):
158+
raise TypeError("Model config dict 'type' field must be a string or Enum")
159+
tag_strings.append(type_)
160+
161+
if format_ := v.get("format"):
162+
if isinstance(format_, Enum):
163+
format_ = str(format_.value)
164+
elif not isinstance(format_, str):
165+
raise TypeError("Model config dict 'format' field must be a string or Enum")
166+
tag_strings.append(format_)
167+
168+
if base_ := v.get("base"):
169+
if isinstance(base_, Enum):
170+
base_ = str(base_.value)
171+
elif not isinstance(base_, str):
172+
raise TypeError("Model config dict 'base' field must be a string or Enum")
173+
tag_strings.append(base_)
174+
175+
# Special case: CLIP Embed models also need the variant to distinguish them.
176+
if (
177+
type_ == ModelType.CLIPEmbed.value
178+
and format_ == ModelFormat.Diffusers.value
179+
and base_ == BaseModelType.Any.value
180+
):
181+
if variant_ := v.get("variant"):
182+
if isinstance(variant_, Enum):
183+
variant_ = variant_.value
184+
elif not isinstance(variant_, str):
185+
raise TypeError("Model config dict 'variant' field must be a string or Enum")
186+
tag_strings.append(variant_)
187+
else:
188+
raise ValueError("CLIP Embed model config dict must include a 'variant' field")
189+
190+
return ".".join(tag_strings)
191+
else:
192+
raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance")
193+
194+
@abstractmethod
195+
@classmethod
196+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
197+
"""Given the model on disk and any override fields, attempt to construct an instance of this config class.
198+
199+
This method serves to identify whether the model on disk matches this config class, and if so, to extract any
200+
additional metadata needed to instantiate the config.
201+
202+
Implementations should raise a NotAMatchError if the model does not match this config class."""
203+
raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}")
204+
205+
206+
class Checkpoint_Config_Base(ABC, BaseModel):
207+
"""Base class for checkpoint-style models."""
208+
209+
config_path: str | None = Field(
210+
description="Path to the config for this model, if any.",
211+
default=None,
212+
)
213+
214+
215+
class Diffusers_Config_Base(ABC, BaseModel):
216+
"""Base class for diffusers-style models."""
217+
218+
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
219+
repo_variant: ModelRepoVariant = Field(ModelRepoVariant.Default)
220+
221+
@classmethod
222+
def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant:
223+
# get all files ending in .bin or .safetensors
224+
weight_files = list(mod.path.glob("**/*.safetensors"))
225+
weight_files.extend(list(mod.path.glob("**/*.bin")))
226+
for x in weight_files:
227+
if ".fp16" in x.suffixes:
228+
return ModelRepoVariant.FP16
229+
if "openvino_model" in x.name:
230+
return ModelRepoVariant.OpenVINO
231+
if "flax_model" in x.name:
232+
return ModelRepoVariant.Flax
233+
if x.suffix == ".onnx":
234+
return ModelRepoVariant.ONNX
235+
return ModelRepoVariant.Default
236+
237+
238+
class SubmodelDefinition(BaseModel):
239+
path_or_prefix: str
240+
model_type: ModelType
241+
variant: AnyVariant | None = None
242+
243+
model_config = ConfigDict(protected_namespaces=())
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import (
2+
Literal,
3+
Self,
4+
)
5+
6+
from pydantic import Field
7+
from typing_extensions import Any
8+
9+
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
10+
from invokeai.backend.model_manager.configs.identification_utils import (
11+
NotAMatchError,
12+
get_config_dict_or_raise,
13+
raise_for_class_name,
14+
raise_for_override_fields,
15+
raise_if_not_dir,
16+
)
17+
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
18+
from invokeai.backend.model_manager.taxonomy import (
19+
BaseModelType,
20+
ClipVariantType,
21+
ModelFormat,
22+
ModelType,
23+
)
24+
25+
26+
def get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None:
27+
try:
28+
hidden_size = config.get("hidden_size")
29+
match hidden_size:
30+
case 1280:
31+
return ClipVariantType.G
32+
case 768:
33+
return ClipVariantType.L
34+
case _:
35+
return None
36+
except Exception:
37+
return None
38+
39+
40+
class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base):
41+
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
42+
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
43+
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
44+
45+
@classmethod
46+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
47+
raise_if_not_dir(mod)
48+
49+
raise_for_override_fields(cls, override_fields)
50+
51+
raise_for_class_name(
52+
{
53+
mod.path / "config.json",
54+
mod.path / "text_encoder" / "config.json",
55+
},
56+
{
57+
"CLIPModel",
58+
"CLIPTextModel",
59+
"CLIPTextModelWithProjection",
60+
},
61+
)
62+
63+
cls._validate_variant(mod)
64+
65+
return cls(**override_fields)
66+
67+
@classmethod
68+
def _validate_variant(cls, mod: ModelOnDisk) -> None:
69+
"""Raise `NotAMatch` if the model variant does not match this config class."""
70+
expected_variant = cls.model_fields["variant"].default
71+
config = get_config_dict_or_raise(
72+
{
73+
mod.path / "config.json",
74+
mod.path / "text_encoder" / "config.json",
75+
},
76+
)
77+
recognized_variant = get_clip_variant_type_from_config(config)
78+
79+
if recognized_variant is None:
80+
raise NotAMatchError("unable to determine CLIP variant from config")
81+
82+
if expected_variant is not recognized_variant:
83+
raise NotAMatchError(f"variant is {recognized_variant}, not {expected_variant}")
84+
85+
86+
class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
87+
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
88+
89+
90+
class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
91+
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import (
2+
Literal,
3+
Self,
4+
)
5+
6+
from pydantic import Field
7+
from typing_extensions import Any
8+
9+
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
10+
from invokeai.backend.model_manager.configs.identification_utils import (
11+
common_config_paths,
12+
raise_for_class_name,
13+
raise_for_override_fields,
14+
raise_if_not_dir,
15+
)
16+
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
17+
from invokeai.backend.model_manager.taxonomy import (
18+
BaseModelType,
19+
ModelFormat,
20+
ModelType,
21+
)
22+
23+
24+
class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
25+
"""Model config for CLIPVision."""
26+
27+
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
28+
type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
29+
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
30+
31+
@classmethod
32+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
33+
raise_if_not_dir(mod)
34+
35+
raise_for_override_fields(cls, override_fields)
36+
37+
raise_for_class_name(
38+
common_config_paths(mod.path),
39+
{
40+
"CLIPVisionModelWithProjection",
41+
},
42+
)
43+
44+
return cls(**override_fields)

0 commit comments

Comments
 (0)