diff --git a/app/components/app_loader.py b/app/components/app_loader.py index 402e8fc6..5525c2fe 100644 --- a/app/components/app_loader.py +++ b/app/components/app_loader.py @@ -3,6 +3,7 @@ import components.app_user as au import components.app_terminator as at import components.app_openai as ao +import components.app_mode as am def load_multipage_app(): #Load user if logged in @@ -17,5 +18,9 @@ def load_multipage_app(): app_openai = ao.app_openai() app_openai.api_info() + #Protected mode + app_mode = am.app_mode() + app_mode.config() + add_styles() diff --git a/app/components/app_mode.py b/app/components/app_mode.py new file mode 100644 index 00000000..ab8c9bcf --- /dev/null +++ b/app/components/app_mode.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. All rights reserved. +import streamlit as st + +from util.session_variables import SessionVariables + +class app_mode: + sv = None + + def __init__(self, sv = None): + if sv is not None: + self.sv = sv + else: + self.sv = SessionVariables('home') + + def config(self): + mode = st.sidebar.toggle("Protected mode", value=self.sv.protected_mode.value, help="Prevent entity identification on screen.") + if mode != self.sv.protected_mode.value: + self.sv.protected_mode.value = mode + st.rerun() + \ No newline at end of file diff --git a/app/util/session_variables.py b/app/util/session_variables.py index a7135e92..e68485ae 100644 --- a/app/util/session_variables.py +++ b/app/util/session_variables.py @@ -12,3 +12,4 @@ def __init__(self, prefix = ''): self.username = sv.SessionVariable('') self.generation_model = sv.SessionVariable('gpt-4-turbo') self.embedding_model = sv.SessionVariable('text-embedding-ada-002') + self.protected_mode = sv.SessionVariable(False) diff --git a/app/util/ui_components.py b/app/util/ui_components.py index c87f133b..f1d4bd16 100644 --- a/app/util/ui_components.py +++ b/app/util/ui_components.py @@ -52,19 +52,17 @@ def report_download_ui(report_var, name): with c2: add_download_pdf(f'{name}.pdf', report_data, f'Download AI {spaced_name} as PDF') -def generative_ai_component(system_prompt_var, instructions_var, variables): +def generative_ai_component(system_prompt_var, variables): st.markdown('##### Generative AI instructions') - with st.expander('Edit AI System Prompt (advanced)', expanded=False): + with st.expander('Edit AI System Prompt (advanced)', expanded=True): instructions_text = st.text_area('Contents of System Prompt used to generate AI outputs.', value=system_prompt_var.value["user_prompt"], height=200) if system_prompt_var.value["user_prompt"] != instructions_text: system_prompt_var.value["user_prompt"] = instructions_text st.rerun() reset_prompt = st.button('Reset to default') - value_area = st.text_area('Instructions (optional - use to guide output)', value=instructions_var.value, height=100) - instructions_var.value = value_area - variables['instructions'] = instructions_var.value - + st.warning('This app uses AI and may not be error-free. Please verify critical details independently.') + full_prompt = ' '.join([ system_prompt_var.value["report_prompt"], system_prompt_var.value["user_prompt"], @@ -85,22 +83,19 @@ def generative_ai_component(system_prompt_var, instructions_var, variables): st.warning(message) return generate, messages, reset_prompt -def generative_batch_ai_component(system_prompt_var, instructions_var, variables, batch_name, batch_val, batch_size): +def generative_batch_ai_component(system_prompt_var, variables, batch_name, batch_val, batch_size): st.markdown('##### Generative AI instructions') - with st.expander('Edit AI System Prompt (advanced)', expanded=False): + with st.expander('Edit AI System Prompt (advanced)', expanded=True): instructions_text = st.text_area('Contents of System Prompt used to generate AI outputs.', value=system_prompt_var.value["user_prompt"], height=200) system_prompt_var.value["user_prompt"] = instructions_text reset_prompt = st.button('Reset to default') - value_area = st.text_area('Instructions (optional - use to guide output)', value=instructions_var.value, height=100) - instructions_var.value = value_area - + st.warning('This app uses AI and may not be error-free. Please verify critical details independently.') batch_offset = 0 batch_count_raw = (len(batch_val) // batch_size) batch_count_remaining = (len(batch_val) % batch_size) batch_count = batch_count_raw + 1 if batch_count_remaining != 0 else batch_count_raw batch_messages = [] - variables['instructions'] = instructions_var.value full_prompt = ' '.join([ system_prompt_var.value["report_prompt"], diff --git a/app/workflows/attribute_patterns/prompts.py b/app/workflows/attribute_patterns/prompts.py index 0690ae2f..602f7705 100644 --- a/app/workflows/attribute_patterns/prompts.py +++ b/app/workflows/attribute_patterns/prompts.py @@ -27,10 +27,6 @@ {attribute_counts} - -Additional instructions: - -{instructions} """ user_prompt = """\ diff --git a/app/workflows/attribute_patterns/variables.py b/app/workflows/attribute_patterns/variables.py index 5e05dd8a..c9d09120 100644 --- a/app/workflows/attribute_patterns/variables.py +++ b/app/workflows/attribute_patterns/variables.py @@ -27,7 +27,6 @@ def __init__(self, prefix): self.attribute_min_count = SessionVariable(0, prefix) self.attribute_suppress_zeros = SessionVariable(False, prefix) self.attribute_last_suppress_zeros = SessionVariable(False, prefix) - self.attribute_instructions = SessionVariable('', prefix) self.attribute_system_prompt = SessionVariable(prompts.list_prompts, prefix) self.attribute_final_df = SessionVariable(pd.DataFrame(), prefix) self.attribute_report = SessionVariable('', prefix) diff --git a/app/workflows/attribute_patterns/workflow.py b/app/workflows/attribute_patterns/workflow.py index 45755ca2..8b4cc82e 100644 --- a/app/workflows/attribute_patterns/workflow.py +++ b/app/workflows/attribute_patterns/workflow.py @@ -160,7 +160,7 @@ def create(): 'attribute_counts': sv.attribute_selected_pattern_att_counts.value.to_csv(index=False) } - generate, messages, reset = util.ui_components.generative_ai_component(sv.attribute_system_prompt, sv.attribute_instructions, variables) + generate, messages, reset = util.ui_components.generative_ai_component(sv.attribute_system_prompt, variables) if reset: sv.attribute_system_prompt.value["user_prompt"] = prompts.user_prompt st.rerun() diff --git a/app/workflows/group_narratives/prompts.py b/app/workflows/group_narratives/prompts.py index 580b68aa..8253f0d9 100644 --- a/app/workflows/group_narratives/prompts.py +++ b/app/workflows/group_narratives/prompts.py @@ -22,9 +22,6 @@ {dataset} -Additional instructions: - -{instructions} """ user_prompt = """\ diff --git a/app/workflows/group_narratives/variables.py b/app/workflows/group_narratives/variables.py index 6efb67c5..82d97f47 100644 --- a/app/workflows/group_narratives/variables.py +++ b/app/workflows/group_narratives/variables.py @@ -20,7 +20,6 @@ def __init__(self, prefix): self.narrative_description = SessionVariable('', prefix) self.narrative_top_groups = SessionVariable(0, prefix) self.narrative_top_attributes = SessionVariable(0, prefix) - self.narrative_instructions = SessionVariable('', prefix) self.narrative_report = SessionVariable('', prefix) self.narrative_system_prompt = SessionVariable(prompts.list_prompts, prefix) self.narrative_subject_identifier = SessionVariable('', prefix) diff --git a/app/workflows/group_narratives/workflow.py b/app/workflows/group_narratives/workflow.py index 72058b2a..1cae6fa3 100644 --- a/app/workflows/group_narratives/workflow.py +++ b/app/workflows/group_narratives/workflow.py @@ -190,7 +190,7 @@ def create(): 'dataset': fdf.to_csv(index=False, encoding='utf-8-sig'), 'filters': filter_description } - generate, messages, reset = util.ui_components.generative_ai_component(sv.narrative_system_prompt, sv.narrative_instructions, variables) + generate, messages, reset = util.ui_components.generative_ai_component(sv.narrative_system_prompt, variables) if reset: sv.narrative_system_prompt.value["user_prompt"] = prompts.user_prompt st.rerun() diff --git a/app/workflows/question_answering/prompts.py b/app/workflows/question_answering/prompts.py index 3453f060..2cb91afe 100644 --- a/app/workflows/question_answering/prompts.py +++ b/app/workflows/question_answering/prompts.py @@ -61,10 +61,6 @@ {outline} -Additional instructions: - -{instructions} - """ list_prompts = { diff --git a/app/workflows/question_answering/variables.py b/app/workflows/question_answering/variables.py index fc31a815..853345ca 100644 --- a/app/workflows/question_answering/variables.py +++ b/app/workflows/question_answering/variables.py @@ -35,5 +35,4 @@ def __init__(self, prefix): self.answering_matches = SessionVariable('', prefix) self.answering_source_diversity = SessionVariable(1, prefix) self.answering_question_history = SessionVariable([], prefix) - self.answering_instructions = SessionVariable('', prefix) self.answering_system_prompt = SessionVariable(prompts.list_prompts, prefix) diff --git a/app/workflows/question_answering/workflow.py b/app/workflows/question_answering/workflow.py index 3cbcf5eb..3d7e2da2 100644 --- a/app/workflows/question_answering/workflow.py +++ b/app/workflows/question_answering/workflow.py @@ -206,7 +206,7 @@ def create(): 'outline': sv.answering_matches.value, 'source_diversity': sv.answering_source_diversity.value } - generate, messages, reset = util.ui_components.generative_ai_component(sv.answering_system_prompt, sv.answering_instructions, variables) + generate, messages, reset = util.ui_components.generative_ai_component(sv.answering_system_prompt, variables) if reset: sv.answering_system_prompt.value["user_prompt"] = prompts.user_prompt st.rerun() diff --git a/app/workflows/record_matching/prompts.py b/app/workflows/record_matching/prompts.py index 8fe56eea..decd7558 100644 --- a/app/workflows/record_matching/prompts.py +++ b/app/workflows/record_matching/prompts.py @@ -14,10 +14,6 @@ If names are in a language other than English, consider whether the English translations are generic descriptive terms (less likely to be related) or distinctive (more likely to be related). -Additional instructions: - -{instructions} - === TASK === Group data: diff --git a/app/workflows/record_matching/variables.py b/app/workflows/record_matching/variables.py index 4be0a1d7..c050a013 100644 --- a/app/workflows/record_matching/variables.py +++ b/app/workflows/record_matching/variables.py @@ -19,5 +19,4 @@ def __init__(self, prefix): self.matching_last_sentence_pair_embedding_threshold = SessionVariable(0.05, prefix) self.matching_evaluations = SessionVariable(pl.DataFrame(), prefix) self.matching_system_prompt = SessionVariable(prompts.list_prompts, prefix) - self.matching_instructions = SessionVariable('', prefix) diff --git a/app/workflows/record_matching/workflow.py b/app/workflows/record_matching/workflow.py index 6898e67e..ec4aa8f2 100644 --- a/app/workflows/record_matching/workflow.py +++ b/app/workflows/record_matching/workflow.py @@ -13,6 +13,7 @@ import workflows.record_matching.functions as functions import workflows.record_matching.config as config import workflows.record_matching.variables as vars +import util.session_variables as home_vars import util.Embedder import util.ui_components @@ -20,6 +21,7 @@ def create(): sv = vars.SessionVariables('record_matching') + sv_home = home_vars.SessionVariables('home') if not os.path.exists(config.outputs_dir): os.makedirs(config.outputs_dir) @@ -138,6 +140,7 @@ def att_ui(i): record_distance = st.number_input('Matching record distance (max)', min_value=0.0, max_value=1.0, step=0.01, value=sv.matching_sentence_pair_embedding_threshold.value, help='The maximum cosine distance between two records in the embedding space for them to be considered a match. Lower values will result in fewer closer matches overall.') with b2: name_similarity = st.number_input('Matching name similarity (min)', min_value=0.0, max_value=1.0, step=0.01, value=sv.matching_sentence_pair_jaccard_threshold.value, help='The minimum Jaccard similarity between the character trigrams of the names of two records for them to be considered a match. Higher values will result in fewer closer name matches.') + if st.button('Detect record groups', use_container_width=True): if record_distance != sv.matching_sentence_pair_embedding_threshold.value: sv.matching_sentence_pair_embedding_threshold.value = record_distance @@ -291,6 +294,16 @@ def att_ui(i): sv.matching_matches_df.value = sv.matching_matches_df.value.sort(by=['Name similarity', 'Group ID'], descending=[False, False]) # # keep all records linked to a group ID if any record linked to that ID has dataset GD or ILM # sv.matching_matches_df.value = sv.matching_matches_df.value.filter(pl.col('Group ID').is_in(sv.matching_matches_df.value.filter(pl.col('Dataset').is_in(['GD', 'ILM']))['Group ID'].unique())) + data = sv.matching_matches_df.value + unique_names = data['Entity name'].unique() + #verify if the names are already in this format: Entity_1, Entity_2, etc + pattern = f'^Entity_\d+$' + matches = unique_names.str.contains(pattern) + all_matches = matches.all() + if not all_matches and sv_home.protected_mode.value: + for i, name in enumerate(unique_names, start=1): + data = data.with_columns(data['Entity name'].replace(name, 'Entity_{}'.format(i))) + sv.matching_matches_df.value = data st.rerun() if len(sv.matching_matches_df.value) > 0: st.markdown(f'Identified **{len(sv.matching_matches_df.value)}** record groups.') @@ -305,7 +318,7 @@ def att_ui(i): with b1: batch_size = 100 data = sv.matching_matches_df.value.drop(['Entity ID', 'Dataset', 'Name similarity']).to_pandas() - generate, batch_messages, reset = util.ui_components.generative_batch_ai_component(sv.matching_system_prompt, sv.matching_instructions, {}, 'data', data, batch_size) + generate, batch_messages, reset = util.ui_components.generative_batch_ai_component(sv.matching_system_prompt, {}, 'data', data, batch_size) if reset: sv.matching_system_prompt.value["user_prompt"] = prompts.user_prompt st.rerun() diff --git a/app/workflows/risk_networks/prompts.py b/app/workflows/risk_networks/prompts.py index 8c9ca1b0..06ae32ab 100644 --- a/app/workflows/risk_networks/prompts.py +++ b/app/workflows/risk_networks/prompts.py @@ -16,9 +16,6 @@ - Evaluate the likelihood that different entity nodes are in fact the same real-world entity. - If there is a selected entity and there are risk flags in the network, evaluate the risk exposure for the selected entity. -Additional instructions: - -{instructions} === TASK === diff --git a/app/workflows/risk_networks/variables.py b/app/workflows/risk_networks/variables.py index 6ac70601..689c6529 100644 --- a/app/workflows/risk_networks/variables.py +++ b/app/workflows/risk_networks/variables.py @@ -47,7 +47,6 @@ def __init__(self, prefix): self.network_attributes_list = SessionVariable([], prefix) self.network_additional_trimmed_attributes = SessionVariable([], prefix) self.network_system_prompt = SessionVariable(prompts.list_prompts, prefix) - self.network_instructions = SessionVariable('', prefix) self.network_report = SessionVariable('', prefix) self.network_merged_links_df = SessionVariable([], prefix) self.network_merged_nodes_df = SessionVariable([], prefix) diff --git a/app/workflows/risk_networks/workflow.py b/app/workflows/risk_networks/workflow.py index f41327fa..a6138064 100644 --- a/app/workflows/risk_networks/workflow.py +++ b/app/workflows/risk_networks/workflow.py @@ -485,7 +485,7 @@ def create(): 'network_edges': sv.network_merged_links_df.value.to_csv(index=False) } sv.network_system_prompt.value = prompts.list_prompts - generate, messages, reset = util.ui_components.generative_ai_component(sv.network_system_prompt, sv.network_instructions, variables) + generate, messages, reset = util.ui_components.generative_ai_component(sv.network_system_prompt, variables) if reset: sv.network_system_prompt.value["user_prompt"] = prompts.user_prompt st.rerun()