Skip to content

Commit

Permalink
add login with azure and db with user
Browse files Browse the repository at this point in the history
  • Loading branch information
dayesouza committed Mar 25, 2024
1 parent 58fde25 commit 698c56a
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 19 deletions.
29 changes: 29 additions & 0 deletions app/Home.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
import streamlit as st
import util.mermaid as mermaid
from streamlit_javascript import st_javascript
import util.session_variables

def get_user(sv):
if sv.mode.value != 'cloud':
return
css='''
[data-testid="stSidebarNavItems"] {
max-height: 100vh
}
'''
st.markdown(f'<style>{css}</style>', unsafe_allow_html=True)
js_code = """await fetch("/.auth/me")
.then(function(response) {return response.json();})
"""
return_value = st_javascript(js_code)

username = None
if return_value == 0:
pass # this is the result before the actual value is returned
elif isinstance(return_value, list) and len(return_value) > 0:
username = return_value[0]["user_id"]
sv.username.value = username
st.sidebar.write(f"Logged in as {username}")
else:
st.warning(f"Could not directly read username from azure active directory: {return_value}.")

def main():
st.set_page_config(layout="wide", initial_sidebar_state="expanded", page_title='Intelligence Toolkit | Home')
sv = util.session_variables.SessionVariables('home')
get_user(sv)

transparency_faq = open('./app/TransparencyFAQ.md', 'r').read()
st.markdown(transparency_faq + '\n\n' + f"""\
#### Which Intelligence Toolkit workflow is right for me and my data?
Expand Down
24 changes: 24 additions & 0 deletions app/util/Database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import duckdb

class Database:
def __init__(self, cache, db_name) -> None:
if not os.path.exists(cache):
os.makedirs(cache)

db_path = os.path.join(cache, f'{db_name}.db')
self.connection = duckdb.connect(database=db_path)

def create_table(self, name, attributes = []):
self.connection.execute(f"CREATE TABLE IF NOT EXISTS {name} ({', '.join(attributes)})")

def select_embedding_from_hash(self, hash_text, username = ''):
return self.connection.execute(f"SELECT embedding FROM embeddings WHERE hash_text = '{hash_text}' and username = '{username}'").fetchone()

def insert_into_embeddings(self, hash_text, embedding, username = ''):
self.connection.execute(f"INSERT INTO embeddings VALUES ('{username}','{hash_text}', {embedding})")

def execute(self, query):
return self.connection.execute(query)


31 changes: 15 additions & 16 deletions app/util/Embedder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from openai import OpenAI
import tiktoken
import os
import numpy as np
from util.Database import Database
import util.session_variables

gen_model = 'gpt-4-turbo-preview'
embed_model = 'text-embedding-3-small'
Expand All @@ -10,33 +11,32 @@
max_input_tokens = 128000
default_temperature = 0
max_embed_tokens = 8191
import duckdb


client = OpenAI()
encoder = tiktoken.get_encoding(text_encoder)

class Embedder:
def __init__(self, cache, model=embed_model, encoder=text_encoder, max_tokens=max_embed_tokens) -> None:
sv = util.session_variables.SessionVariables('home')
self.username = sv.username.value
self.model = model
self.encoder = tiktoken.get_encoding(encoder)
self.max_tokens = max_tokens
if not os.path.exists(cache):
os.makedirs(cache)
self.connection = duckdb.connect(database=f'{cache}\\embeddings.db')
self.connection.execute("CREATE TABLE IF NOT EXISTS embeddings (hash_text STRING, embedding DOUBLE[])")

self.connection = Database(cache, 'embeddings')
self.connection.create_table('embeddings', ['username STRING','hash_text STRING', 'embedding DOUBLE[]'])

def encode_all(self, texts):
final_embeddings = [None] * len(texts)
new_texts = []
for ix, text in enumerate(texts):
text = text.replace("\n", " ")
hsh = hash(text)
exists = self.connection.execute(f"SELECT embedding FROM embeddings WHERE hash_text = '{hsh}'").fetchone()
if not exists:
embeddings = self.connection.select_embedding_from_hash(hsh)
if not embeddings:
new_texts.append((ix, text))
else:
final_embeddings[ix] = np.array(exists)
final_embeddings[ix] = np.array(embeddings)
print(f'Got {len(new_texts)} new texts')
# split into batches of 2000
for i in range(0, len(new_texts), 2000):
Expand All @@ -45,26 +45,25 @@ def encode_all(self, texts):
embeddings = [x.embedding for x in client.embeddings.create(input = batch_texts, model=self.model).data]
for j, (ix, text) in enumerate(batch):
hsh = hash(text)
self.connection.execute(f"INSERT INTO embeddings VALUES ('{hsh}', {embeddings[j]})")
self.connection.insert_into_embeddings(hsh, embeddings[j])
final_embeddings[ix] = np.array(embeddings[j])
return np.array(final_embeddings)

def encode(self, text):
text = text.replace("\n", " ")
hsh = hash(text)
exists = self.connection.execute(f"SELECT embedding FROM embeddings WHERE hash_text = '{hsh}'").fetchone()
embeddings = self.connection.select_embedding_from_hash(hsh)

if exists:
return np.array(exists[0])
# return [float(x) for x in open(path, 'r').read().split('\n') if len(x) > 0]
if embeddings:
return np.array(embeddings[0])
else:
tokens = len(self.encoder.encode(text))
if tokens > self.max_tokens:
text = text[:self.max_tokens]
print('Truncated text to max tokens')
try:
embedding = client.embeddings.create(input = [text], model=self.model).data[0].embedding
self.connection.execute(f"INSERT INTO embeddings VALUES ('{hsh}', {embedding})")
self.connection.insert_into_embeddings(hsh, embedding)
return np.array(embedding)
except:
print(f'Error embedding text: {text}')
Expand Down
14 changes: 14 additions & 0 deletions app/util/session_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from util.session_variable import SessionVariable
import pandas as pd
import util.session_variable as sv
import os

class SessionVariables:

def __init__(self, prefix):
self.narrative_input_df = SessionVariable(pd.DataFrame(), prefix)
self.mode = sv.SessionVariable(os.environ.get("MODE", "dev"))
self.username = sv.SessionVariable('')



2 changes: 1 addition & 1 deletion app/workflows/question_answering/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cache_dir = '\\cache\\question_answering'
cache_dir = '.\\cache\\question_answering'

chunk_size = 5000
chunk_overlap = 0
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/record_matching/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
list_sep = '; '
max_rows_to_show = 1000
entity_label = 'Entity'
cache_dir = '\\cache\\record_matching'
cache_dir = '.\\cache\\record_matching'
outputs_dir = f'{cache_dir}\\outputs'

intro = """ \
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/risk_networks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
list_sep = '; '
max_rows_to_show = 1000
entity_label = 'ENTITY'
cache_dir = '\\cache\\risk_networks'
cache_dir = '.\\cache\\risk_networks'
outputs_dir = f'{cache_dir}\\outputs'

intro = """ \
Expand Down
Binary file modified requirements.txt
Binary file not shown.

0 comments on commit 698c56a

Please sign in to comment.