Skip to content

Commit e188317

Browse files
author
zhanghui-china
committed
merge text2image into main
1 parent 15eaec7 commit e188317

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+120665
-138
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ Demo 访问地址:https://openxlab.org.cn/apps/detail/zhanghui-china/shishen20
328328
329329
## 项目参与人员(排名不分先后)
330330
331-
1.张小白,项目策划、测试和打杂。现为某IT公司数据工程师,华为云HCDE(原华为云MVP),2020年华为云社区十佳博主,2022年昇腾社区优秀开发者,2022年华为云社区年度优秀版主,MindSpore布道师,DataWhale优秀学习者, [知乎](https://www.zhihu.com/people/zhanghui_china)
331+
1.张小白,项目策划、测试和打杂。南京大学本科毕业,现为某IT公司数据工程师,华为云HCDE(原华为云MVP),2020年华为云社区十佳博主,2022年昇腾社区优秀开发者,2022年华为云社区年度优秀版主,MindSpore布道师,DataWhale优秀学习者, [知乎](https://www.zhihu.com/people/zhanghui_china)
332332
333333
2.sole fish:语音输入 [github](https://github.com/YanxingLiu)
334334
@@ -342,7 +342,7 @@ Demo 访问地址:https://openxlab.org.cn/apps/detail/zhanghui-china/shishen20
342342
343343
7.刘光磊:图标设计,前端优化 [github](https://github.com/Mrguanglei)
344344
345-
8.喵喵咪:数据集准备 [github](https://github.com/miyc1996)
345+
8.喵喵咪:数据集准备,后续本地小模型部署测试,北京航空航天大学硕士,现为上海某国企工程师。 [github](https://github.com/miyc1996)
346346
347347
9.王巍龙:数据集,微调
348348

README_EN.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ Demo link:https://openxlab.org.cn/apps/detail/zhanghui-china/nlp_shishen3
311311

312312
7.Liu Guanglei:iconic design, front-end optimization [github](https://github.com/Mrguanglei)
313313

314-
8.Miao Miaomi:datasets [github](https://github.com/miyc1996)
314+
8.Miao Miaomi:datasets @Beijing University of Aeronautics and Astronautics [github](https://github.com/miyc1996)
315315

316316
9.Wang Weilong:datasets, fine-tuning
317317

@@ -323,3 +323,7 @@ Demo link:https://openxlab.org.cn/apps/detail/zhanghui-china/nlp_shishen3
323323
## License
324324

325325
The project follows [Apache License 2.0](LICENSE.txt)
326+
327+
## Star History
328+
329+
[![Star History Chart](https://api.star-history.com/svg?repos=SmartFlowAI/TheGodOfCookery&type=Date)](https://star-history.com/#SmartFlowAI/TheGodOfCookery&Date)

app-enhanced-rag-test.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
"""
2+
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
3+
We mainly modified part of the code logic to adapt to the generation of our model.
4+
Please refer to these links below for more information:
5+
1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
6+
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
7+
3. transformers: https://github.com/huggingface/transformers
8+
"""
9+
import os
10+
import sys
11+
from dataclasses import asdict
12+
13+
import streamlit as st
14+
import torch
15+
from langchain_community.llms.tongyi import Tongyi
16+
# from audiorecorder import audiorecorder
17+
# from modelscope import AutoModelForCausalLM, AutoTokenizer
18+
from transformers.utils import logging
19+
20+
from rag.CookMasterLLM import CookMasterLLM
21+
from rag.interface import (GenerationConfig,
22+
generate_interactive,
23+
generate_interactive_rag_stream,
24+
generate_interactive_rag)
25+
26+
# from whisper_app import run_whisper
27+
# from download import finetuned
28+
29+
logger = logging.get_logger(__name__)
30+
31+
# __import__('pysqlite3')
32+
33+
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
34+
35+
# global variables
36+
enable_rag = True
37+
streaming = False
38+
user_avatar = "images/user.png"
39+
robot_avatar = "images/robot.png"
40+
user_prompt = "<|User|>:{user}\n"
41+
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
42+
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
43+
# speech
44+
audio_save_path = "/tmp/audio.wav"
45+
whisper_model_scale = "medium"
46+
47+
def on_btn_click():
48+
"""
49+
点击按钮时执行的函数,用于删除session_state中存储的消息。
50+
51+
Args:
52+
53+
54+
Returns:
55+
56+
"""
57+
del st.session_state.messages
58+
59+
60+
@st.cache_resource
61+
def load_model():
62+
"""
63+
加载预训练模型和分词器。
64+
65+
Args:
66+
无。
67+
68+
Returns:
69+
model (Transformers模型): 预训练模型。
70+
tokenizer (Transformers分词器): 分词器。
71+
"""
72+
model = (
73+
AutoModelForCausalLM.from_pretrained(llm_model_path, trust_remote_code=True)
74+
.to(torch.bfloat16)
75+
.cuda()
76+
)
77+
tokenizer = AutoTokenizer.from_pretrained(llm_model_path, trust_remote_code=True)
78+
return model, tokenizer
79+
80+
81+
def prepare_generation_config():
82+
"""
83+
准备生成配置。
84+
85+
Args:
86+
87+
88+
Returns:
89+
Tuple[GenerationConfig, Optional[str]]: 包含生成配置和语音字符串的元组。
90+
- GenerationConfig: 生成配置。
91+
- Optional[str]: 语音字符串,如果没有录制语音则为None。
92+
"""
93+
with st.sidebar:
94+
# 1. Max length of the generated text
95+
max_length = st.slider("Max Length", min_value=32,
96+
max_value=2048, value=2048)
97+
98+
# 2. Clear history.
99+
st.button("Clear Chat History", on_click=on_btn_click)
100+
101+
# 3. Enable RAG
102+
global enable_rag
103+
enable_rag = st.checkbox("Enable RAG", value=True)
104+
105+
# 4. Streaming
106+
global streaming
107+
streaming = st.checkbox("Streaming", value=False)
108+
109+
# 5. Speech input
110+
# audio = audiorecorder("Record", "Stop record")
111+
# speech_string = None
112+
# if len(audio) > 0:
113+
# audio.export(audio_save_path, format="wav")
114+
# speech_string = run_whisper(
115+
# whisper_model_scale, "cuda",
116+
# audio_save_path)
117+
118+
generation_config = GenerationConfig(
119+
max_length=max_length, top_p=0.8, temperature=0.8, repetition_penalty=1.002)
120+
121+
return generation_config, None
122+
123+
124+
def combine_history(prompt):
125+
"""
126+
根据用户输入的提示信息,组合出一段完整的对话历史,用于机器人进行对话。
127+
128+
Args:
129+
prompt (str): 用户输入的提示信息。
130+
131+
Returns:
132+
str: 组合好的对话历史。
133+
"""
134+
messages = st.session_state.messages
135+
total_prompt = "您是一个厨师,熟悉很多菜的制作方法。用户会问你哪些菜怎么制作,您可以用自己的专业知识答复他。回答的内容一般包含两块:这道菜需要哪些食材,这道菜具体是怎么做出来的。如果用户没有问菜谱相关的问题,就提醒他对菜谱的相关问题进行提问。"
136+
for message in messages:
137+
cur_content = message["content"]
138+
if message["role"] == "user":
139+
cur_prompt = user_prompt.replace("{user}", cur_content)
140+
elif message["role"] == "robot":
141+
cur_prompt = robot_prompt.replace("{robot}", cur_content)
142+
else:
143+
raise RuntimeError
144+
total_prompt += cur_prompt
145+
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
146+
return total_prompt
147+
148+
149+
def process_user_input(prompt,
150+
model,
151+
tokenizer,
152+
generation_config):
153+
"""
154+
处理用户输入,根据用户输入内容调用相应的模型生成回复。
155+
156+
Args:
157+
prompt (str): 用户输入的内容。
158+
model (str): 使用的模型名称。
159+
tokenizer (object): 分词器对象。
160+
generation_config (dict): 生成配置参数。
161+
162+
"""
163+
# Check if the user input contains certain keywords
164+
keywords = ["怎么做", "做法", "菜谱"]
165+
contains_keywords = any(keyword in prompt for keyword in keywords)
166+
167+
# Display user message in chat message container
168+
with st.chat_message("user", avatar=user_avatar):
169+
st.markdown(prompt)
170+
real_prompt = combine_history(prompt)
171+
172+
# Add user message to chat history
173+
st.session_state.messages.append(
174+
{"role": "user", "content": prompt, "avatar": user_avatar})
175+
176+
# If keywords are not present, display a prompt message immediately
177+
if not contains_keywords:
178+
with st.chat_message("robot", avatar=robot_avatar):
179+
st.markdown(
180+
"我是食神周星星的唯一传人张小白,我什么菜都会做,包括黑暗料理,您可以问我什么菜怎么做———比如酸菜鱼怎么做?,我会告诉你具体的做法。")
181+
# Add robot response to chat history
182+
st.session_state.messages.append(
183+
{"role": "robot",
184+
"content": "我是食神周星星的唯一传人张小白,我什么菜都会做,包括黑暗料理,您可以问我什么菜怎么做———比如酸菜鱼怎么做?,我会告诉你具体的做法。",
185+
"avatar": robot_avatar})
186+
else:
187+
# Generate robot response
188+
with st.chat_message("robot", avatar=robot_avatar):
189+
message_placeholder = st.empty()
190+
if enable_rag:
191+
if streaming:
192+
generator = generate_interactive_rag_stream(
193+
model=model,
194+
tokenizer=tokenizer,
195+
prompt=prompt,
196+
history=real_prompt,
197+
verbose=False
198+
)
199+
for cur_response in generator:
200+
cur_response = cur_response.replace('\\n', '\n')
201+
message_placeholder.markdown(cur_response + "▌")
202+
message_placeholder.markdown(cur_response)
203+
else:
204+
cur_response = generate_interactive_rag(
205+
model=model,
206+
tokenizer=tokenizer,
207+
prompt=prompt,
208+
history=real_prompt,
209+
verbose=False
210+
)
211+
message_placeholder.markdown(cur_response)
212+
else:
213+
generator = generate_interactive(
214+
model=model,
215+
tokenizer=tokenizer,
216+
prompt=real_prompt,
217+
# additional_eos_token_id=103028,
218+
additional_eos_token_id=92542,
219+
**asdict(generation_config),
220+
)
221+
for cur_response in generator:
222+
cur_response = cur_response.replace('\\n', '\n')
223+
message_placeholder.markdown(cur_response + "▌")
224+
message_placeholder.markdown(cur_response)
225+
# for cur_response in generator:
226+
# cur_response = cur_response.replace('\\n', '\n')
227+
# message_placeholder.markdown(cur_response + "▌")
228+
# message_placeholder.markdown(cur_response)
229+
# Add robot response to chat history
230+
st.session_state.messages.append(
231+
{"role": "robot", "content": cur_response, "avatar": robot_avatar})
232+
torch.cuda.empty_cache()
233+
234+
235+
def main():
236+
print("Torch version:")
237+
print(torch.__version__)
238+
print("Torch support GPU: ")
239+
print(torch.cuda.is_available())
240+
241+
st.title("食神2——菜谱小助手 by 张小白")
242+
model, tokenizer = load_model()
243+
generation_config, speech_prompt = prepare_generation_config()
244+
245+
# 1.Initialize chat history
246+
if "messages" not in st.session_state:
247+
st.session_state.messages = []
248+
249+
# 2.Display chat messages from history on app rerun
250+
for message in st.session_state.messages:
251+
with st.chat_message(message["role"], avatar=message.get("avatar")):
252+
st.markdown(message["content"])
253+
254+
# 3.Process text input
255+
if text_prompt := st.chat_input("What is up?"):
256+
process_user_input(text_prompt, model, tokenizer, generation_config)
257+
258+
# 4. Process speech input
259+
if speech_prompt is not None:
260+
process_user_input(speech_prompt, model, tokenizer, generation_config)
261+
262+
263+
if __name__ == "__main__":
264+
main()

0 commit comments

Comments
 (0)