diff --git a/apps/models_provider/impl/zhipu_model_provider/model/llm.py b/apps/models_provider/impl/zhipu_model_provider/model/llm.py index dbbb9b67147..308d1ca35da 100644 --- a/apps/models_provider/impl/zhipu_model_provider/model/llm.py +++ b/apps/models_provider/impl/zhipu_model_provider/model/llm.py @@ -7,27 +7,21 @@ @desc: """ -import json -from collections.abc import Iterator -from typing import Any, Dict, List, Optional +from typing import Dict, List -from langchain_community.chat_models import ChatZhipuAI -from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \ - _convert_delta_to_message_chunk -from langchain_core.callbacks import ( - CallbackManagerForLLMRun, -) -from langchain_core.messages import ( - AIMessageChunk, - BaseMessage -) -from langchain_core.outputs import ChatGenerationChunk +from langchain_core.messages import BaseMessage, get_buffer_string +from common.config.tokenizer_manage_config import TokenizerManage from models_provider.base_model_provider import MaxKBBaseModel +from models_provider.impl.base_chat_open_ai import BaseChatOpenAI -class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): - optional_params: dict +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class ZhipuChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def is_cache_model(): @@ -39,69 +33,23 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** zhipuai_chat = ZhipuChatModel( api_key=model_credential.get('api_key'), model=model_name, + base_url='https://open.bigmodel.cn/api/paas/v4', + extra_body=optional_params, streaming=model_kwargs.get('streaming', False), - optional_params=optional_params, - **optional_params, + custom_get_token_ids=custom_get_token_ids ) return zhipuai_chat - usage_metadata: dict = {} - - def get_last_generation_info(self) -> Optional[Dict[str, Any]]: - return self.usage_metadata - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.usage_metadata.get('prompt_tokens', 0) + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: - return self.usage_metadata.get('completion_tokens', 0) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - """Stream the chat response in chunks.""" - if self.zhipuai_api_key is None: - raise ValueError("Did not find zhipuai_api_key.") - if self.zhipuai_api_base is None: - raise ValueError("Did not find zhipu_api_base.") - message_dicts, params = self._create_message_dicts(messages, stop) - payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True} - _truncate_params(payload) - headers = { - "Authorization": _get_jwt_token(self.zhipuai_api_key), - "Accept": "application/json", - } - - default_chunk_class = AIMessageChunk - import httpx - - with httpx.Client(headers=headers, timeout=60) as client: - with connect_sse( - client, "POST", self.zhipuai_api_base, json=payload - ) as event_source: - for sse in event_source.iter_sse(): - chunk = json.loads(sse.data) - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - generation_info = {} - if "usage" in chunk: - generation_info = chunk["usage"] - self.usage_metadata = generation_info - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - finish_reason = choice.get("finish_reason", None) - - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info - ) - yield chunk - if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) - if finish_reason is not None: - break + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text))