Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

azure-ft-example #578

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
376 changes: 376 additions & 0 deletions colabs/azure/azure_gpt_medical_notes.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/wandb/examples/blob/master/colabs/azure/azure_gpt_medical_notes.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, List, Literal, Optional, Tuple\n",
"\n",
"import instructor\n",
"import openai\n",
"import pandas as pd\n",
"import weave\n",
"from pydantic import BaseModel, Field\n",
"from set_env import set_env\n",
"import json\n",
"import asyncio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"set_env(\"OPENAI_API_KEY\")\n",
"set_env(\"WANDB_API_KEY\")\n",
"set_env(\"AZURE_OPENAI_ENDPOINT\")\n",
"set_env(\"AZURE_OPENAI_API_KEY\")\n",
"print(\"Env set\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from utils.config import ENTITY, WEAVE_PROJECT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"weave.init(f\"{ENTITY}/{WEAVE_PROJECT}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"N_SAMPLES = 67"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client = openai.OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_medical_data(url: str, num_samples: int = N_SAMPLES) -> Tuple[pd.DataFrame, pd.DataFrame]:\n",
" \"\"\"\n",
" Load medical data and split into train and test sets\n",
" \n",
" Args:\n",
" url: URL of the CSV file\n",
" num_samples: Number of samples to load\n",
" \n",
" Returns:\n",
" Tuple of (train_df, test_df)\n",
" \"\"\"\n",
" df = pd.read_csv(url)\n",
" df = df.sample(n=num_samples, random_state=42) # Sample and shuffle data\n",
" \n",
" # Split into 80% train, 20% test\n",
" train_size = int(0.8 * len(df))\n",
" train_df = df[:train_size]\n",
" test_df = df[train_size:]\n",
" \n",
" return train_df, test_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"medical_dataset_url = \"https://raw.githubusercontent.com/wyim/aci-bench/main/data/challenge_data/train.csv\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df, test_df = load_medical_data(medical_dataset_url)\n",
"train_samples = train_df.to_dict(\"records\")\n",
"test_samples = test_df.to_dict(\"records\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_samples[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_samples[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def convert_to_jsonl(df: pd.DataFrame, output_file: str = \"medical_conversations.jsonl\"):\n",
" \"\"\"\n",
" Convert medical dataset to JSONL format with conversation structure\n",
" \n",
" Args:\n",
" df: DataFrame to convert\n",
" output_file: Output JSONL filename\n",
" \"\"\"\n",
" \n",
" with open(output_file, 'w', encoding='utf-8') as f:\n",
" for _, row in df.iterrows():\n",
" # Create the conversation structure\n",
" conversation = {\n",
" \"messages\": [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": row['dialogue']\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": row['note']\n",
" }\n",
" ]\n",
" }\n",
" \n",
" # Write as JSON line\n",
" json_line = json.dumps(conversation, ensure_ascii=False)\n",
" f.write(json_line + '\\n')\n",
" \n",
" print(f\"Converted {len(df)} records to {output_file}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"convert_to_jsonl(train_df, \"medical_conversations_train.jsonl\")\n",
"convert_to_jsonl(test_df, \"medical_conversations_test.jsonl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from utils.prompts import medical_task, medical_system_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def format_dialogue(dialogue: str):\n",
" dialogue = dialogue.replace(\"\\n\", \" \")\n",
" transcript = f\"Dialogue: {dialogue}\"\n",
" return transcript\n",
"\n",
"\n",
"@weave.op()\n",
"def process_medical_record(dialogue: str) -> Dict:\n",
" transcript = format_dialogue(dialogue)\n",
" prompt = medical_task.format(transcript=transcript)\n",
"\n",
" response = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": medical_system_prompt},\n",
" {\"role\": \"user\", \"content\": prompt},\n",
" ],\n",
" )\n",
"\n",
" extracted_info = response.choices[0].message.content\n",
"\n",
" return {\n",
" \"input\": transcript,\n",
" \"output\": extracted_info,\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define the LLM scoring function\n",
"@weave.op()\n",
"async def medical_note_accuracy(note: str, output: dict) -> dict:\n",
" scoring_prompt = \"\"\"Compare the generated medical note with the ground truth note and evaluate accuracy.\n",
" Score as 1 if the generated note captures the key medical information accurately, 0 if not.\n",
" Output in valid JSON format with just a \"score\" field.\n",
" \n",
" Ground Truth Note:\n",
" {ground_truth}\n",
" \n",
" Generated Note:\n",
" {generated}\"\"\"\n",
" \n",
" prompt = scoring_prompt.format(\n",
" ground_truth=note,\n",
" generated=output['output']\n",
" )\n",
" \n",
" response = client.chat.completions.create(\n",
" model=\"gpt-4o\",\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" response_format={ \"type\": \"json_object\" }\n",
" )\n",
" return json.loads(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create evaluation for test samples\n",
"test_evaluation = weave.Evaluation(\n",
" name='medical_record_extraction_test',\n",
" dataset=test_samples,\n",
" scorers=[medical_note_accuracy]\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" in_jupyter = True\n",
"except ImportError:\n",
" in_jupyter = False\n",
"if in_jupyter:\n",
" import nest_asyncio\n",
"\n",
" nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_results = asyncio.run(test_evaluation.evaluate(process_medical_record))\n",
"print(f\"Completed test evaluation\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from openai import AzureOpenAI\n",
"\n",
"# Initialize Azure client\n",
"azure_client = AzureOpenAI(\n",
" azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\"), \n",
" api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"), \n",
" api_version=\"2024-02-01\"\n",
")\n",
"\n",
"@weave.op()\n",
"def process_medical_record_azure(dialogue: str) -> Dict:\n",
"\n",
" response = azure_client.chat.completions.create(\n",
" model=\"gpt-35-turbo-0125-ft-d30b3aee14864c29acd9ac54eb92457f\",\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"},\n",
" {\"role\": \"user\", \"content\": dialogue},\n",
" ],\n",
" )\n",
"\n",
" extracted_info = response.choices[0].message.content\n",
"\n",
" return {\n",
" \"input\": dialogue,\n",
" \"output\": extracted_info,\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_results_azure = asyncio.run(test_evaluation.evaluate(process_medical_record_azure))\n",
"print(f\"Completed test evaluation\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}