Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
7c72824
refactor: port MM probes to new api
psychedelicious Sep 23, 2025
8ae9716
feat(mm): port TIs to new API
psychedelicious Sep 23, 2025
8b6fe5c
tidy(mm): remove unused probes
psychedelicious Sep 23, 2025
cdcdecc
feat(mm): port spandrel to new API
psychedelicious Sep 23, 2025
12c3cbc
fix(mm): parsing for spandrel
psychedelicious Sep 23, 2025
7ab6042
fix(mm): loader for clip embed
psychedelicious Sep 23, 2025
1db1264
fix(mm): tis use existing weight_files method
psychedelicious Sep 23, 2025
82ffb58
feat(mm): port vae to new API
psychedelicious Sep 23, 2025
20a0231
fix(mm): vae class inheritance and config_path
psychedelicious Sep 23, 2025
c88fee6
tidy(mm): patcher types and import paths
psychedelicious Sep 23, 2025
5996e31
feat(mm): better errors when invalid model config found in db
psychedelicious Sep 23, 2025
8217fd9
feat(mm): port t5 to new API
psychedelicious Sep 23, 2025
1d3f6c4
feat(mm): make config_path optional
psychedelicious Sep 23, 2025
881f063
refactor(mm): simplify model classification process
psychedelicious Sep 24, 2025
049e9f2
refactor(mm): remove unused methods in config.py
psychedelicious Sep 24, 2025
8b6929b
refactor(mm): add model config parsing utils
psychedelicious Sep 24, 2025
4220657
fix(mm): abstractmethod bork
psychedelicious Sep 24, 2025
6c60e6d
tidy(mm): clarify that model id utils are private
psychedelicious Sep 24, 2025
b1780f9
fix(mm): fall back to UnknownModelConfig correctly
psychedelicious Sep 24, 2025
cfef478
feat(mm): port CLIPVisionDiffusersConfig to new api
psychedelicious Sep 24, 2025
4f4268e
feat(mm): port SigLIPDiffusersConfig to new api
psychedelicious Sep 24, 2025
01104f5
feat(mm): make match helpers more succint
psychedelicious Sep 24, 2025
6c66013
feat(mm): port flux redux to new api
psychedelicious Sep 24, 2025
20db2cb
feat(mm): port ip adapter to new api
psychedelicious Sep 24, 2025
f0e931c
tidy(mm): skip optimistic override handling for now
psychedelicious Sep 24, 2025
2813ec4
refactor(mm): continue iterating on config
psychedelicious Sep 25, 2025
e0d91ef
feat(mm): port flux "control lora" and t2i adapter to new api
psychedelicious Sep 25, 2025
5deb9bb
tidy(ui): use Extract to get model config types
psychedelicious Sep 25, 2025
07e99c9
fix(mm): t2i base determination
psychedelicious Sep 25, 2025
d27bef1
feat(mm): port cnet to new api
psychedelicious Sep 25, 2025
1268b23
refactor(mm): add config validation utils, make it all consistent and…
psychedelicious Sep 25, 2025
5f45a9c
feat(mm): wip port of main models to new api
psychedelicious Sep 25, 2025
7765c83
feat(mm): wip port of main models to new api
psychedelicious Sep 25, 2025
3a44fde
feat(mm): wip port of main models to new api
psychedelicious Sep 25, 2025
69efdc3
docs(mm): add todos
psychedelicious Sep 26, 2025
7765df4
tidy(mm): removed unused model merge class
psychedelicious Sep 29, 2025
9676cb8
feat(mm): wip port main models to new api
psychedelicious Sep 29, 2025
09449cf
tidy(mm): clean up model heuristic utils
psychedelicious Oct 1, 2025
d63348b
tidy(mm): clean up ModelOnDisk caching
psychedelicious Oct 1, 2025
bab7f62
tidy(mm): flux lora format util
psychedelicious Oct 1, 2025
935fafe
refactor(mm): make config classes narrow
psychedelicious Oct 1, 2025
17c5ad2
refactor(mm): diffusers loras
psychedelicious Oct 1, 2025
29087af
feat(mm): consistent naming for all model config classes
psychedelicious Oct 1, 2025
32a9ad1
fix(mm): tag generation & scattered probe fixes
psychedelicious Oct 1, 2025
508c488
tidy(mm): consistent class names
psychedelicious Oct 2, 2025
aea7e0f
refactor(mm): split configs into separate files
psychedelicious Oct 3, 2025
03b3191
docs(mm): add comments for identification utils
psychedelicious Oct 6, 2025
c0fff3a
chore(ui): typegen
psychedelicious Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
MainCheckpointConfig,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
)
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
Expand Down Expand Up @@ -741,9 +744,18 @@ async def convert_model(
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))

