Skip to content

Commit

Permalink
changes in report page and add protected mode to record matching
Browse files Browse the repository at this point in the history
  • Loading branch information
dayesouza committed Apr 16, 2024
1 parent e9dcaeb commit 1bfabd7
Show file tree
Hide file tree
Showing 19 changed files with 51 additions and 40 deletions.
5 changes: 5 additions & 0 deletions app/components/app_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import components.app_user as au
import components.app_terminator as at
import components.app_openai as ao
import components.app_mode as am

def load_multipage_app():
#Load user if logged in
Expand All @@ -17,5 +18,9 @@ def load_multipage_app():
app_openai = ao.app_openai()
app_openai.api_info()

#Protected mode
app_mode = am.app_mode()
app_mode.config()

add_styles()

20 changes: 20 additions & 0 deletions app/components/app_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
import streamlit as st

from util.session_variables import SessionVariables

class app_mode:
sv = None

def __init__(self, sv = None):
if sv is not None:
self.sv = sv
else:
self.sv = SessionVariables('home')

def config(self):
mode = st.sidebar.toggle("Protected mode", value=self.sv.protected_mode.value, help="Prevent entity identification on screen.")
if mode != self.sv.protected_mode.value:
self.sv.protected_mode.value = mode
st.rerun()

1 change: 1 addition & 0 deletions app/util/session_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ def __init__(self, prefix = ''):
self.username = sv.SessionVariable('')
self.generation_model = sv.SessionVariable('gpt-4-turbo')
self.embedding_model = sv.SessionVariable('text-embedding-ada-002')
self.protected_mode = sv.SessionVariable(False)
19 changes: 7 additions & 12 deletions app/util/ui_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,17 @@ def report_download_ui(report_var, name):
with c2:
add_download_pdf(f'{name}.pdf', report_data, f'Download AI {spaced_name} as PDF')

def generative_ai_component(system_prompt_var, instructions_var, variables):
def generative_ai_component(system_prompt_var, variables):
st.markdown('##### Generative AI instructions')
with st.expander('Edit AI System Prompt (advanced)', expanded=False):
with st.expander('Edit AI System Prompt (advanced)', expanded=True):
instructions_text = st.text_area('Contents of System Prompt used to generate AI outputs.', value=system_prompt_var.value["user_prompt"], height=200)
if system_prompt_var.value["user_prompt"] != instructions_text:
system_prompt_var.value["user_prompt"] = instructions_text
st.rerun()
reset_prompt = st.button('Reset to default')

value_area = st.text_area('Instructions (optional - use to guide output)', value=instructions_var.value, height=100)
instructions_var.value = value_area
variables['instructions'] = instructions_var.value

st.warning('This app uses AI and may not be error-free. Please verify critical details independently.')

