Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Darren Edge committed Apr 17, 2024
2 parents b675b13 + 4273845 commit b80f48b
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 40 deletions.
1 change: 1 addition & 0 deletions app/workflows/attribute_patterns/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def create():
sv.attribute_selected_pattern.value = selected_pattern
sv.attribute_selected_pattern_period.value = selected_pattern_period
sv.attribute_report.value = ''
sv.attribute_report_validation.value = {}
st.rerun()

st.markdown('**Selected pattern: ' + selected_pattern + ' (' + selected_pattern_period + ')**')
Expand Down
1 change: 1 addition & 0 deletions app/workflows/question_answering/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def create():
sv.answering_next_q_id.value = 1
sv.answering_surface_questions.value = {}
sv.answering_deeper_questions.value = {}
sv.answering_report_validation.value = {}
sv.answering_target_matches.value = answering_target_matches
sv.answering_source_diversity.value = answering_source_diversity
sv.answering_last_lazy_question.value = question
Expand Down
43 changes: 28 additions & 15 deletions app/workflows/record_matching/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def att_ui(i):
name_similarity = st.number_input('Matching name similarity (min)', min_value=0.0, max_value=1.0, step=0.01, value=sv.matching_sentence_pair_jaccard_threshold.value, help='The minimum Jaccard similarity between the character trigrams of the names of two records for them to be considered a match. Higher values will result in fewer closer name matches.')

if st.button('Detect record groups', use_container_width=True):
sv.matching_evaluations.value = pl.DataFrame()
sv.matching_report_validation.value = {}
if record_distance != sv.matching_sentence_pair_embedding_threshold.value:
sv.matching_sentence_pair_embedding_threshold.value = record_distance
if name_similarity != sv.matching_sentence_pair_jaccard_threshold.value:
Expand Down Expand Up @@ -296,24 +298,21 @@ def att_ui(i):
sv.matching_matches_df.value = sv.matching_matches_df.value.sort(by=['Name similarity', 'Group ID'], descending=[False, False])
# # keep all records linked to a group ID if any record linked to that ID has dataset GD or ILM
# sv.matching_matches_df.value = sv.matching_matches_df.value.filter(pl.col('Group ID').is_in(sv.matching_matches_df.value.filter(pl.col('Dataset').is_in(['GD', 'ILM']))['Group ID'].unique()))
data = sv.matching_matches_df.value
unique_names = data['Entity name'].unique()
#verify if the names are already in this format: Entity_1, Entity_2, etc
pattern = f'^Entity_\d+$'
matches = unique_names.str.contains(pattern)
all_matches = matches.all()
if not all_matches and sv_home.protected_mode.value:
for i, name in enumerate(unique_names, start=1):
data = data.with_columns(data['Entity name'].replace(name, 'Entity_{}'.format(i)))
sv.matching_matches_df.value = data

st.rerun()
if len(sv.matching_matches_df.value) > 0:
st.markdown(f'Identified **{len(sv.matching_matches_df.value)}** record groups.')
with c2:
data = sv.matching_matches_df.value
st.markdown('##### Record groups')
if len(sv.matching_matches_df.value) > 0:
st.dataframe(sv.matching_matches_df.value, height=700, use_container_width=True, hide_index=True)
st.download_button('Download record groups', data=sv.matching_matches_df.value.write_csv(), file_name='record_groups.csv', mime='text/csv')
if sv_home.protected_mode.value:
unique_names = sv.matching_matches_df.value['Entity name'].unique()
for i, name in enumerate(unique_names, start=1):
data = data.with_columns(data['Entity name'].replace(name, 'Entity_{}'.format(i)))

st.dataframe(data, height=700, use_container_width=True, hide_index=True)
st.download_button('Download record groups', data=data.write_csv(), file_name='record_groups.csv', mime='text/csv')

