Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dayesouza committed Apr 17, 2024
2 parents 3b377f5 + 7b54c54 commit 68b2bf0
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 26 deletions.
2 changes: 1 addition & 1 deletion app/util/Embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def encode_all(self, texts):
hsh = hash(text)
list_all_embeddings.append((hsh, embeddings[j]))
final_embeddings[ix] = np.array(embeddings[j])
self.connection.insert_multiple_into_embeddings(list_all_embeddings)
# self.connection.insert_multiple_into_embeddings(list_all_embeddings)
pb.empty()
return np.array(final_embeddings)

Expand Down
30 changes: 28 additions & 2 deletions app/util/ui_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def prepare_input_df(workflow, input_df_var, processed_df_var, output_df_var, id
selected_cols = [col for col in input_df_var.value.columns.values if st.session_state[f'{workflow}_{col}'] == True]
processed_df_var.value = processed_df_var.value[['Subject ID']].copy()
for col in selected_cols:
processed_df_var.value[col] = st.session_state[f'{workflow}_binned_df'][col]
if col in st.session_state[f'{workflow}_binned_df'].columns.values:
processed_df_var.value[col] = st.session_state[f'{workflow}_binned_df'][col]

if selected_cols != st.session_state[f'{workflow}_last_attributes']:
processed_df_var.value = util.df_functions.fix_null_ints(processed_df_var.value)
Expand Down Expand Up @@ -386,6 +387,25 @@ def convert(x):
st.session_state[f'{workflow}_last_attributes'] = [] # hack to force second rerun and show any changes from binning
st.rerun()

with st.expander('Expand compound values', expanded=False):
options = [x for x in processed_df_var.value.columns.values if x != 'Subject ID']
selected_compound_cols = st.multiselect('Select compound columns to expand', options, help='Select the columns you want to expand into separate columns. If you do not select any columns, no expansion will be performed.')
col_delimiter = st.text_input('Column delimiter', value='', help='The character used to separate values in compound columns. If the delimiter is not present in a cell, the cell will be left unchanged.')
if st.button('Expand selected columns', key='expand_compound'):
bdf = st.session_state[f'{workflow}_binned_df']
for col in selected_compound_cols:
if col_delimiter != '':
# add each value as a separate column with a 1 if the value is present in the compound column and None otherwise
values = processed_df_var.value[col].apply(lambda x: [y.strip() for y in x.split(col_delimiter)] if type(x) == str else [])
unique_values = set([v for vals in values for v in vals])
for val in unique_values:
bdf[col+'_'+val] = values.apply(lambda x: 1 if val in x else None)
processed_df_var.value[col+'_'+val] = bdf[col+'_'+val]
bdf.drop(columns=[col], inplace=True)
processed_df_var.value.drop(columns=[col], inplace=True)
st.rerun()


with st.expander('Suppress insignificant attribute values', expanded=False):
if f'{workflow}_min_count' not in st.session_state.keys():
st.session_state[f'{workflow}_min_count'] = 0
Expand All @@ -400,13 +420,19 @@ def convert(x):

# remove any values that are less than the minimum count
if bdf[col].dtype == 'str':
print(f'Processing {col} as string')
bdf[col] = bdf[col].apply(lambda x: '' if x in value_counts and value_counts[x] < min_value else str(x))
elif bdf[col].dtype == 'float64':
print(f'Processing {col} as float')
bdf[col] = bdf[col].apply(lambda x: np.nan if x in value_counts and value_counts[x] < min_value else x)
elif bdf[col].dtype == 'int64':
print(f'Processing {col} as int')
bdf[col] = bdf[col].apply(lambda x: -sys.maxsize if x in value_counts and value_counts[x] < min_value else x)
bdf[col] = bdf[col].astype('Int64')
bdf[col] = bdf[col].replace(-sys.maxsize, np.nan)
else:
print(f'Processing {col} as string')
bdf[col] = bdf[col].apply(lambda x: '' if x in value_counts and value_counts[x] < min_value else str(x))

if f'{workflow}_suppress_zeros' not in st.session_state.keys():
st.session_state[f'{workflow}_suppress_zeros'] = False
Expand Down Expand Up @@ -472,6 +498,6 @@ def convert(x):

def validate_ai_report(messages, result, show_status = True):
if show_status:
st.status('Validating AI report and generating groundedness score...', expanded=False, state='running')
st.status('Validating AI report and generating faithfulness score...', expanded=False, state='running')
validation, messages_to_llm = util.AI_API.validate_report(messages, result)
return re.sub(r"```json\n|\n```", "", validation), messages_to_llm
2 changes: 1 addition & 1 deletion app/workflows/attribute_patterns/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,4 @@ def create():
"result": sv.attribute_report_validation.value,
"report": report_data
}, indent=4)
st.download_button('Download validation prompt', use_container_width=True, data=str(obj), file_name=f'attr_pattern_{get_current_time}_messages.json', mime='text/json')
st.download_button('Download faithfulness evaluation', use_container_width=True, data=str(obj), file_name=f'attr_pattern_{get_current_time}_messages.json', mime='text/json')
5 changes: 3 additions & 2 deletions app/workflows/data_synthesis/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def create():
# distinct_counts.append(1)
# else:
distinct_counts.append(len(distinct_values))
distinct_counts.sort()
common_level = max(distinct_counts[int(len(distinct_counts) * 0.5)], len(distinct_counts))
overall_att_count = sum(distinct_counts)
# calculate number of pairs of column values using combinatorics
num_observed_pairs = 0
num_common_pairs = 0
common_level = int((overall_att_count / num_cols) * math.sqrt(num_cols)) if num_cols > 0 else 0
for ix, ci in enumerate(att_cols):
for jx, cj in enumerate(att_cols[ix+1:]):
groups = wdf[[ci, cj]].dropna().groupby([ci, cj]).size()
Expand All @@ -67,7 +68,7 @@ def create():
st.markdown(f'### Synthesizability summary')
st.markdown(f'Number of selected columns: **{num_cols}**', help='This is the number of columns you selected for processing. The more columns you select, the harder it will be to synthesize data.')
st.markdown(f'Number of distinct attribute values: **{overall_att_count}**', help='This is the total number of distinct attribute values across all selected columns. The more distinct values, the harder it will be to synthesize data.')
st.markdown(f'Common pair threshold: **{common_level}**', help='This is the minimum number of records that must appear in a pair of column values for the pair to be considered common. The higher this number, the harder it will be to synthesize data. The value is set as int((overall_att_count / num_cols) * math.sqrt(num_cols)).')
st.markdown(f'Common pair threshold: **{common_level}**', help='This is the minimum number of records that must appear in a pair of column values for the pair to be considered common. The higher this number, the harder it will be to synthesize data. The value is set as max(median value count, num selected columns).')
st.markdown(f'Estimated synthesizability score: **{round(coverage, 4)}**', help=f'We define synthesizability as the proportion of observed pairs of values across selected columns that are common, appearing at least as many times as the number of columns. In this case, {num_common_pairs}/{num_observed_pairs} pairs appear at least {num_cols} times. The intuition here is that all combinations of attribute values in a synthetic record must be composed from common attribute pairs. **Rule of thumb**: Aim for a synthesizability score of **0.5** or higher.')
with generate_tab:
if len(sv.synthesis_sensitive_df.value) == 0:
Expand Down
1 change: 1 addition & 0 deletions app/workflows/group_narratives/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