full_prompt = ' '.join([
system_prompt_var.value["report_prompt"],
system_prompt_var.value["user_prompt"],
Expand All @@ -85,22 +83,19 @@ def generative_ai_component(system_prompt_var, instructions_var, variables):
st.warning(message)
return generate, messages, reset_prompt

def generative_batch_ai_component(system_prompt_var, instructions_var, variables, batch_name, batch_val, batch_size):
def generative_batch_ai_component(system_prompt_var, variables, batch_name, batch_val, batch_size):
st.markdown('##### Generative AI instructions')
with st.expander('Edit AI System Prompt (advanced)', expanded=False):
with st.expander('Edit AI System Prompt (advanced)', expanded=True):
instructions_text = st.text_area('Contents of System Prompt used to generate AI outputs.', value=system_prompt_var.value["user_prompt"], height=200)
system_prompt_var.value["user_prompt"] = instructions_text
reset_prompt = st.button('Reset to default')

value_area = st.text_area('Instructions (optional - use to guide output)', value=instructions_var.value, height=100)
instructions_var.value = value_area

st.warning('This app uses AI and may not be error-free. Please verify critical details independently.')
batch_offset = 0
batch_count_raw = (len(batch_val) // batch_size)
batch_count_remaining = (len(batch_val) % batch_size)
batch_count = batch_count_raw + 1 if batch_count_remaining != 0 else batch_count_raw
batch_messages = []
variables['instructions'] = instructions_var.value

full_prompt = ' '.join([
system_prompt_var.value["report_prompt"],
Expand Down
4 changes: 0 additions & 4 deletions app/workflows/attribute_patterns/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@
{attribute_counts}
Additional instructions:
{instructions}
"""

user_prompt = """\
Expand Down
1 change: 0 additions & 1 deletion app/workflows/attribute_patterns/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, prefix):
self.attribute_min_count = SessionVariable(0, prefix)
self.attribute_suppress_zeros = SessionVariable(False, prefix)
self.attribute_last_suppress_zeros = SessionVariable(False, prefix)
self.attribute_instructions = SessionVariable('', prefix)
self.attribute_system_prompt = SessionVariable(prompts.list_prompts, prefix)
self.attribute_final_df = SessionVariable(pd.DataFrame(), prefix)
self.attribute_report = SessionVariable('', prefix)
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/attribute_patterns/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def create():
'attribute_counts': sv.attribute_selected_pattern_att_counts.value.to_csv(index=False)
}

generate, messages, reset = util.ui_components.generative_ai_component(sv.attribute_system_prompt, sv.attribute_instructions, variables)
generate, messages, reset = util.ui_components.generative_ai_component(sv.attribute_system_prompt, variables)
if reset:
sv.attribute_system_prompt.value["user_prompt"] = prompts.user_prompt
st.rerun()
Expand Down
3 changes: 0 additions & 3 deletions app/workflows/group_narratives/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
{dataset}
Additional instructions:
{instructions}
"""

user_prompt = """\
Expand Down
1 change: 0 additions & 1 deletion app/workflows/group_narratives/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self, prefix):
self.narrative_description = SessionVariable('', prefix)
self.narrative_top_groups = SessionVariable(0, prefix)
self.narrative_top_attributes = SessionVariable(0, prefix)
self.narrative_instructions = SessionVariable('', prefix)
self.narrative_report = SessionVariable('', prefix)
self.narrative_system_prompt = SessionVariable(prompts.list_prompts, prefix)
self.narrative_subject_identifier = SessionVariable('', prefix)
2 changes: 1 addition & 1 deletion app/workflows/group_narratives/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def create():
'dataset': fdf.to_csv(index=False, encoding='utf-8-sig'),
'filters': filter_description
}
generate, messages, reset = util.ui_components.generative_ai_component(sv.narrative_system_prompt, sv.narrative_instructions, variables)
generate, messages, reset = util.ui_components.generative_ai_component(sv.narrative_system_prompt, variables)
if reset:
sv.narrative_system_prompt.value["user_prompt"] = prompts.user_prompt
st.rerun()
Expand Down
4 changes: 0 additions & 4 deletions app/workflows/question_answering/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@
{outline}
Additional instructions:
{instructions}
"""

list_prompts = {
Expand Down
1 change: 0 additions & 1 deletion app/workflows/question_answering/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@ def __init__(self, prefix):
self.answering_matches = SessionVariable('', prefix)
self.answering_source_diversity = SessionVariable(1, prefix)
self.answering_question_history = SessionVariable([], prefix)
self.answering_instructions = SessionVariable('', prefix)
self.answering_system_prompt = SessionVariable(prompts.list_prompts, prefix)
2 changes: 1 addition & 1 deletion app/workflows/question_answering/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def create():
'outline': sv.answering_matches.value,
'source_diversity': sv.answering_source_diversity.value
}
generate, messages, reset = util.ui_components.generative_ai_component(sv.answering_system_prompt, sv.answering_instructions, variables)
generate, messages, reset = util.ui_components.generative_ai_component(sv.answering_system_prompt, variables)
if reset:
sv.answering_system_prompt.value["user_prompt"] = prompts.user_prompt
st.rerun()
Expand Down
4 changes: 0 additions & 4 deletions app/workflows/record_matching/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
If names are in a language other than English, consider whether the English translations are generic descriptive terms (less likely to be related) or distinctive (more likely to be related).
Additional instructions:
{instructions}
=== TASK ===
Group data:
Expand Down
1 change: 0 additions & 1 deletion app/workflows/record_matching/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,4 @@ def __init__(self, prefix):
self.matching_last_sentence_pair_embedding_threshold = SessionVariable(0.05, prefix)
self.matching_evaluations = SessionVariable(pl.DataFrame(), prefix)
self.matching_system_prompt = SessionVariable(prompts.list_prompts, prefix)
self.matching_instructions = SessionVariable('', prefix)

15 changes: 14 additions & 1 deletion app/workflows/record_matching/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import workflows.record_matching.functions as functions
import workflows.record_matching.config as config
import workflows.record_matching.variables as vars
import util.session_variables as home_vars
import util.Embedder
import util.ui_components

embedder = util.Embedder.create_embedder(config.cache_dir)

def create():
sv = vars.SessionVariables('record_matching')
sv_home = home_vars.SessionVariables('home')

if not os.path.exists(config.outputs_dir):
os.makedirs(config.outputs_dir)
Expand Down Expand Up @@ -138,6 +140,7 @@ def att_ui(i):
record_distance = st.number_input('Matching record distance (max)', min_value=0.0, max_value=1.0, step=0.01, value=sv.matching_sentence_pair_embedding_threshold.value, help='The maximum cosine distance between two records in the embedding space for them to be considered a match. Lower values will result in fewer closer matches overall.')
with b2:
name_similarity = st.number_input('Matching name similarity (min)', min_value=0.0, max_value=1.0, step=0.01, value=sv.matching_sentence_pair_jaccard_threshold.value, help='The minimum Jaccard similarity between the character trigrams of the names of two records for them to be considered a match. Higher values will result in fewer closer name matches.')

if st.button('Detect record groups', use_container_width=True):
if record_distance != sv.matching_sentence_pair_embedding_threshold.value:
sv.matching_sentence_pair_embedding_threshold.value = record_distance
Expand Down Expand Up @@ -291,6 +294,16 @@ def att_ui(i):
sv.matching_matches_df.value = sv.matching_matches_df.value.sort(by=['Name similarity', 'Group ID'], descending=[False, False])
# # keep all records linked to a group ID if any record linked to that ID has dataset GD or ILM
# sv.matching_matches_df.value = sv.matching_matches_df.value.filter(pl.col('Group ID').is_in(sv.matching_matches_df.value.filter(pl.col('Dataset').is_in(['GD', 'ILM']))['Group ID'].unique()))
data = sv.matching_matches_df.value
unique_names = data['Entity name'].unique()
#verify if the names are already in this format: Entity_1, Entity_2, etc
pattern = f'^Entity_\d+$'
matches = unique_names.str.contains(pattern)
all_matches = matches.all()
if not all_matches and sv_home.protected_mode.value:
for i, name in enumerate(unique_names, start=1):
data = data.with_columns(data['Entity name'].replace(name, 'Entity_{}'.format(i)))
sv.matching_matches_df.value = data
st.rerun()
if len(sv.matching_matches_df.value) > 0:
st.markdown(f'Identified **{len(sv.matching_matches_df.value)}** record groups.')
Expand All @@ -305,7 +318,7 @@ def att_ui(i):
with b1:
batch_size = 100
data = sv.matching_matches_df.value.drop(['Entity ID', 'Dataset', 'Name similarity']).to_pandas()
generate, batch_messages, reset = util.ui_components.generative_batch_ai_component(sv.matching_system_prompt, sv.matching_instructions, {}, 'data', data, batch_size)
generate, batch_messages, reset = util.ui_components.generative_batch_ai_component(sv.matching_system_prompt, {}, 'data', data, batch_size)
if reset:
sv.matching_system_prompt.value["user_prompt"] = prompts.user_prompt
st.rerun()
Expand Down
3 changes: 0 additions & 3 deletions app/workflows/risk_networks/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
- Evaluate the likelihood that different entity nodes are in fact the same real-world entity.
- If there is a selected entity and there are risk flags in the network, evaluate the risk exposure for the selected entity.
Additional instructions:
{instructions}
=== TASK ===
Expand Down
1 change: 0 additions & 1 deletion app/workflows/risk_networks/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(self, prefix):
self.network_attributes_list = SessionVariable([], prefix)
self.network_additional_trimmed_attributes = SessionVariable([], prefix)
self.network_system_prompt = SessionVariable(prompts.list_prompts, prefix)
self.network_instructions = SessionVariable('', prefix)
self.network_report = SessionVariable('', prefix)
self.network_merged_links_df = SessionVariable([], prefix)
self.network_merged_nodes_df = SessionVariable([], prefix)
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/risk_networks/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def create():
'network_edges': sv.network_merged_links_df.value.to_csv(index=False)
}
sv.network_system_prompt.value = prompts.list_prompts
generate, messages, reset = util.ui_components.generative_ai_component(sv.network_system_prompt, sv.network_instructions, variables)
generate, messages, reset = util.ui_components.generative_ai_component(sv.network_system_prompt, variables)
if reset:
sv.network_system_prompt.value["user_prompt"] = prompts.user_prompt
st.rerun()
Expand Down

0 comments on commit 1bfabd7

Please sign in to comment.