with evaluate_tab:
b1, b2 = st.columns([2, 3])
Expand All @@ -336,17 +335,31 @@ def att_ui(i):
response = util.AI_API.generate_text_from_message_list(messages, placeholder, prefix=prefix)
if len(response.strip()) > 0:
prefix = prefix + response + '\n'

result = prefix.replace('```\n', '').strip()
sv.matching_evaluations.value = pl.read_csv(io.StringIO(result), read_csv_options={"truncate_ragged_lines": True})

validation, messages_to_llm = util.ui_components.validate_ai_report(messages, sv.matching_evaluations.value)
if sv_home.protected_mode.value:
unique_names = sv.matching_matches_df.value['Entity name'].unique()
for i, name in enumerate(unique_names, start=1):
result = result.replace(name, 'Entity_{}'.format(i))

csv = pl.read_csv(io.StringIO(result))
sv.matching_evaluations.value = csv.drop_nulls()

#get 30 random dows to evaluate
data_to_validate = sv.matching_evaluations.value
if len(sv.matching_evaluations.value) > 30:
data_to_validate = sv.matching_evaluations.value.sample(n=30)

validation, messages_to_llm = util.ui_components.validate_ai_report(messages, data_to_validate)
sv.matching_report_validation.value = json.loads(validation)
sv.matching_report_validation_messages.value = messages_to_llm
st.rerun()
else:
if len(sv.matching_evaluations.value) == 0:
gen_placeholder.warning('Press the Generate button to create an AI report for the current record matches.')
placeholder.empty()

if len(sv.matching_evaluations.value) > 0:
st.dataframe(sv.matching_evaluations.value.to_pandas(), height=700, use_container_width=True, hide_index=True)
jdf = sv.matching_matches_df.value.join(sv.matching_evaluations.value, on='Group ID', how='inner')
Expand All @@ -361,6 +374,6 @@ def att_ui(i):
obj = json.dumps({
"message": sv.matching_report_validation_messages.value,
"result": sv.matching_report_validation.value,
"report": sv.matching_evaluations.value
"report": pd.DataFrame(sv.matching_evaluations.value).to_json()
}, indent=4)
st.download_button('Download faithfulness evaluation', use_container_width=True, data=str(obj), file_name=f'matching_{get_current_time}_messages.json', mime='text/json')
6 changes: 3 additions & 3 deletions app/workflows/risk_networks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import colorsys
import numpy as np
from collections import defaultdict
from streamlit_agraph import Config, Edge, Node, agraph
from streamlit_agraph import Config, Edge, Node

import workflows.risk_networks.config as config

Expand Down Expand Up @@ -61,6 +61,7 @@ def get_type_color(node_type, is_flagged, attribute_types):
comm = G.nodes[node]['network'] if 'network' in G.nodes[node] else ''
label = '\n'.join(vals) + '\n(' + config.list_sep.join(atts) + ')'
d_risk = G.nodes[node]['flags']

nodes.append(
Node(
title=node + f'\nFlags: {d_risk}',
Expand All @@ -82,8 +83,7 @@ def get_type_color(node_type, is_flagged, attribute_types):
physics=True,
hierarchical=False
)
return_value = agraph(nodes=nodes, edges=edges, config=g_config) # type: ignore
return return_value
return nodes, edges, g_config # type: ignore

def merge_nodes(G, can_merge_fn):
nodes = list(G.nodes()) # may change during iteration
Expand Down
3 changes: 2 additions & 1 deletion app/workflows/risk_networks/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ def __init__(self, prefix):
self.network_mean_flagged_flags = SessionVariable(0, prefix)
self.network_risk_exposure = SessionVariable('', prefix)
self.network_last_show_entities = SessionVariable(False, prefix)
self.network_last_show_groups = SessionVariable(False, prefix)
self.network_last_show_groups = SessionVariable(False, prefix)
self.network_attributes_protected = SessionVariable([], prefix)
Loading

0 comments on commit b80f48b

Please sign in to comment.