diff --git a/query_data.py b/query_data.py index 43ed9a5e6..b80cf9bab 100644 --- a/query_data.py +++ b/query_data.py @@ -5,7 +5,13 @@ from langchain_openai import ChatOpenAI from langchain.prompts import ChatPromptTemplate +import openai +from dotenv import load_dotenv +import os + +load_dotenv() CHROMA_PATH = "chroma" +openai.api_key = os.environ["OPENAI_API_KEY"] PROMPT_TEMPLATE = """ Answer the question based only on the following context: @@ -41,7 +47,7 @@ def main(): print(prompt) model = ChatOpenAI() - response_text = model.predict(prompt) + response_text = model.invoke(prompt) sources = [doc.metadata.get("source", None) for doc, _score in results] formatted_response = f"Response: {response_text}\nSources: {sources}"