Skip to content

Commit

Permalink
Fix data generation batch sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
Darren Edge committed Nov 10, 2024
1 parent 41c9621 commit d618a1a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
2 changes: 0 additions & 2 deletions app/util/schema_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def build_schema_ui(global_schema, last_filename):
jsn = loads(file.read())
for k, v in jsn.items():
global_schema[k] = v
print(f'Loaded schema: {global_schema}')
st.markdown('### Edit data schema')
generate_form_from_json_schema(
global_schema=global_schema,
Expand Down Expand Up @@ -255,7 +254,6 @@ def create_enum_ui(field_location, key, key_with_prefix, value):
value['enum'].pop(i)
st.rerun()
new_enum_value = st.text_input(f'New value', key=f'{key_with_prefix}_new_enum_{"_".join([str(x) for x in value["enum"]])}', value="")
print(new_enum_value)
if new_enum_value != "" and new_enum_value not in value['enum']:
if value['type'] == 'string':
value['enum'].append(new_enum_value)
Expand Down
1 change: 0 additions & 1 deletion app/workflows/generate_mock_data/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ async def create(sv: bds_variables.SessionVariables, workflow: None):
dl_placeholders.append(dl_placeholder)

def on_dfs_update(path_to_df):
print(path_to_df)
for ix, record_array in enumerate(sv.record_arrays.value):
with df_placeholders[ix]:
df = path_to_df[record_array]
Expand Down
38 changes: 25 additions & 13 deletions intelligence_toolkit/generate_mock_data/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,43 @@ async def generate_data(
callback_batch,
parallel_batches=5,
):
num_iterations = num_records_overall // (records_per_batch * parallel_batches)
record_arrays = extract_array_fields(data_schema)
primary_record_array = record_arrays[0]
generated_objects = []
first_object = generate_unseeded_data(
ai_configuration=ai_configuration,
generation_guidance=generation_guidance,
primary_record_array=primary_record_array,
total_records=parallel_batches,
total_records=records_per_batch,
data_schema=data_schema,
temperature=temperature,
)
first_object_json = loads(first_object)
current_object_json = {}
try:
first_object_json = loads(first_object)
except Exception as e:
msg = f"AI did not return a valid JSON response. Please try again. {e}"
raise ValueError(msg) from e
generated_objects.append(first_object_json)
current_object_json = first_object_json.copy()
dfs = {}
for i in range(num_iterations):
if i == 0:
sample_records = sample_from_record_array(
first_object_json, primary_record_array, records_per_batch
)
else:
sample_records = sample_from_record_array(
current_object_json, primary_record_array, parallel_batches
)
for record_array in record_arrays:
df = extract_df(current_object_json, record_array)
dfs[".".join(record_array)] = df
if df_update_callback is not None:
df_update_callback(dfs)

num_records = records_per_batch
while num_records < num_records_overall:
remainder = num_records_overall - num_records
required = remainder / records_per_batch
if not required.is_integer():
required += 1
batches = min(parallel_batches, required)
sample_records = sample_from_record_array(
current_object_json, primary_record_array, batches
)
num_records += records_per_batch * parallel_batches
# Use each as seed for parallel gen
new_objects = await generate_seeded_data(
ai_configuration=ai_configuration,
Expand All @@ -62,7 +75,6 @@ async def generate_data(
)

for new_object in new_objects:
print(new_object)
try:
new_object_json = loads(new_object)
except Exception as e:
Expand Down

0 comments on commit d618a1a

Please sign in to comment.