user_prompt = """\
The report should be structured in markdown and use plain English accessible to non-native speakers and non-technical audiences.
Where possible, the text should add numeric counts, ranks, and deltas in parentheses to support its claims, but should avoid using complex column names directly.
"""
Expand Down
9 changes: 5 additions & 4 deletions app/workflows/group_narratives/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create():
groups = st.multiselect('Compare groups of records with different combinations of these attributes:', sorted_cols, default=sv.narrative_groups.value)
aggregates = st.multiselect('Using counts of these attributes:', sorted_cols, default=sv.narrative_aggregates.value)
temporal_options = [''] + sorted_cols
temporal = st.selectbox('Across levels of this temporal/ordinal attribute (optional):', temporal_options, index=temporal_options.index(sv.narrative_temporal.value))
temporal = st.selectbox('Across windows of this temporal/ordinal attribute (optional):', temporal_options, index=temporal_options.index(sv.narrative_temporal.value))

model = st.button('Create summary', disabled=len(groups) == 0 or len(aggregates) == 0)

Expand Down Expand Up @@ -79,12 +79,13 @@ def create():
# narrow df for model
id_vars = groups + [temporal] if temporal != '' else groups
ndf = wdf.melt(id_vars=id_vars, value_vars=aggregates, var_name='Attribute', value_name='Value')
ndf['Attribute Value'] = str(ndf['Attribute']) + ':' + str(ndf['Value'])

ndf.dropna(subset=['Value'], inplace=True)
ndf['Attribute Value'] = ndf.apply(lambda x : str(x['Attribute']) + ':' + str(x['Value']), axis=1)
temporal_atts = []

# create group df
gdf = wdf.melt(id_vars=groups, value_vars=['Subject ID'], var_name='Attribute', value_name='Value')

gdf['Attribute Value'] = gdf['Attribute'] + ':' + gdf['Value']
gdf = gdf.groupby(groups).size().reset_index(name='Group Count')
# Add group ranks
Expand Down Expand Up @@ -234,4 +235,4 @@ def create():
"result": sv.narrative_report_validation.value,
"report": sv.narrative_report.value
}, indent=4)
st.download_button('Download validation prompt', use_container_width=True, data=str(obj), file_name=f'narrative_{get_current_time}_messages.json', mime='text/json')
st.download_button('Download faithfulness evaluation', use_container_width=True, data=str(obj), file_name=f'narrative_{get_current_time}_messages.json', mime='text/json')
2 changes: 1 addition & 1 deletion app/workflows/question_answering/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,4 @@ def create():
"result": sv.answering_report_validation.value,
"report": sv.answering_lazy_answer_text.value
}, indent=4)
st.download_button('Download validation prompt', use_container_width=True, data=str(obj), file_name=f'qa_{get_current_time}_messages.json', mime='text/json')
st.download_button('Download faithfulness evaluation', use_container_width=True, data=str(obj), file_name=f'qa_{get_current_time}_messages.json', mime='text/json')
11 changes: 2 additions & 9 deletions app/workflows/record_matching/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,7 @@ def att_ui(i):
prefix = prefix + response + '\n'

result = prefix.replace('```\n', '').strip()

if sv_home.protected_mode.value:
unique_names = sv.matching_matches_df.value['Entity name'].unique()
for i, name in enumerate(unique_names, start=1):
#search for unique_names in result and change for its Entity_
result = result.replace(name, 'Entity_{}'.format(i))

sv.matching_evaluations.value = pl.read_csv(io.StringIO(result))
sv.matching_evaluations.value = pl.read_csv(io.StringIO(result), read_csv_options={"truncate_ragged_lines": True})

validation, messages_to_llm = util.ui_components.validate_ai_report(messages, sv.matching_evaluations.value)
sv.matching_report_validation.value = json.loads(validation)
Expand All @@ -369,4 +362,4 @@ def att_ui(i):
"result": sv.matching_report_validation.value,
"report": sv.matching_evaluations.value
}, indent=4)
st.download_button('Download validation prompt', use_container_width=True, data=str(obj), file_name=f'matching_{get_current_time}_messages.json', mime='text/json')
st.download_button('Download faithfulness evaluation', use_container_width=True, data=str(obj), file_name=f'matching_{get_current_time}_messages.json', mime='text/json')
10 changes: 5 additions & 5 deletions app/workflows/risk_networks/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@
"""

user_prompt = """\
The report should be structured in markdown and use plain English accessible to non-native speakers and non-technical audiences.
The report should be structured in markdown and use plain English accessible to non-native speakers and non-technical audiences.
Begin your response with the heading:
Begin your response with the heading:
"##### Evaluation of <Entity ID> in Network <Network ID>"
"##### Evaluation of <Entity ID> in Network <Network ID>"
if there is a selected entity, or else:
if there is a selected entity, or else:
"##### Evaluation of Entity Network <Network ID>"
"##### Evaluation of Entity Network <Network ID>"
"""

list_prompts = {
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/risk_networks/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,4 +616,4 @@ def create():
"result": sv.network_report_validation.value,
"report": sv.network_report.value
}, indent=4)
st.download_button('Download validation prompt', use_container_width=True, data=str(obj), file_name=f'networks_{get_current_time}_messages.json', mime='text/json')
st.download_button('Download faithfulness evaluation', use_container_width=True, data=str(obj), file_name=f'networks_{get_current_time}_messages.json', mime='text/json')

0 comments on commit 68b2bf0

Please sign in to comment.