forked from AntoniZap/IBM-Chatbot
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm.py
More file actions
76 lines (70 loc) · 2.4 KB
/
llm.py
File metadata and controls
76 lines (70 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains.combine_documents import create_stuff_documents_chain
from local import resolve
from config import options
#setup function for llama
def setup_llama():
from langchain_community.llms import LlamaCpp
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from llama_cpp.llama_cache import LlamaDiskCache
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
llm = LlamaCpp(
model_path= os.getenv('LLAMA_MODEL_PATH'),
# callback_manager = callback_manager,
verbose = False,
n_ctx=1024,
)
llm.client.set_cache(LlamaDiskCache())
return llm
#setup function for chatgpt
def setup_chatgpt():
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(temperature = 0.6)
return llm
#setup function for ai21
def setup_ai21():
from langchain_community.llms import AI21
llm = AI21(temperature=0)
return llm
#retrieves the llm based on the llm_choice parameter
def get_raw_llm(llm_choice):
global raw_llms
llm_choice = llm_choice.lower()
try:
llm = raw_llms[llm_choice]
except KeyError:
if llm_choice == "llama":
setup = setup_llama
elif llm_choice == "ai21":
setup = setup_ai21
elif llm_choice == "chatgpt":
setup = setup_chatgpt
else:
raise KeyError()
llm = setup()
raw_llms[llm_choice] = llm
return raw_llms[llm_choice]
#safer alternative to get_raw_llm which checks if the llm is already setup first
def get_llm(llm_choice):
global llms
llm_choice = llm_choice.lower()
try:
llm = llms[llm_choice]
except KeyError:
llm = get_raw_llm(llm_choice)
system_prompt = resolve(options["language"], "system_prompt")
prompt = ChatPromptTemplate.from_messages(
[
("system", f"{system_prompt}\n\\{{context}}\n----------"),
MessagesPlaceholder(variable_name="messages")
]
)
document_chain = create_stuff_documents_chain(llm, prompt)
chain = RunnablePassthrough.assign(answer=document_chain)
llms[llm_choice] = chain
return llms[llm_choice]
llms = {}
raw_llms = {}