77 @desc:
88"""
99
10- import json
11- from collections .abc import Iterator
12- from typing import Any , Dict , List , Optional
10+ from typing import Dict , List
1311
14- from langchain_community .chat_models import ChatZhipuAI
15- from langchain_community .chat_models .zhipuai import _truncate_params , _get_jwt_token , connect_sse , \
16- _convert_delta_to_message_chunk
17- from langchain_core .callbacks import (
18- CallbackManagerForLLMRun ,
19- )
20- from langchain_core .messages import (
21- AIMessageChunk ,
22- BaseMessage
23- )
24- from langchain_core .outputs import ChatGenerationChunk
12+ from langchain_core .messages import BaseMessage , get_buffer_string
2513
14+ from common .config .tokenizer_manage_config import TokenizerManage
2615from models_provider .base_model_provider import MaxKBBaseModel
16+ from models_provider .impl .base_chat_open_ai import BaseChatOpenAI
2717
2818
29- class ZhipuChatModel (MaxKBBaseModel , ChatZhipuAI ):
30- optional_params : dict
19+ def custom_get_token_ids (text : str ):
20+ tokenizer = TokenizerManage .get_tokenizer ()
21+ return tokenizer .encode (text )
22+
23+
24+ class ZhipuChatModel (MaxKBBaseModel , BaseChatOpenAI ):
3125
3226 @staticmethod
3327 def is_cache_model ():
@@ -39,69 +33,23 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3933 zhipuai_chat = ZhipuChatModel (
4034 api_key = model_credential .get ('api_key' ),
4135 model = model_name ,
36+ base_url = 'https://open.bigmodel.cn/api/paas/v4' ,
37+ extra_body = optional_params ,
4238 streaming = model_kwargs .get ('streaming' , False ),
43- optional_params = optional_params ,
44- ** optional_params ,
39+ custom_get_token_ids = custom_get_token_ids
4540 )
4641 return zhipuai_chat
4742
48- usage_metadata : dict = {}
49-
50- def get_last_generation_info (self ) -> Optional [Dict [str , Any ]]:
51- return self .usage_metadata
52-
5343 def get_num_tokens_from_messages (self , messages : List [BaseMessage ]) -> int :
54- return self .usage_metadata .get ('prompt_tokens' , 0 )
44+ try :
45+ return super ().get_num_tokens_from_messages (messages )
46+ except Exception as e :
47+ tokenizer = TokenizerManage .get_tokenizer ()
48+ return sum ([len (tokenizer .encode (get_buffer_string ([m ]))) for m in messages ])
5549
5650 def get_num_tokens (self , text : str ) -> int :
57- return self .usage_metadata .get ('completion_tokens' , 0 )
58-
59- def _stream (
60- self ,
61- messages : List [BaseMessage ],
62- stop : Optional [List [str ]] = None ,
63- run_manager : Optional [CallbackManagerForLLMRun ] = None ,
64- ** kwargs : Any ,
65- ) -> Iterator [ChatGenerationChunk ]:
66- """Stream the chat response in chunks."""
67- if self .zhipuai_api_key is None :
68- raise ValueError ("Did not find zhipuai_api_key." )
69- if self .zhipuai_api_base is None :
70- raise ValueError ("Did not find zhipu_api_base." )
71- message_dicts , params = self ._create_message_dicts (messages , stop )
72- payload = {** params , ** kwargs , ** self .optional_params , "messages" : message_dicts , "stream" : True }
73- _truncate_params (payload )
74- headers = {
75- "Authorization" : _get_jwt_token (self .zhipuai_api_key ),
76- "Accept" : "application/json" ,
77- }
78-
79- default_chunk_class = AIMessageChunk
80- import httpx
81-
82- with httpx .Client (headers = headers , timeout = 60 ) as client :
83- with connect_sse (
84- client , "POST" , self .zhipuai_api_base , json = payload
85- ) as event_source :
86- for sse in event_source .iter_sse ():
87- chunk = json .loads (sse .data )
88- if len (chunk ["choices" ]) == 0 :
89- continue
90- choice = chunk ["choices" ][0 ]
91- generation_info = {}
92- if "usage" in chunk :
93- generation_info = chunk ["usage" ]
94- self .usage_metadata = generation_info
95- chunk = _convert_delta_to_message_chunk (
96- choice ["delta" ], default_chunk_class
97- )
98- finish_reason = choice .get ("finish_reason" , None )
99-
100- chunk = ChatGenerationChunk (
101- message = chunk , generation_info = generation_info
102- )
103- yield chunk
104- if run_manager :
105- run_manager .on_llm_new_token (chunk .text , chunk = chunk )
106- if finish_reason is not None :
107- break
51+ try :
52+ return super ().get_num_tokens (text )
53+ except Exception as e :
54+ tokenizer = TokenizerManage .get_tokenizer ()
55+ return len (tokenizer .encode (text ))
0 commit comments