Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 115 additions & 49 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Optional,
Set,
Tuple,
Type,
Union,
no_type_check,
)
Expand Down Expand Up @@ -242,6 +243,28 @@ async def recover_sub_pool(self, address):
def default_uid(cls) -> str:
return "worker"

def _get_spec_dicts_with_cache_status(
self, model_family: Any, cache_manager_cls: Type
) -> Tuple[List[dict], List[str]]:
"""
Build model_specs with cache_status and collect download_hubs.
"""

specs: List[dict] = []
download_hubs: List[str] = []
for spec in model_family.model_specs:
model_hub = spec.model_hub
if model_hub not in download_hubs:
download_hubs.append(model_hub)

family_copy = model_family.copy()
family_copy.model_specs = [spec]
cache_manager = cache_manager_cls(family_copy)
specs.append(
{**spec.dict(), "cache_status": cache_manager.get_cache_status()}
)
return specs, download_hubs

async def __post_create__(self):
from ..model.audio import (
CustomAudioModelFamilyV2,
Expand Down Expand Up @@ -813,21 +836,18 @@ def sort_helper(item):

if model_type == "LLM":
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
from ..model.llm.cache_manager import LLMCacheManager

# Add built-in LLM families
for family in BUILTIN_LLM_FAMILIES:
if detailed:
# Remove duplicate hubs while preserving order
seen_hubs = set()
download_hubs = []
for spec in family.model_specs:
if spec.model_hub not in seen_hubs:
seen_hubs.add(spec.model_hub)
download_hubs.append(spec.model_hub)

specs, download_hubs = self._get_spec_dicts_with_cache_status(
family, LLMCacheManager
)
ret.append(
{
**family.dict(),
"model_specs": specs,
"is_builtin": True,
"download_hubs": download_hubs,
}
Expand All @@ -838,17 +858,13 @@ def sort_helper(item):
# Add user-defined LLM families
for family in get_user_defined_llm_families():
if detailed:
# Remove duplicate hubs while preserving order
seen_hubs = set()
download_hubs = []
for spec in family.model_specs:
if spec.model_hub not in seen_hubs:
seen_hubs.add(spec.model_hub)
download_hubs.append(spec.model_hub)

specs, download_hubs = self._get_spec_dicts_with_cache_status(
family, LLMCacheManager
)
ret.append(
{
**family.dict(),
"model_specs": specs,
"is_builtin": False,
"download_hubs": download_hubs,
}
Expand All @@ -860,23 +876,20 @@ def sort_helper(item):
return ret
elif model_type == "embedding":
from ..model.embedding import BUILTIN_EMBEDDING_MODELS
from ..model.embedding.cache_manager import EmbeddingCacheManager
from ..model.embedding.custom import get_user_defined_embeddings

# Add built-in embedding models
for model_name, family_list in BUILTIN_EMBEDDING_MODELS.items():
for family in family_list:
if detailed:
# Remove duplicate hubs while preserving order
seen_hubs = set()
download_hubs = []
for spec in family.model_specs:
if spec.model_hub not in seen_hubs:
seen_hubs.add(spec.model_hub)
download_hubs.append(spec.model_hub)

specs, download_hubs = self._get_spec_dicts_with_cache_status(
family, EmbeddingCacheManager
)
ret.append(
{
**family.dict(),
"model_specs": specs,
"is_builtin": True,
"download_hubs": download_hubs,
}
Expand All @@ -887,17 +900,13 @@ def sort_helper(item):
# Add user-defined embedding models
for model_spec in get_user_defined_embeddings():
if detailed:
# Remove duplicate hubs while preserving order
seen_hubs = set()
download_hubs = []
for spec in model_spec.model_specs:
if spec.model_hub not in seen_hubs:
seen_hubs.add(spec.model_hub)
download_hubs.append(spec.model_hub)

specs, download_hubs = self._get_spec_dicts_with_cache_status(
model_spec, EmbeddingCacheManager
)
ret.append(
{
**model_spec.dict(),
"model_specs": specs,
"is_builtin": False,
"download_hubs": download_hubs,
}
Expand All @@ -911,15 +920,26 @@ def sort_helper(item):
return ret
elif model_type == "image":
from ..model.image import BUILTIN_IMAGE_MODELS
from ..model.image.cache_manager import ImageCacheManager
from ..model.image.custom import get_user_defined_images

# Add built-in image models (BUILTIN_IMAGE_MODELS contains model_name -> families list)
for model_name, families in BUILTIN_IMAGE_MODELS.items():
for family in families:
if detailed:
cache_manager = ImageCacheManager(family)
model_specs = [
{
"model_format": "pytorch",
"model_hub": family.model_hub,
"model_id": family.model_id,
"cache_status": cache_manager.get_cache_status(),
}
]
ret.append(
{
**family.dict(),
"model_specs": model_specs,
"is_builtin": True,
"download_hubs": [family.model_hub],
}
Expand All @@ -930,9 +950,19 @@ def sort_helper(item):
# Add user-defined image models
for model_spec in get_user_defined_images():
if detailed:
cache_manager = ImageCacheManager(model_spec)
model_specs = [
{
"model_format": "pytorch",
"model_hub": model_spec.model_hub,
"model_id": model_spec.model_id,
"cache_status": cache_manager.get_cache_status(),
}
]
ret.append(
{
**model_spec.dict(),
"model_specs": model_specs,
"is_builtin": False,
"download_hubs": [model_spec.model_hub],
}
Expand All @@ -947,14 +977,25 @@ def sort_helper(item):
elif model_type == "audio":
from ..model.audio import BUILTIN_AUDIO_MODELS
from ..model.audio.custom import get_user_defined_audios
from ..model.cache_manager import CacheManager

# Add built-in audio models (BUILTIN_AUDIO_MODELS contains model_name -> families list)
for model_name, families in BUILTIN_AUDIO_MODELS.items():
for family in families:
if detailed:
audio_cache_manager = CacheManager(family)
model_specs = [
{
"model_format": "pytorch",
"model_hub": family.model_hub,
"model_id": family.model_id,
"cache_status": audio_cache_manager.get_cache_status(),
}
]
ret.append(
{
**family.dict(),
"model_specs": model_specs,
"is_builtin": True,
"download_hubs": [family.model_hub],
}
Expand All @@ -965,9 +1006,19 @@ def sort_helper(item):
# Add user-defined audio models
for model_spec in get_user_defined_audios():
if detailed:
audio_cache_manager = CacheManager(model_spec)
model_specs = [
{
"model_format": "pytorch",
"model_hub": model_spec.model_hub,
"model_id": model_spec.model_id,
"cache_status": audio_cache_manager.get_cache_status(),
}
]
ret.append(
{
**model_spec.dict(),
"model_specs": model_specs,
"is_builtin": False,
"download_hubs": [model_spec.model_hub],
}
Expand All @@ -980,15 +1031,26 @@ def sort_helper(item):
ret.sort(key=sort_helper)
return ret
elif model_type == "video":
from ..model.cache_manager import CacheManager
from ..model.video import BUILTIN_VIDEO_MODELS

# Add built-in video models (BUILTIN_VIDEO_MODELS contains model_name -> families list)
for model_name, families in BUILTIN_VIDEO_MODELS.items():
for family in families:
if detailed:
video_cache_manager = CacheManager(family)
model_specs = [
{
"model_format": "pytorch",
"model_hub": family.model_hub,
"model_id": family.model_id,
"cache_status": video_cache_manager.get_cache_status(),
}
]
ret.append(
{
**family.dict(),
"model_specs": model_specs,
"is_builtin": True,
"download_hubs": [family.model_hub],
}
Expand All @@ -1000,23 +1062,20 @@ def sort_helper(item):
return ret
elif model_type == "rerank":
from ..model.rerank import BUILTIN_RERANK_MODELS
from ..model.rerank.cache_manager import RerankCacheManager
from ..model.rerank.custom import get_user_defined_reranks

# Add built-in rerank models (BUILTIN_RERANK_MODELS contains model_name -> family_list list)
for model_name, family_list in BUILTIN_RERANK_MODELS.items():
for family in family_list:
if detailed:
# Remove duplicate hubs while preserving order
seen_hubs = set()
download_hubs = []
for spec in family.model_specs:
if spec.model_hub not in seen_hubs:
seen_hubs.add(spec.model_hub)
download_hubs.append(spec.model_hub)

specs, download_hubs = self._get_spec_dicts_with_cache_status(
family, RerankCacheManager
)
ret.append(
{
**family.dict(),
"model_specs": specs,
"is_builtin": True,
"download_hubs": download_hubs,
}
Expand All @@ -1027,17 +1086,13 @@ def sort_helper(item):
# Add user-defined rerank models
for model_spec in get_user_defined_reranks():
if detailed:
# Remove duplicate hubs while preserving order
seen_hubs = set()
download_hubs = []
for spec in model_spec.model_specs:
if spec.model_hub not in seen_hubs:
seen_hubs.add(spec.model_hub)
download_hubs.append(spec.model_hub)

specs, download_hubs = self._get_spec_dicts_with_cache_status(
model_spec, RerankCacheManager
)
ret.append(
{
**model_spec.dict(),
"model_specs": specs,
"is_builtin": False,
"download_hubs": download_hubs,
}
Expand All @@ -1055,7 +1110,18 @@ def sort_helper(item):
ret = []

for model_spec in get_flexible_models():
ret.append({"model_name": model_spec.model_name, "is_builtin": False})
if detailed:
ret.append(
{
**model_spec.dict(),
"cache_status": True,
"is_builtin": False,
}
)
else:
ret.append(
{"model_name": model_spec.model_name, "is_builtin": False}
)

ret.sort(key=sort_helper)
return ret
Expand Down
Loading