-
Notifications
You must be signed in to change notification settings - Fork 291
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
376 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,376 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from typing import Dict, List, Literal, Optional, Tuple\n", | ||
"\n", | ||
"import instructor\n", | ||
"import openai\n", | ||
"import pandas as pd\n", | ||
"import weave\n", | ||
"from pydantic import BaseModel, Field\n", | ||
"from set_env import set_env\n", | ||
"import json\n", | ||
"import asyncio" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"set_env(\"OPENAI_API_KEY\")\n", | ||
"set_env(\"WANDB_API_KEY\")\n", | ||
"set_env(\"AZURE_OPENAI_ENDPOINT\")\n", | ||
"set_env(\"AZURE_OPENAI_API_KEY\")\n", | ||
"print(\"Env set\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from utils.config import ENTITY, WEAVE_PROJECT" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"weave.init(f\"{ENTITY}/{WEAVE_PROJECT}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"N_SAMPLES = 67" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"client = openai.OpenAI()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def load_medical_data(url: str, num_samples: int = N_SAMPLES) -> Tuple[pd.DataFrame, pd.DataFrame]:\n", | ||
" \"\"\"\n", | ||
" Load medical data and split into train and test sets\n", | ||
" \n", | ||
" Args:\n", | ||
" url: URL of the CSV file\n", | ||
" num_samples: Number of samples to load\n", | ||
" \n", | ||
" Returns:\n", | ||
" Tuple of (train_df, test_df)\n", | ||
" \"\"\"\n", | ||
" df = pd.read_csv(url)\n", | ||
" df = df.sample(n=num_samples, random_state=42) # Sample and shuffle data\n", | ||
" \n", | ||
" # Split into 80% train, 20% test\n", | ||
" train_size = int(0.8 * len(df))\n", | ||
" train_df = df[:train_size]\n", | ||
" test_df = df[train_size:]\n", | ||
" \n", | ||
" return train_df, test_df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"medical_dataset_url = \"https://raw.githubusercontent.com/wyim/aci-bench/main/data/challenge_data/train.csv\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_df, test_df = load_medical_data(medical_dataset_url)\n", | ||
"train_samples = train_df.to_dict(\"records\")\n", | ||
"test_samples = test_df.to_dict(\"records\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_samples[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test_samples[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def convert_to_jsonl(df: pd.DataFrame, output_file: str = \"medical_conversations.jsonl\"):\n", | ||
" \"\"\"\n", | ||
" Convert medical dataset to JSONL format with conversation structure\n", | ||
" \n", | ||
" Args:\n", | ||
" df: DataFrame to convert\n", | ||
" output_file: Output JSONL filename\n", | ||
" \"\"\"\n", | ||
" \n", | ||
" with open(output_file, 'w', encoding='utf-8') as f:\n", | ||
" for _, row in df.iterrows():\n", | ||
" # Create the conversation structure\n", | ||
" conversation = {\n", | ||
" \"messages\": [\n", | ||
" {\n", | ||
" \"role\": \"system\",\n", | ||
" \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"\n", | ||
" },\n", | ||
" {\n", | ||
" \"role\": \"user\",\n", | ||
" \"content\": row['dialogue']\n", | ||
" },\n", | ||
" {\n", | ||
" \"role\": \"assistant\",\n", | ||
" \"content\": row['note']\n", | ||
" }\n", | ||
" ]\n", | ||
" }\n", | ||
" \n", | ||
" # Write as JSON line\n", | ||
" json_line = json.dumps(conversation, ensure_ascii=False)\n", | ||
" f.write(json_line + '\\n')\n", | ||
" \n", | ||
" print(f\"Converted {len(df)} records to {output_file}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"convert_to_jsonl(train_df, \"medical_conversations_train.jsonl\")\n", | ||
"convert_to_jsonl(test_df, \"medical_conversations_test.jsonl\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from utils.prompts import medical_task, medical_system_prompt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def format_dialogue(dialogue: str):\n", | ||
" dialogue = dialogue.replace(\"\\n\", \" \")\n", | ||
" transcript = f\"Dialogue: {dialogue}\"\n", | ||
" return transcript\n", | ||
"\n", | ||
"\n", | ||
"@weave.op()\n", | ||
"def process_medical_record(dialogue: str) -> Dict:\n", | ||
" transcript = format_dialogue(dialogue)\n", | ||
" prompt = medical_task.format(transcript=transcript)\n", | ||
"\n", | ||
" response = client.chat.completions.create(\n", | ||
" model=\"gpt-3.5-turbo\",\n", | ||
" messages=[\n", | ||
" {\"role\": \"system\", \"content\": medical_system_prompt},\n", | ||
" {\"role\": \"user\", \"content\": prompt},\n", | ||
" ],\n", | ||
" )\n", | ||
"\n", | ||
" extracted_info = response.choices[0].message.content\n", | ||
"\n", | ||
" return {\n", | ||
" \"input\": transcript,\n", | ||
" \"output\": extracted_info,\n", | ||
" }" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define the LLM scoring function\n", | ||
"@weave.op()\n", | ||
"async def medical_note_accuracy(note: str, output: dict) -> dict:\n", | ||
" scoring_prompt = \"\"\"Compare the generated medical note with the ground truth note and evaluate accuracy.\n", | ||
" Score as 1 if the generated note captures the key medical information accurately, 0 if not.\n", | ||
" Output in valid JSON format with just a \"score\" field.\n", | ||
" \n", | ||
" Ground Truth Note:\n", | ||
" {ground_truth}\n", | ||
" \n", | ||
" Generated Note:\n", | ||
" {generated}\"\"\"\n", | ||
" \n", | ||
" prompt = scoring_prompt.format(\n", | ||
" ground_truth=note,\n", | ||
" generated=output['output']\n", | ||
" )\n", | ||
" \n", | ||
" response = client.chat.completions.create(\n", | ||
" model=\"gpt-4o\",\n", | ||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n", | ||
" response_format={ \"type\": \"json_object\" }\n", | ||
" )\n", | ||
" return json.loads(response.choices[0].message.content)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create evaluation for test samples\n", | ||
"test_evaluation = weave.Evaluation(\n", | ||
" name='medical_record_extraction_test',\n", | ||
" dataset=test_samples,\n", | ||
" scorers=[medical_note_accuracy]\n", | ||
")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"try:\n", | ||
" in_jupyter = True\n", | ||
"except ImportError:\n", | ||
" in_jupyter = False\n", | ||
"if in_jupyter:\n", | ||
" import nest_asyncio\n", | ||
"\n", | ||
" nest_asyncio.apply()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test_results = asyncio.run(test_evaluation.evaluate(process_medical_record))\n", | ||
"print(f\"Completed test evaluation\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"from openai import AzureOpenAI\n", | ||
"\n", | ||
"# Initialize Azure client\n", | ||
"azure_client = AzureOpenAI(\n", | ||
" azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\"), \n", | ||
" api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"), \n", | ||
" api_version=\"2024-02-01\"\n", | ||
")\n", | ||
"\n", | ||
"@weave.op()\n", | ||
"def process_medical_record_azure(dialogue: str) -> Dict:\n", | ||
"\n", | ||
" response = azure_client.chat.completions.create(\n", | ||
" model=\"gpt-35-turbo-0125-ft-d30b3aee14864c29acd9ac54eb92457f\",\n", | ||
" messages=[\n", | ||
" {\"role\": \"system\", \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"},\n", | ||
" {\"role\": \"user\", \"content\": dialogue},\n", | ||
" ],\n", | ||
" )\n", | ||
"\n", | ||
" extracted_info = response.choices[0].message.content\n", | ||
"\n", | ||
" return {\n", | ||
" \"input\": dialogue,\n", | ||
" \"output\": extracted_info,\n", | ||
" }" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test_results_azure = asyncio.run(test_evaluation.evaluate(process_medical_record_azure))\n", | ||
"print(f\"Completed test evaluation\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |