Skip to content

Commit

Permalink
azure-ft-example (#578)
Browse files Browse the repository at this point in the history
* azure-ft-example

* Auto-clean notebooks

---------

Co-authored-by: GitHub Action <[email protected]>
  • Loading branch information
ash0ts and actions-user authored Nov 12, 2024
1 parent ed05814 commit b7a0f20
Showing 1 changed file with 376 additions and 0 deletions.
376 changes: 376 additions & 0 deletions colabs/azure/azure_gpt_medical_notes.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/wandb/examples/blob/master/colabs/azure/azure_gpt_medical_notes.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, List, Literal, Optional, Tuple\n",
"\n",
"import instructor\n",
"import openai\n",
"import pandas as pd\n",
"import weave\n",
"from pydantic import BaseModel, Field\n",
"from set_env import set_env\n",
"import json\n",
"import asyncio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"set_env(\"OPENAI_API_KEY\")\n",
"set_env(\"WANDB_API_KEY\")\n",
"set_env(\"AZURE_OPENAI_ENDPOINT\")\n",
"set_env(\"AZURE_OPENAI_API_KEY\")\n",
"print(\"Env set\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from utils.config import ENTITY, WEAVE_PROJECT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"weave.init(f\"{ENTITY}/{WEAVE_PROJECT}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"N_SAMPLES = 67"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client = openai.OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_medical_data(url: str, num_samples: int = N_SAMPLES) -> Tuple[pd.DataFrame, pd.DataFrame]:\n",
" \"\"\"\n",
" Load medical data and split into train and test sets\n",
" \n",
" Args:\n",
" url: URL of the CSV file\n",
" num_samples: Number of samples to load\n",
" \n",
" Returns:\n",
" Tuple of (train_df, test_df)\n",
" \"\"\"\n",
" df = pd.read_csv(url)\n",
" df = df.sample(n=num_samples, random_state=42) # Sample and shuffle data\n",
" \n",
" # Split into 80% train, 20% test\n",
" train_size = int(0.8 * len(df))\n",
" train_df = df[:train_size]\n",
" test_df = df[train_size:]\n",
" \n",
" return train_df, test_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"medical_dataset_url = \"https://raw.githubusercontent.com/wyim/aci-bench/main/data/challenge_data/train.csv\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df, test_df = load_medical_data(medical_dataset_url)\n",
"train_samples = train_df.to_dict(\"records\")\n",
"test_samples = test_df.to_dict(\"records\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_samples[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_samples[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def convert_to_jsonl(df: pd.DataFrame, output_file: str = \"medical_conversations.jsonl\"):\n",
" \"\"\"\n",
" Convert medical dataset to JSONL format with conversation structure\n",
" \n",
" Args:\n",
" df: DataFrame to convert\n",
" output_file: Output JSONL filename\n",
" \"\"\"\n",
" \n",
" with open(output_file, 'w', encoding='utf-8') as f:\n",
" for _, row in df.iterrows():\n",
" # Create the conversation structure\n",
" conversation = {\n",
" \"messages\": [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": row['dialogue']\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": row['note']\n",
" }\n",
" ]\n",
" }\n",
" \n",
" # Write as JSON line\n",
" json_line = json.dumps(conversation, ensure_ascii=False)\n",
" f.write(json_line + '\\n')\n",
" \n",
" print(f\"Converted {len(df)} records to {output_file}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"convert_to_jsonl(train_df, \"medical_conversations_train.jsonl\")\n",
"convert_to_jsonl(test_df, \"medical_conversations_test.jsonl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from utils.prompts import medical_task, medical_system_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def format_dialogue(dialogue: str):\n",
" dialogue = dialogue.replace(\"\\n\", \" \")\n",
" transcript = f\"Dialogue: {dialogue}\"\n",
" return transcript\n",
"\n",
"\n",
"@weave.op()\n",
"def process_medical_record(dialogue: str) -> Dict:\n",
" transcript = format_dialogue(dialogue)\n",
" prompt = medical_task.format(transcript=transcript)\n",
"\n",
" response = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": medical_system_prompt},\n",
" {\"role\": \"user\", \"content\": prompt},\n",
" ],\n",
" )\n",
"\n",
" extracted_info = response.choices[0].message.content\n",
"\n",
" return {\n",
" \"input\": transcript,\n",
" \"output\": extracted_info,\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define the LLM scoring function\n",
"@weave.op()\n",
"async def medical_note_accuracy(note: str, output: dict) -> dict:\n",
" scoring_prompt = \"\"\"Compare the generated medical note with the ground truth note and evaluate accuracy.\n",
" Score as 1 if the generated note captures the key medical information accurately, 0 if not.\n",
" Output in valid JSON format with just a \"score\" field.\n",
" \n",
" Ground Truth Note:\n",
" {ground_truth}\n",
" \n",
" Generated Note:\n",
" {generated}\"\"\"\n",
" \n",
" prompt = scoring_prompt.format(\n",
" ground_truth=note,\n",
" generated=output['output']\n",
" )\n",
" \n",
" response = client.chat.completions.create(\n",
" model=\"gpt-4o\",\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" response_format={ \"type\": \"json_object\" }\n",
" )\n",
" return json.loads(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create evaluation for test samples\n",
"test_evaluation = weave.Evaluation(\n",
" name='medical_record_extraction_test',\n",
" dataset=test_samples,\n",
" scorers=[medical_note_accuracy]\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" in_jupyter = True\n",
"except ImportError:\n",
" in_jupyter = False\n",
"if in_jupyter:\n",
" import nest_asyncio\n",
"\n",
" nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_results = asyncio.run(test_evaluation.evaluate(process_medical_record))\n",
"print(f\"Completed test evaluation\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from openai import AzureOpenAI\n",
"\n",
"# Initialize Azure client\n",
"azure_client = AzureOpenAI(\n",
" azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\"), \n",
" api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"), \n",
" api_version=\"2024-02-01\"\n",
")\n",
"\n",
"@weave.op()\n",
"def process_medical_record_azure(dialogue: str) -> Dict:\n",
"\n",
" response = azure_client.chat.completions.create(\n",
" model=\"gpt-35-turbo-0125-ft-d30b3aee14864c29acd9ac54eb92457f\",\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"},\n",
" {\"role\": \"user\", \"content\": dialogue},\n",
" ],\n",
" )\n",
"\n",
" extracted_info = response.choices[0].message.content\n",
"\n",
" return {\n",
" \"input\": dialogue,\n",
" \"output\": extracted_info,\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_results_azure = asyncio.run(test_evaluation.evaluate(process_medical_record_azure))\n",
"print(f\"Completed test evaluation\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit b7a0f20

Please sign in to comment.