Skip to content

Commit

Permalink
add openai change between azure or openai
Browse files Browse the repository at this point in the history
  • Loading branch information
dayesouza committed Apr 16, 2024
1 parent 1bfabd7 commit 42661e0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 46 deletions.
57 changes: 22 additions & 35 deletions app/pages/Settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from components.app_loader import load_multipage_app
from util.openai_instance import get_key_env, key, _OpenAI
from util.openai_instance import get_key_env, key, _OpenAI, openai_endpoint_key, openai_version_key, openai_type_key
from util.SecretsHandler import SecretsHandler
import streamlit as st
import time
Expand Down Expand Up @@ -43,44 +43,31 @@ def main():
st.warning("No OpenAI key found in the environment. Please insert one above.")
elif not secret_input and not secret:
st.info("Using key from the environment.")



st.divider()

# st.header("OpenAI Models")
# st.markdown("Select the OpenAI models you want to use. This modification will be valid for this session only.")
# with st.spinner("Fetching models..."):
# try:
# openai_models = openai.client().models.list()
# #Removes deprecated models:
# # Source: https://platform.openai.com/docs/models/gpt-3-5-turbo
# deprecated = [
# 'gpt-3.5-turbo-16k',
# 'gpt-3.5-turbo-0613',
# 'gpt-3.5-turbo-16k-0613'
# ]
# openai_models = [x for x in openai_models if x.id not in deprecated]
# # order by id string, reversed so new ones are showed first
# openai_models = sorted(openai_models, key=lambda x: x.id, reverse=True)
# except Exception as e:
# st.error(f"Invalid key. Please check your OpenAI key. {e}")
# return
# gpt_list = [x.id for x in openai_models if 'gpt' in x.id]
# print('gpt_list', gpt_list)
# embeddings_list = [x.id for x in openai_models if 'embedding' in x.id]
# index_model = gpt_list.index(sv.generation_model.value) if sv.generation_model.value in gpt_list else 0
# model_change = st.selectbox("OpenAI generative model", gpt_list, index=index_model)
# st.caption("Note that not all models will have the same token limit, so the information on the workflow screen might be inaccurate.")
# if model_change != sv.generation_model.value:
# sv.generation_model.value = model_change
# st.rerun()
st.header("OpenAI Type")
st.markdown("Select the OpenAI type you want to use.")
types = ["OpenAI", "Azure OpenAI"]
type_input = st.radio("OpenAI Type", types, index=types.index(openai.get_openai_type()) or 0)
type = openai.get_openai_type()
if type != type_input:
on_change(secrets_handler, openai_type_key, type_input)()
st.rerun()

# embedding_model = embeddings_list.index(sv.embedding_model.value) if sv.embedding_model.value in embeddings_list else 0
# embedding_model_change = st.selectbox("OpenAI embedding model", embeddings_list, index=embedding_model)
# if embedding_model_change != sv.embedding_model.value:
# sv.embedding_model.value = embedding_model_change
# st.rerun()
if type_input == "Azure OpenAI":
col1, col2 = st.columns(2)
with col1:
endpoint = st.text_input("Azure OpenAI Endpoint", type="password", value=openai.get_azure_openai_endpoint())
if endpoint != openai.get_azure_openai_endpoint():
on_change(secrets_handler, openai_endpoint_key, endpoint)()
st.rerun()

with col2:
version = st.text_input("Azure OpenAI Version", value=openai.get_azure_openai_version())
if version != openai.get_azure_openai_version():
on_change(secrets_handler, openai_version_key, version)()
st.rerun()

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion app/util/AI_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ def generate_text_from_message_list(messages, placeholder=None, prefix='', model
placeholder.markdown(prefix + response, unsafe_allow_html=True)
except Exception as e:
print(f'Error generating from message list: {e}')
raise Exception(f'Problem in OpenAI response. {e.message}')
raise Exception(f'Problem in OpenAI response. {e}')
return response
37 changes: 27 additions & 10 deletions app/util/openai_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from openai import AzureOpenAI

key = 'openaikey'
openai_type_key = 'openai_typekey'
openai_version_key = 'openai_versionkey'
openai_endpoint_key = 'openai_endpointkey'
class _OpenAI:
_instance = None
_key = None
Expand All @@ -22,20 +25,34 @@ def client(self):
if api_key != self._key:
self._key = api_key
try:
self._instance = get_openai_type(self._key)
self._instance = self.get_openai_api()
except Exception as e:
raise Exception(f'OpenAI client not created: {e}')


return self._instance

def get_openai_type(self):
environ = os.environ['OPENAI_TYPE'] if 'OPENAI_TYPE' in os.environ else None
secret = self._secrets.get_secret(openai_type_key)
return secret if len(secret) > 0 else environ

def get_azure_openai_version(self):
environ = os.environ['AZURE_OPENAI_VERSION'] if 'AZURE_OPENAI_VERSION' in os.environ else None
secret = self._secrets.get_secret(openai_version_key)
return secret if len(secret) > 0 else environ

def get_azure_openai_endpoint(self):
environ = os.environ['AZURE_OPENAI_ENDPOINT'] if 'AZURE_OPENAI_ENDPOINT' in os.environ else None
secret = self._secrets.get_secret(openai_endpoint_key)
return secret if len(secret) > 0 else environ

def get_openai_api(self):
if self.get_openai_type() == "Azure OpenAI":
print('self._key', self._key)
return AzureOpenAI(api_key=self._key, azure_endpoint=self.get_azure_openai_endpoint(), api_version=self.get_azure_openai_version())
else:
return OpenAI(api_key=self._key)

def get_key_env():
return os.environ['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in os.environ else ''

def get_openai_type(key, endpoint = None, api_version = None):
if 'OPENAI_TYPE' in os.environ and os.environ['OPENAI_TYPE'] == "AZURE":
endpoint = os.environ['AZURE_OPENAI_ENDPOINT'] if 'AZURE_OPENAI_ENDPOINT' in os.environ else None
api_version = os.environ['AZURE_OPENAI_VERSION'] if 'AZURE_OPENAI_VERSION' in os.environ else None
return AzureOpenAI(api_key=key, azure_endpoint=endpoint, api_version=api_version)
else:
return OpenAI(api_key=key)
return os.environ['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in os.environ else ''

0 comments on commit 42661e0

Please sign in to comment.