Skip to content

Commit 543f83a

Browse files
committed
refactor: update ZhipuChatModel to use BaseChatOpenAI and improve token counting
--bug=1061305 --user=刘瑞斌 【应用】ai对话启用工具后部分模型(智谱)不统计tokens https://www.tapd.cn/62980211/s/1791683
1 parent ef162bd commit 543f83a

File tree

1 file changed

+23
-75
lines changed
  • apps/models_provider/impl/zhipu_model_provider/model

1 file changed

+23
-75
lines changed

apps/models_provider/impl/zhipu_model_provider/model/llm.py

Lines changed: 23 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,21 @@
77
@desc:
88
"""
99

10-
import json
11-
from collections.abc import Iterator
12-
from typing import Any, Dict, List, Optional
10+
from typing import Dict, List
1311

14-
from langchain_community.chat_models import ChatZhipuAI
15-
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
16-
_convert_delta_to_message_chunk
17-
from langchain_core.callbacks import (
18-
CallbackManagerForLLMRun,
19-
)
20-
from langchain_core.messages import (
21-
AIMessageChunk,
22-
BaseMessage
23-
)
24-
from langchain_core.outputs import ChatGenerationChunk
12+
from langchain_core.messages import BaseMessage, get_buffer_string
2513

14+
from common.config.tokenizer_manage_config import TokenizerManage
2615
from models_provider.base_model_provider import MaxKBBaseModel
16+
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
2717

2818

29-
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
30-
optional_params: dict
19+
def custom_get_token_ids(text: str):
20+
tokenizer = TokenizerManage.get_tokenizer()
21+
return tokenizer.encode(text)
22+
23+
24+
class ZhipuChatModel(MaxKBBaseModel, BaseChatOpenAI):
3125

3226
@staticmethod
3327
def is_cache_model():
@@ -39,69 +33,23 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3933
zhipuai_chat = ZhipuChatModel(
4034
api_key=model_credential.get('api_key'),
4135
model=model_name,
36+
base_url='https://open.bigmodel.cn/api/paas/v4',
37+
extra_body=optional_params,
4238
streaming=model_kwargs.get('streaming', False),
43-
optional_params=optional_params,
44-
**optional_params,
39+
custom_get_token_ids=custom_get_token_ids
4540
)
4641
return zhipuai_chat
4742

48-
usage_metadata: dict = {}
49-
50-
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
51-
return self.usage_metadata
52-
5343
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
54-
return self.usage_metadata.get('prompt_tokens', 0)
44+
try:
45+
return super().get_num_tokens_from_messages(messages)
46+
except Exception as e:
47+
tokenizer = TokenizerManage.get_tokenizer()
48+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
5549

5650
def get_num_tokens(self, text: str) -> int:
57-
return self.usage_metadata.get('completion_tokens', 0)
58-
59-
def _stream(
60-
self,
61-
messages: List[BaseMessage],
62-
stop: Optional[List[str]] = None,
63-
run_manager: Optional[CallbackManagerForLLMRun] = None,
64-
**kwargs: Any,
65-
) -> Iterator[ChatGenerationChunk]:
66-
"""Stream the chat response in chunks."""
67-
if self.zhipuai_api_key is None:
68-
raise ValueError("Did not find zhipuai_api_key.")
69-
if self.zhipuai_api_base is None:
70-
raise ValueError("Did not find zhipu_api_base.")
71-
message_dicts, params = self._create_message_dicts(messages, stop)
72-
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
73-
_truncate_params(payload)
74-
headers = {
75-
"Authorization": _get_jwt_token(self.zhipuai_api_key),
76-
"Accept": "application/json",
77-
}
78-
79-
default_chunk_class = AIMessageChunk
80-
import httpx
81-
82-
with httpx.Client(headers=headers, timeout=60) as client:
83-
with connect_sse(
84-
client, "POST", self.zhipuai_api_base, json=payload
85-
) as event_source:
86-
for sse in event_source.iter_sse():
87-
chunk = json.loads(sse.data)
88-
if len(chunk["choices"]) == 0:
89-
continue
90-
choice = chunk["choices"][0]
91-
generation_info = {}
92-
if "usage" in chunk:
93-
generation_info = chunk["usage"]
94-
self.usage_metadata = generation_info
95-
chunk = _convert_delta_to_message_chunk(
96-
choice["delta"], default_chunk_class
97-
)
98-
finish_reason = choice.get("finish_reason", None)
99-
100-
chunk = ChatGenerationChunk(
101-
message=chunk, generation_info=generation_info
102-
)
103-
yield chunk
104-
if run_manager:
105-
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
106-
if finish_reason is not None:
107-
break
51+
try:
52+
return super().get_num_tokens(text)
53+
except Exception as e:
54+
tokenizer = TokenizerManage.get_tokenizer()
55+
return len(tokenizer.encode(text))

0 commit comments

Comments
 (0)