Skip to content

Commit dafc98f

Browse files
authored
BUG: fix manage cache models missing (#4329)
1 parent 5654a9c commit dafc98f

File tree

1 file changed

+115
-49
lines changed

1 file changed

+115
-49
lines changed

xinference/core/worker.py

Lines changed: 115 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Optional,
3636
Set,
3737
Tuple,
38+
Type,
3839
Union,
3940
no_type_check,
4041
)
@@ -242,6 +243,28 @@ async def recover_sub_pool(self, address):
242243
def default_uid(cls) -> str:
243244
return "worker"
244245

246+
def _get_spec_dicts_with_cache_status(
247+
self, model_family: Any, cache_manager_cls: Type
248+
) -> Tuple[List[dict], List[str]]:
249+
"""
250+
Build model_specs with cache_status and collect download_hubs.
251+
"""
252+
253+
specs: List[dict] = []
254+
download_hubs: List[str] = []
255+
for spec in model_family.model_specs:
256+
model_hub = spec.model_hub
257+
if model_hub not in download_hubs:
258+
download_hubs.append(model_hub)
259+
260+
family_copy = model_family.copy()
261+
family_copy.model_specs = [spec]
262+
cache_manager = cache_manager_cls(family_copy)
263+
specs.append(
264+
{**spec.dict(), "cache_status": cache_manager.get_cache_status()}
265+
)
266+
return specs, download_hubs
267+
245268
async def __post_create__(self):
246269
from ..model.audio import (
247270
CustomAudioModelFamilyV2,
@@ -813,21 +836,18 @@ def sort_helper(item):
813836

814837
if model_type == "LLM":
815838
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
839+
from ..model.llm.cache_manager import LLMCacheManager
816840

817841
# Add built-in LLM families
818842
for family in BUILTIN_LLM_FAMILIES:
819843
if detailed:
820-
# Remove duplicate hubs while preserving order
821-
seen_hubs = set()
822-
download_hubs = []
823-
for spec in family.model_specs:
824-
if spec.model_hub not in seen_hubs:
825-
seen_hubs.add(spec.model_hub)
826-
download_hubs.append(spec.model_hub)
827-
844+
specs, download_hubs = self._get_spec_dicts_with_cache_status(
845+
family, LLMCacheManager
846+
)
828847
ret.append(
829848
{
830849
**family.dict(),
850+
"model_specs": specs,
831851
"is_builtin": True,
832852
"download_hubs": download_hubs,
833853
}
@@ -838,17 +858,13 @@ def sort_helper(item):
838858
# Add user-defined LLM families
839859
for family in get_user_defined_llm_families():
840860
if detailed:
841-
# Remove duplicate hubs while preserving order
842-
seen_hubs = set()
843-
download_hubs = []
844-
for spec in family.model_specs:
845-
if spec.model_hub not in seen_hubs:
846-
seen_hubs.add(spec.model_hub)
847-
download_hubs.append(spec.model_hub)
848-
861+
specs, download_hubs = self._get_spec_dicts_with_cache_status(
862+
family, LLMCacheManager
863+
)
849864
ret.append(
850865
{
851866
**family.dict(),
867+
"model_specs": specs,
852868
"is_builtin": False,
853869
"download_hubs": download_hubs,
854870
}
@@ -860,23 +876,20 @@ def sort_helper(item):
860876
return ret
861877
elif model_type == "embedding":
862878
from ..model.embedding import BUILTIN_EMBEDDING_MODELS
879+
from ..model.embedding.cache_manager import EmbeddingCacheManager
863880
from ..model.embedding.custom import get_user_defined_embeddings
864881

865882
# Add built-in embedding models
866883
for model_name, family_list in BUILTIN_EMBEDDING_MODELS.items():
867884
for family in family_list:
868885
if detailed:
869-
# Remove duplicate hubs while preserving order
870-
seen_hubs = set()
871-
download_hubs = []
872-
for spec in family.model_specs:
873-
if spec.model_hub not in seen_hubs:
874-
seen_hubs.add(spec.model_hub)
875-
download_hubs.append(spec.model_hub)
876-
886+
specs, download_hubs = self._get_spec_dicts_with_cache_status(
887+
family, EmbeddingCacheManager
888+
)
877889
ret.append(
878890
{
879891
**family.dict(),
892+
"model_specs": specs,
880893
"is_builtin": True,
881894
"download_hubs": download_hubs,
882895
}
@@ -887,17 +900,13 @@ def sort_helper(item):
887900
# Add user-defined embedding models
888901
for model_spec in get_user_defined_embeddings():
889902
if detailed:
890-
# Remove duplicate hubs while preserving order
891-
seen_hubs = set()
892-
download_hubs = []
893-
for spec in model_spec.model_specs:
894-
if spec.model_hub not in seen_hubs:
895-
seen_hubs.add(spec.model_hub)
896-
download_hubs.append(spec.model_hub)
897-
903+
specs, download_hubs = self._get_spec_dicts_with_cache_status(
904+
model_spec, EmbeddingCacheManager
905+
)
898906
ret.append(
899907
{
900908
**model_spec.dict(),
909+
"model_specs": specs,
901910
"is_builtin": False,
902911
"download_hubs": download_hubs,
903912
}
@@ -911,15 +920,26 @@ def sort_helper(item):
911920
return ret
912921
elif model_type == "image":
913922
from ..model.image import BUILTIN_IMAGE_MODELS
923+
from ..model.image.cache_manager import ImageCacheManager
914924
from ..model.image.custom import get_user_defined_images
915925

916926
# Add built-in image models (BUILTIN_IMAGE_MODELS contains model_name -> families list)
917927
for model_name, families in BUILTIN_IMAGE_MODELS.items():
918928
for family in families:
919929
if detailed:
930+
cache_manager = ImageCacheManager(family)
931+
model_specs = [
932+
{
933+
"model_format": "pytorch",
934+
"model_hub": family.model_hub,
935+
"model_id": family.model_id,
936+
"cache_status": cache_manager.get_cache_status(),
937+
}
938+
]
920939
ret.append(
921940
{
922941
**family.dict(),
942+
"model_specs": model_specs,
923943
"is_builtin": True,
924944
"download_hubs": [family.model_hub],
925945
}
@@ -930,9 +950,19 @@ def sort_helper(item):
930950
# Add user-defined image models
931951
for model_spec in get_user_defined_images():
932952
if detailed:
953+
cache_manager = ImageCacheManager(model_spec)
954+
model_specs = [
955+
{
956+
"model_format": "pytorch",
957+
"model_hub": model_spec.model_hub,
958+
"model_id": model_spec.model_id,
959+
"cache_status": cache_manager.get_cache_status(),
960+
}
961+
]
933962
ret.append(
934963
{
935964
**model_spec.dict(),
965+
"model_specs": model_specs,
936966
"is_builtin": False,
937967
"download_hubs": [model_spec.model_hub],
938968
}
@@ -947,14 +977,25 @@ def sort_helper(item):
947977
elif model_type == "audio":
948978
from ..model.audio import BUILTIN_AUDIO_MODELS
949979
from ..model.audio.custom import get_user_defined_audios
980+
from ..model.cache_manager import CacheManager
950981

951982
# Add built-in audio models (BUILTIN_AUDIO_MODELS contains model_name -> families list)
952983
for model_name, families in BUILTIN_AUDIO_MODELS.items():
953984
for family in families:
954985
if detailed:
986+
audio_cache_manager = CacheManager(family)
987+
model_specs = [
988+
{
989+
"model_format": "pytorch",
990+
"model_hub": family.model_hub,
991+
"model_id": family.model_id,
992+
"cache_status": audio_cache_manager.get_cache_status(),
993+
}
994+
]
955995
ret.append(
956996
{
957997
**family.dict(),
998+
"model_specs": model_specs,
958999
"is_builtin": True,
9591000
"download_hubs": [family.model_hub],
9601001
}
@@ -965,9 +1006,19 @@ def sort_helper(item):
9651006
# Add user-defined audio models
9661007
for model_spec in get_user_defined_audios():
9671008
if detailed:
1009+
audio_cache_manager = CacheManager(model_spec)
1010+
model_specs = [
1011+
{
1012+
"model_format": "pytorch",
1013+
"model_hub": model_spec.model_hub,
1014+
"model_id": model_spec.model_id,
1015+
"cache_status": audio_cache_manager.get_cache_status(),
1016+
}
1017+
]
9681018
ret.append(
9691019
{
9701020
**model_spec.dict(),
1021+
"model_specs": model_specs,
9711022
"is_builtin": False,
9721023
"download_hubs": [model_spec.model_hub],
9731024
}
@@ -980,15 +1031,26 @@ def sort_helper(item):
9801031
ret.sort(key=sort_helper)
9811032
return ret
9821033
elif model_type == "video":
1034+
from ..model.cache_manager import CacheManager
9831035
from ..model.video import BUILTIN_VIDEO_MODELS
9841036

9851037
# Add built-in video models (BUILTIN_VIDEO_MODELS contains model_name -> families list)
9861038
for model_name, families in BUILTIN_VIDEO_MODELS.items():
9871039
for family in families:
9881040
if detailed:
1041+
video_cache_manager = CacheManager(family)
1042+
model_specs = [
1043+
{
1044+
"model_format": "pytorch",
1045+
"model_hub": family.model_hub,
1046+
"model_id": family.model_id,
1047+
"cache_status": video_cache_manager.get_cache_status(),
1048+
}
1049+
]
9891050
ret.append(
9901051
{
9911052
**family.dict(),
1053+
"model_specs": model_specs,
9921054
"is_builtin": True,
9931055
"download_hubs": [family.model_hub],
9941056
}
@@ -1000,23 +1062,20 @@ def sort_helper(item):
10001062
return ret
10011063
elif model_type == "rerank":
10021064
from ..model.rerank import BUILTIN_RERANK_MODELS
1065+
from ..model.rerank.cache_manager import RerankCacheManager
10031066
from ..model.rerank.custom import get_user_defined_reranks
10041067

10051068
# Add built-in rerank models (BUILTIN_RERANK_MODELS contains model_name -> family_list list)
10061069
for model_name, family_list in BUILTIN_RERANK_MODELS.items():
10071070
for family in family_list:
10081071
if detailed:
1009-
# Remove duplicate hubs while preserving order
1010-
seen_hubs = set()
1011-
download_hubs = []
1012-
for spec in family.model_specs:
1013-
if spec.model_hub not in seen_hubs:
1014-
seen_hubs.add(spec.model_hub)
1015-
download_hubs.append(spec.model_hub)
1016-
1072+
specs, download_hubs = self._get_spec_dicts_with_cache_status(
1073+
family, RerankCacheManager
1074+
)
10171075
ret.append(
10181076
{
10191077
**family.dict(),
1078+
"model_specs": specs,
10201079
"is_builtin": True,
10211080
"download_hubs": download_hubs,
10221081
}
@@ -1027,17 +1086,13 @@ def sort_helper(item):
10271086
# Add user-defined rerank models
10281087
for model_spec in get_user_defined_reranks():
10291088
if detailed:
1030-
# Remove duplicate hubs while preserving order
1031-
seen_hubs = set()
1032-
download_hubs = []
1033-
for spec in model_spec.model_specs:
1034-
if spec.model_hub not in seen_hubs:
1035-
seen_hubs.add(spec.model_hub)
1036-
download_hubs.append(spec.model_hub)
1037-
1089+
specs, download_hubs = self._get_spec_dicts_with_cache_status(
1090+
model_spec, RerankCacheManager
1091+
)
10381092
ret.append(
10391093
{
10401094
**model_spec.dict(),
1095+
"model_specs": specs,
10411096
"is_builtin": False,
10421097
"download_hubs": download_hubs,
10431098
}
@@ -1055,7 +1110,18 @@ def sort_helper(item):
10551110
ret = []
10561111

10571112
for model_spec in get_flexible_models():
1058-
ret.append({"model_name": model_spec.model_name, "is_builtin": False})
1113+
if detailed:
1114+
ret.append(
1115+
{
1116+
**model_spec.dict(),
1117+
"cache_status": True,
1118+
"is_builtin": False,
1119+
}
1120+
)
1121+
else:
1122+
ret.append(
1123+
{"model_name": model_spec.model_name, "is_builtin": False}
1124+
)
10591125

10601126
ret.sort(key=sort_helper)
10611127
return ret

0 commit comments

Comments
 (0)