Skip to content

Commit

Permalink
Add batches to text generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Darren Edge committed Nov 10, 2024
1 parent 7e868c5 commit 6fa8ee3
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions intelligence_toolkit/generate_mock_data/text_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,28 @@ async def generate_text_data(
generation_guidance: str = "",
temperature: float = 0.5,
df_update_callback=None,
parallelism: int = 10,
):
df = pd.DataFrame(columns=["mock_text"])
tasks = []
for text in input_texts:
tasks.append(
asyncio.create_task(
_generate_text_async(
ai_configuration=ai_configuration,
input_text=text,
generation_guidance=generation_guidance,
temperature=temperature,
)
)
)
generated_texts = await tqdm_asyncio.gather(*tasks)
df = pd.DataFrame(generated_texts, columns=["mock_text"])
if df_update_callback is not None:
df_update_callback(df)

generated_texts = []
# batch the input_texts into groups of parallelism
batches = [
input_texts[i : i + parallelism]
for i in range(0, len(input_texts), parallelism)
]
for batch in batches:
tasks = [
asyncio.create_task(_generate_text_async(
ai_configuration=ai_configuration,
input_text=text,
generation_guidance=generation_guidance,
temperature=temperature,
)) for text in batch]
new_generated_texts = await tqdm_asyncio.gather(*tasks)
generated_texts.extend(new_generated_texts)
df = pd.DataFrame(generated_texts, columns=["mock_text"])
if df_update_callback is not None:
df_update_callback(df)
return generated_texts, df


Expand Down

0 comments on commit 6fa8ee3

Please sign in to comment.