|
| 1 | +""" |
| 2 | +This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers. |
| 3 | +We mainly modified part of the code logic to adapt to the generation of our model. |
| 4 | +Please refer to these links below for more information: |
| 5 | + 1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps |
| 6 | + 2. chatglm2: https://github.com/THUDM/ChatGLM2-6B |
| 7 | + 3. transformers: https://github.com/huggingface/transformers |
| 8 | +""" |
| 9 | +import os |
| 10 | +import sys |
| 11 | +from dataclasses import asdict |
| 12 | + |
| 13 | +import streamlit as st |
| 14 | +import torch |
| 15 | +from langchain_community.llms.tongyi import Tongyi |
| 16 | +# from audiorecorder import audiorecorder |
| 17 | +# from modelscope import AutoModelForCausalLM, AutoTokenizer |
| 18 | +from transformers.utils import logging |
| 19 | + |
| 20 | +from rag.CookMasterLLM import CookMasterLLM |
| 21 | +from rag.interface import (GenerationConfig, |
| 22 | + generate_interactive, |
| 23 | + generate_interactive_rag_stream, |
| 24 | + generate_interactive_rag) |
| 25 | + |
| 26 | +# from whisper_app import run_whisper |
| 27 | +# from download import finetuned |
| 28 | + |
| 29 | +logger = logging.get_logger(__name__) |
| 30 | + |
| 31 | +# __import__('pysqlite3') |
| 32 | + |
| 33 | +# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') |
| 34 | + |
| 35 | +# global variables |
| 36 | +enable_rag = True |
| 37 | +streaming = False |
| 38 | +user_avatar = "images/user.png" |
| 39 | +robot_avatar = "images/robot.png" |
| 40 | +user_prompt = "<|User|>:{user}\n" |
| 41 | +robot_prompt = "<|Bot|>:{robot}<eoa>\n" |
| 42 | +cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:" |
| 43 | +# speech |
| 44 | +audio_save_path = "/tmp/audio.wav" |
| 45 | +whisper_model_scale = "medium" |
| 46 | + |
| 47 | +def on_btn_click(): |
| 48 | + """ |
| 49 | + 点击按钮时执行的函数,用于删除session_state中存储的消息。 |
| 50 | +
|
| 51 | + Args: |
| 52 | + 无 |
| 53 | +
|
| 54 | + Returns: |
| 55 | + 无 |
| 56 | + """ |
| 57 | + del st.session_state.messages |
| 58 | + |
| 59 | + |
| 60 | +@st.cache_resource |
| 61 | +def load_model(): |
| 62 | + """ |
| 63 | + 加载预训练模型和分词器。 |
| 64 | +
|
| 65 | + Args: |
| 66 | + 无。 |
| 67 | +
|
| 68 | + Returns: |
| 69 | + model (Transformers模型): 预训练模型。 |
| 70 | + tokenizer (Transformers分词器): 分词器。 |
| 71 | + """ |
| 72 | + model = ( |
| 73 | + AutoModelForCausalLM.from_pretrained(llm_model_path, trust_remote_code=True) |
| 74 | + .to(torch.bfloat16) |
| 75 | + .cuda() |
| 76 | + ) |
| 77 | + tokenizer = AutoTokenizer.from_pretrained(llm_model_path, trust_remote_code=True) |
| 78 | + return model, tokenizer |
| 79 | + |
| 80 | + |
| 81 | +def prepare_generation_config(): |
| 82 | + """ |
| 83 | + 准备生成配置。 |
| 84 | +
|
| 85 | + Args: |
| 86 | + 无 |
| 87 | +
|
| 88 | + Returns: |
| 89 | + Tuple[GenerationConfig, Optional[str]]: 包含生成配置和语音字符串的元组。 |
| 90 | + - GenerationConfig: 生成配置。 |
| 91 | + - Optional[str]: 语音字符串,如果没有录制语音则为None。 |
| 92 | + """ |
| 93 | + with st.sidebar: |
| 94 | + # 1. Max length of the generated text |
| 95 | + max_length = st.slider("Max Length", min_value=32, |
| 96 | + max_value=2048, value=2048) |
| 97 | + |
| 98 | + # 2. Clear history. |
| 99 | + st.button("Clear Chat History", on_click=on_btn_click) |
| 100 | + |
| 101 | + # 3. Enable RAG |
| 102 | + global enable_rag |
| 103 | + enable_rag = st.checkbox("Enable RAG", value=True) |
| 104 | + |
| 105 | + # 4. Streaming |
| 106 | + global streaming |
| 107 | + streaming = st.checkbox("Streaming", value=False) |
| 108 | + |
| 109 | + # 5. Speech input |
| 110 | + # audio = audiorecorder("Record", "Stop record") |
| 111 | + # speech_string = None |
| 112 | + # if len(audio) > 0: |
| 113 | + # audio.export(audio_save_path, format="wav") |
| 114 | + # speech_string = run_whisper( |
| 115 | + # whisper_model_scale, "cuda", |
| 116 | + # audio_save_path) |
| 117 | + |
| 118 | + generation_config = GenerationConfig( |
| 119 | + max_length=max_length, top_p=0.8, temperature=0.8, repetition_penalty=1.002) |
| 120 | + |
| 121 | + return generation_config, None |
| 122 | + |
| 123 | + |
| 124 | +def combine_history(prompt): |
| 125 | + """ |
| 126 | + 根据用户输入的提示信息,组合出一段完整的对话历史,用于机器人进行对话。 |
| 127 | +
|
| 128 | + Args: |
| 129 | + prompt (str): 用户输入的提示信息。 |
| 130 | +
|
| 131 | + Returns: |
| 132 | + str: 组合好的对话历史。 |
| 133 | + """ |
| 134 | + messages = st.session_state.messages |
| 135 | + total_prompt = "您是一个厨师,熟悉很多菜的制作方法。用户会问你哪些菜怎么制作,您可以用自己的专业知识答复他。回答的内容一般包含两块:这道菜需要哪些食材,这道菜具体是怎么做出来的。如果用户没有问菜谱相关的问题,就提醒他对菜谱的相关问题进行提问。" |
| 136 | + for message in messages: |
| 137 | + cur_content = message["content"] |
| 138 | + if message["role"] == "user": |
| 139 | + cur_prompt = user_prompt.replace("{user}", cur_content) |
| 140 | + elif message["role"] == "robot": |
| 141 | + cur_prompt = robot_prompt.replace("{robot}", cur_content) |
| 142 | + else: |
| 143 | + raise RuntimeError |
| 144 | + total_prompt += cur_prompt |
| 145 | + total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt) |
| 146 | + return total_prompt |
| 147 | + |
| 148 | + |
| 149 | +def process_user_input(prompt, |
| 150 | + model, |
| 151 | + tokenizer, |
| 152 | + generation_config): |
| 153 | + """ |
| 154 | + 处理用户输入,根据用户输入内容调用相应的模型生成回复。 |
| 155 | +
|
| 156 | + Args: |
| 157 | + prompt (str): 用户输入的内容。 |
| 158 | + model (str): 使用的模型名称。 |
| 159 | + tokenizer (object): 分词器对象。 |
| 160 | + generation_config (dict): 生成配置参数。 |
| 161 | +
|
| 162 | + """ |
| 163 | + # Check if the user input contains certain keywords |
| 164 | + keywords = ["怎么做", "做法", "菜谱"] |
| 165 | + contains_keywords = any(keyword in prompt for keyword in keywords) |
| 166 | + |
| 167 | + # Display user message in chat message container |
| 168 | + with st.chat_message("user", avatar=user_avatar): |
| 169 | + st.markdown(prompt) |
| 170 | + real_prompt = combine_history(prompt) |
| 171 | + |
| 172 | + # Add user message to chat history |
| 173 | + st.session_state.messages.append( |
| 174 | + {"role": "user", "content": prompt, "avatar": user_avatar}) |
| 175 | + |
| 176 | + # If keywords are not present, display a prompt message immediately |
| 177 | + if not contains_keywords: |
| 178 | + with st.chat_message("robot", avatar=robot_avatar): |
| 179 | + st.markdown( |
| 180 | + "我是食神周星星的唯一传人张小白,我什么菜都会做,包括黑暗料理,您可以问我什么菜怎么做———比如酸菜鱼怎么做?,我会告诉你具体的做法。") |
| 181 | + # Add robot response to chat history |
| 182 | + st.session_state.messages.append( |
| 183 | + {"role": "robot", |
| 184 | + "content": "我是食神周星星的唯一传人张小白,我什么菜都会做,包括黑暗料理,您可以问我什么菜怎么做———比如酸菜鱼怎么做?,我会告诉你具体的做法。", |
| 185 | + "avatar": robot_avatar}) |
| 186 | + else: |
| 187 | + # Generate robot response |
| 188 | + with st.chat_message("robot", avatar=robot_avatar): |
| 189 | + message_placeholder = st.empty() |
| 190 | + if enable_rag: |
| 191 | + if streaming: |
| 192 | + generator = generate_interactive_rag_stream( |
| 193 | + model=model, |
| 194 | + tokenizer=tokenizer, |
| 195 | + prompt=prompt, |
| 196 | + history=real_prompt, |
| 197 | + verbose=False |
| 198 | + ) |
| 199 | + for cur_response in generator: |
| 200 | + cur_response = cur_response.replace('\\n', '\n') |
| 201 | + message_placeholder.markdown(cur_response + "▌") |
| 202 | + message_placeholder.markdown(cur_response) |
| 203 | + else: |
| 204 | + cur_response = generate_interactive_rag( |
| 205 | + model=model, |
| 206 | + tokenizer=tokenizer, |
| 207 | + prompt=prompt, |
| 208 | + history=real_prompt, |
| 209 | + verbose=False |
| 210 | + ) |
| 211 | + message_placeholder.markdown(cur_response) |
| 212 | + else: |
| 213 | + generator = generate_interactive( |
| 214 | + model=model, |
| 215 | + tokenizer=tokenizer, |
| 216 | + prompt=real_prompt, |
| 217 | + # additional_eos_token_id=103028, |
| 218 | + additional_eos_token_id=92542, |
| 219 | + **asdict(generation_config), |
| 220 | + ) |
| 221 | + for cur_response in generator: |
| 222 | + cur_response = cur_response.replace('\\n', '\n') |
| 223 | + message_placeholder.markdown(cur_response + "▌") |
| 224 | + message_placeholder.markdown(cur_response) |
| 225 | + # for cur_response in generator: |
| 226 | + # cur_response = cur_response.replace('\\n', '\n') |
| 227 | + # message_placeholder.markdown(cur_response + "▌") |
| 228 | + # message_placeholder.markdown(cur_response) |
| 229 | + # Add robot response to chat history |
| 230 | + st.session_state.messages.append( |
| 231 | + {"role": "robot", "content": cur_response, "avatar": robot_avatar}) |
| 232 | + torch.cuda.empty_cache() |
| 233 | + |
| 234 | + |
| 235 | +def main(): |
| 236 | + print("Torch version:") |
| 237 | + print(torch.__version__) |
| 238 | + print("Torch support GPU: ") |
| 239 | + print(torch.cuda.is_available()) |
| 240 | + |
| 241 | + st.title("食神2——菜谱小助手 by 张小白") |
| 242 | + model, tokenizer = load_model() |
| 243 | + generation_config, speech_prompt = prepare_generation_config() |
| 244 | + |
| 245 | + # 1.Initialize chat history |
| 246 | + if "messages" not in st.session_state: |
| 247 | + st.session_state.messages = [] |
| 248 | + |
| 249 | + # 2.Display chat messages from history on app rerun |
| 250 | + for message in st.session_state.messages: |
| 251 | + with st.chat_message(message["role"], avatar=message.get("avatar")): |
| 252 | + st.markdown(message["content"]) |
| 253 | + |
| 254 | + # 3.Process text input |
| 255 | + if text_prompt := st.chat_input("What is up?"): |
| 256 | + process_user_input(text_prompt, model, tokenizer, generation_config) |
| 257 | + |
| 258 | + # 4. Process speech input |
| 259 | + if speech_prompt is not None: |
| 260 | + process_user_input(speech_prompt, model, tokenizer, generation_config) |
| 261 | + |
| 262 | + |
| 263 | +if __name__ == "__main__": |
| 264 | + main() |
0 commit comments