if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
if isinstance(
model_config,
(
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
),
):
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."
logger.error(msg)
raise HTTPException(400, msg)

with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/create_gradient_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase
from invokeai.backend.model_manager.config import Main_Config_Base
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor

Expand Down Expand Up @@ -182,7 +182,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
assert isinstance(main_model_config, Main_Config_Base)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = dilated_mask_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
Expand Down Expand Up @@ -232,7 +232,7 @@ def _run_diffusion(
)

transformer_config = context.models.get_config(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
is_schnell = transformer_config.variant is FluxVariantType.Schnell

# Calculate the timestep schedule.
timesteps = get_schedule(
Expand Down Expand Up @@ -277,7 +277,7 @@ def _run_diffusion(

# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
img_cond: torch.Tensor | None = None
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
is_flux_fill = transformer_config.variant is FluxVariantType.DevFill
if is_flux_fill:
img_cond = self._prep_flux_fill_img_cond(
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
Expand Down
5 changes: 2 additions & 3 deletions invokeai/app/invocations/flux_ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
IPAdapter_Checkpoint_FLUX_Config,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType

Expand Down Expand Up @@ -68,7 +67,7 @@ def validate_begin_end_step_percent(self) -> Self:
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, IPAdapter_Checkpoint_FLUX_Config)

# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
Expand Down
8 changes: 4 additions & 4 deletions invokeai/app/invocations/flux_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.flux.util import get_flux_max_seq_length
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
Checkpoint_Config_Base,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType

Expand Down Expand Up @@ -87,12 +87,12 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)

transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
assert isinstance(transformer_config, Checkpoint_Config_Base)

return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
max_seq_len=get_flux_max_seq_length(transformer_config.variant),
)
8 changes: 4 additions & 4 deletions invokeai/app/invocations/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
IPAdapter_Checkpoint_Config_Base,
IPAdapter_InvokeAI_Config_Base,
)
from invokeai.backend.model_manager.starter_models import (
StarterModel,
Expand Down Expand Up @@ -123,9 +123,9 @@ def validate_begin_end_step_percent(self) -> Self:
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base))

if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
Expand Down
5 changes: 3 additions & 2 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class ModelIdentifierField(BaseModel):
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
submodel_type: SubModelType | None = Field(
description="The submodel to load, if this is a main model",
default=None,
)

@classmethod
Expand Down
27 changes: 11 additions & 16 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import threading
import time
from copy import deepcopy
from pathlib import Path
from queue import Empty, Queue
from shutil import move, rmtree
Expand Down Expand Up @@ -36,11 +37,10 @@
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
Checkpoint_Config_Base,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
Expand Down Expand Up @@ -370,6 +370,8 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102
model_path = self.app_config.models_path / model.path
if model_path.is_file() or model_path.is_symlink():
model_path.unlink()
assert model_path.parent != self.app_config.models_path
os.rmdir(model_path.parent)
elif model_path.is_dir():
rmtree(model_path)
self.unregister(key)
Expand Down Expand Up @@ -598,18 +600,11 @@ def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()

# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
return ModelConfigFactory.from_model_on_disk(
mod=model_path,
overrides=deepcopy(fields),
hash_algo=hash_algo,
)

def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
Expand All @@ -630,7 +625,7 @@ def _register(

info.path = model_path.as_posix()

if isinstance(info, CheckpointConfigBase):
if isinstance(info, Checkpoint_Config_Base) and info.config_path is not None:
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
# invoke-managed legacy config dir, we use a relative path.
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
Expand Down
5 changes: 4 additions & 1 deletion invokeai/app/services/model_records/model_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
FluxVariantType,
ModelFormat,
ModelSourceType,
ModelType,
Expand Down Expand Up @@ -90,7 +91,9 @@ class ModelRecordChanges(BaseModelExcludeNull):

# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field(
description="The variant of the model.", default=None
)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
Expand Down
28 changes: 24 additions & 4 deletions invokeai/app/services/model_records/model_records_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,25 @@ def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
with self._db.transaction() as cursor:
record = self.get_model(key)

# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
# The changes may mean the model config class changes. So we need to:
#
# 1. convert the existing record to a dict
# 2. apply the changes to the dict
# 3. create a new model config from the updated dict
#
# This way we ensure that the update does not inadvertently create an invalid model config.

# 1. convert the existing record to a dict
record_as_dict = record.model_dump()

# 2. apply the changes to the dict
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
record_as_dict[field_name] = getattr(changes, field_name)

# 3. create a new model config from the updated dict
record = ModelConfigFactory.make_config(record_as_dict)

# If we get this far, the updated model config is valid, so we can save it to the database.
json_serialized = record.model_dump_json()

cursor.execute(
Expand Down Expand Up @@ -277,14 +292,19 @@ def search_by_attr(
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError:
except pydantic.ValidationError as e:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0]
try:
name = json.loads(row[0]).get("name", "<unknown>")
except Exception:
name = "<unknown>"
self._logger.warning(
f"Found an invalid model config in the database. Ignoring this model. ({row_data})"
f"Skipping invalid model config in the database with name {name}. Ignoring this model. ({row_data})"
)
self._logger.warning(f"Validation error: {e}")
else:
results.append(model_config)

Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from invokeai.app.util.step_callback import diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigBase,
Config_Base,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
Expand Down Expand Up @@ -558,7 +558,7 @@ def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path
The absolute path to the model.
"""

model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path)
model_path = Path(config_or_path.path) if isinstance(config_or_path, Config_Base) else Path(config_or_path)

if model_path.is_absolute():
return model_path.resolve()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.backend.model_manager.config import AnyModelConfig, AnyModelConfigValidator
from invokeai.backend.model_manager.config import AnyModelConfigValidator


class NormalizeResult(NamedTuple):
Expand All @@ -30,7 +30,7 @@ def __call__(self, cursor: sqlite3.Cursor) -> None:
for model_id, config_json in rows:
try:
# Get the model config as a pydantic object
config = self._load_model_config(config_json)
config = AnyModelConfigValidator.validate_json(config_json)
except ValidationError:
# This could happen if the config schema changed in a way that makes old configs invalid. Unlikely
# for users, more likely for devs testing out migration paths.
Expand Down Expand Up @@ -216,11 +216,6 @@ def _prune_empty_directories(self) -> None:

self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir)

def _load_model_config(self, config_json: str) -> AnyModelConfig:
# The typing of the validator says it returns Unknown, but it's really a AnyModelConfig. This utility function
# just makes that clear.
return AnyModelConfigValidator.validate_json(config_json)


def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""Builds the migration object for migrating from version 21 to version 22.
Expand Down
8 changes: 8 additions & 0 deletions invokeai/app/util/custom_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.model_manager.config import AnyModelConfigValidator
from invokeai.backend.util.logging import InvokeAILogger

logger = InvokeAILogger.get_logger()
Expand Down Expand Up @@ -115,6 +116,13 @@ def openapi() -> dict[str, Any]:
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])

any_model_config_schema = AnyModelConfigValidator.json_schema(
mode="serialization",
ref_template="#/components/schemas/{model}",
)
move_defs_to_top_level(openapi_schema, any_model_config_schema)
openapi_schema["components"]["schemas"]["AnyModelConfig"] = any_model_config_schema

if post_transform is not None:
openapi_schema = post_transform(openapi_schema)

Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/flux/controlnet/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from invokeai.backend.flux.model import FluxParams


def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
def is_state_dict_xlabs_controlnet(sd: dict[str | int, Any]) -> bool:
"""Is the state dict for an XLabs ControlNet model?

This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
Expand All @@ -25,7 +25,7 @@ def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
return False


def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
def is_state_dict_instantx_controlnet(sd: dict[str | int, Any]) -> bool:
"""Is the state dict for an InstantX ControlNet model?

This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
Expand Down
Loading
Loading