Skip to content

Commit d39ceb4

Browse files
authored
azure-ft-example
1 parent ed05814 commit d39ceb4

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed
+376
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from typing import Dict, List, Literal, Optional, Tuple\n",
10+
"\n",
11+
"import instructor\n",
12+
"import openai\n",
13+
"import pandas as pd\n",
14+
"import weave\n",
15+
"from pydantic import BaseModel, Field\n",
16+
"from set_env import set_env\n",
17+
"import json\n",
18+
"import asyncio"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"set_env(\"OPENAI_API_KEY\")\n",
28+
"set_env(\"WANDB_API_KEY\")\n",
29+
"set_env(\"AZURE_OPENAI_ENDPOINT\")\n",
30+
"set_env(\"AZURE_OPENAI_API_KEY\")\n",
31+
"print(\"Env set\")"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": 3,
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"from utils.config import ENTITY, WEAVE_PROJECT"
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": null,
46+
"metadata": {},
47+
"outputs": [],
48+
"source": [
49+
"weave.init(f\"{ENTITY}/{WEAVE_PROJECT}\")"
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": 5,
55+
"metadata": {},
56+
"outputs": [],
57+
"source": [
58+
"N_SAMPLES = 67"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": 6,
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"client = openai.OpenAI()"
68+
]
69+
},
70+
{
71+
"cell_type": "code",
72+
"execution_count": 7,
73+
"metadata": {},
74+
"outputs": [],
75+
"source": [
76+
"def load_medical_data(url: str, num_samples: int = N_SAMPLES) -> Tuple[pd.DataFrame, pd.DataFrame]:\n",
77+
" \"\"\"\n",
78+
" Load medical data and split into train and test sets\n",
79+
" \n",
80+
" Args:\n",
81+
" url: URL of the CSV file\n",
82+
" num_samples: Number of samples to load\n",
83+
" \n",
84+
" Returns:\n",
85+
" Tuple of (train_df, test_df)\n",
86+
" \"\"\"\n",
87+
" df = pd.read_csv(url)\n",
88+
" df = df.sample(n=num_samples, random_state=42) # Sample and shuffle data\n",
89+
" \n",
90+
" # Split into 80% train, 20% test\n",
91+
" train_size = int(0.8 * len(df))\n",
92+
" train_df = df[:train_size]\n",
93+
" test_df = df[train_size:]\n",
94+
" \n",
95+
" return train_df, test_df"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": 8,
101+
"metadata": {},
102+
"outputs": [],
103+
"source": [
104+
"medical_dataset_url = \"https://raw.githubusercontent.com/wyim/aci-bench/main/data/challenge_data/train.csv\""
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 9,
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"train_df, test_df = load_medical_data(medical_dataset_url)\n",
114+
"train_samples = train_df.to_dict(\"records\")\n",
115+
"test_samples = test_df.to_dict(\"records\")"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"train_samples[0]"
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": null,
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"test_samples[0]"
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": 12,
139+
"metadata": {},
140+
"outputs": [],
141+
"source": [
142+
"def convert_to_jsonl(df: pd.DataFrame, output_file: str = \"medical_conversations.jsonl\"):\n",
143+
" \"\"\"\n",
144+
" Convert medical dataset to JSONL format with conversation structure\n",
145+
" \n",
146+
" Args:\n",
147+
" df: DataFrame to convert\n",
148+
" output_file: Output JSONL filename\n",
149+
" \"\"\"\n",
150+
" \n",
151+
" with open(output_file, 'w', encoding='utf-8') as f:\n",
152+
" for _, row in df.iterrows():\n",
153+
" # Create the conversation structure\n",
154+
" conversation = {\n",
155+
" \"messages\": [\n",
156+
" {\n",
157+
" \"role\": \"system\",\n",
158+
" \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"\n",
159+
" },\n",
160+
" {\n",
161+
" \"role\": \"user\",\n",
162+
" \"content\": row['dialogue']\n",
163+
" },\n",
164+
" {\n",
165+
" \"role\": \"assistant\",\n",
166+
" \"content\": row['note']\n",
167+
" }\n",
168+
" ]\n",
169+
" }\n",
170+
" \n",
171+
" # Write as JSON line\n",
172+
" json_line = json.dumps(conversation, ensure_ascii=False)\n",
173+
" f.write(json_line + '\\n')\n",
174+
" \n",
175+
" print(f\"Converted {len(df)} records to {output_file}\")"
176+
]
177+
},
178+
{
179+
"cell_type": "code",
180+
"execution_count": null,
181+
"metadata": {},
182+
"outputs": [],
183+
"source": [
184+
"convert_to_jsonl(train_df, \"medical_conversations_train.jsonl\")\n",
185+
"convert_to_jsonl(test_df, \"medical_conversations_test.jsonl\")"
186+
]
187+
},
188+
{
189+
"cell_type": "code",
190+
"execution_count": 14,
191+
"metadata": {},
192+
"outputs": [],
193+
"source": [
194+
"from utils.prompts import medical_task, medical_system_prompt"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": 15,
200+
"metadata": {},
201+
"outputs": [],
202+
"source": [
203+
"def format_dialogue(dialogue: str):\n",
204+
" dialogue = dialogue.replace(\"\\n\", \" \")\n",
205+
" transcript = f\"Dialogue: {dialogue}\"\n",
206+
" return transcript\n",
207+
"\n",
208+
"\n",
209+
"@weave.op()\n",
210+
"def process_medical_record(dialogue: str) -> Dict:\n",
211+
" transcript = format_dialogue(dialogue)\n",
212+
" prompt = medical_task.format(transcript=transcript)\n",
213+
"\n",
214+
" response = client.chat.completions.create(\n",
215+
" model=\"gpt-3.5-turbo\",\n",
216+
" messages=[\n",
217+
" {\"role\": \"system\", \"content\": medical_system_prompt},\n",
218+
" {\"role\": \"user\", \"content\": prompt},\n",
219+
" ],\n",
220+
" )\n",
221+
"\n",
222+
" extracted_info = response.choices[0].message.content\n",
223+
"\n",
224+
" return {\n",
225+
" \"input\": transcript,\n",
226+
" \"output\": extracted_info,\n",
227+
" }"
228+
]
229+
},
230+
{
231+
"cell_type": "code",
232+
"execution_count": 16,
233+
"metadata": {},
234+
"outputs": [],
235+
"source": [
236+
"# Define the LLM scoring function\n",
237+
"@weave.op()\n",
238+
"async def medical_note_accuracy(note: str, output: dict) -> dict:\n",
239+
" scoring_prompt = \"\"\"Compare the generated medical note with the ground truth note and evaluate accuracy.\n",
240+
" Score as 1 if the generated note captures the key medical information accurately, 0 if not.\n",
241+
" Output in valid JSON format with just a \"score\" field.\n",
242+
" \n",
243+
" Ground Truth Note:\n",
244+
" {ground_truth}\n",
245+
" \n",
246+
" Generated Note:\n",
247+
" {generated}\"\"\"\n",
248+
" \n",
249+
" prompt = scoring_prompt.format(\n",
250+
" ground_truth=note,\n",
251+
" generated=output['output']\n",
252+
" )\n",
253+
" \n",
254+
" response = client.chat.completions.create(\n",
255+
" model=\"gpt-4o\",\n",
256+
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
257+
" response_format={ \"type\": \"json_object\" }\n",
258+
" )\n",
259+
" return json.loads(response.choices[0].message.content)"
260+
]
261+
},
262+
{
263+
"cell_type": "code",
264+
"execution_count": 17,
265+
"metadata": {},
266+
"outputs": [],
267+
"source": [
268+
"# Create evaluation for test samples\n",
269+
"test_evaluation = weave.Evaluation(\n",
270+
" name='medical_record_extraction_test',\n",
271+
" dataset=test_samples,\n",
272+
" scorers=[medical_note_accuracy]\n",
273+
")\n"
274+
]
275+
},
276+
{
277+
"cell_type": "code",
278+
"execution_count": 18,
279+
"metadata": {},
280+
"outputs": [],
281+
"source": [
282+
"try:\n",
283+
" in_jupyter = True\n",
284+
"except ImportError:\n",
285+
" in_jupyter = False\n",
286+
"if in_jupyter:\n",
287+
" import nest_asyncio\n",
288+
"\n",
289+
" nest_asyncio.apply()"
290+
]
291+
},
292+
{
293+
"cell_type": "code",
294+
"execution_count": null,
295+
"metadata": {},
296+
"outputs": [],
297+
"source": [
298+
"test_results = asyncio.run(test_evaluation.evaluate(process_medical_record))\n",
299+
"print(f\"Completed test evaluation\")"
300+
]
301+
},
302+
{
303+
"cell_type": "code",
304+
"execution_count": 20,
305+
"metadata": {},
306+
"outputs": [],
307+
"source": [
308+
"import os\n",
309+
"from openai import AzureOpenAI\n",
310+
"\n",
311+
"# Initialize Azure client\n",
312+
"azure_client = AzureOpenAI(\n",
313+
" azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\"), \n",
314+
" api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"), \n",
315+
" api_version=\"2024-02-01\"\n",
316+
")\n",
317+
"\n",
318+
"@weave.op()\n",
319+
"def process_medical_record_azure(dialogue: str) -> Dict:\n",
320+
"\n",
321+
" response = azure_client.chat.completions.create(\n",
322+
" model=\"gpt-35-turbo-0125-ft-d30b3aee14864c29acd9ac54eb92457f\",\n",
323+
" messages=[\n",
324+
" {\"role\": \"system\", \"content\": \"You are a medical scribe assistant. Your task is to accurately document medical conversations between doctors and patients, creating detailed medical notes that capture all relevant clinical information.\"},\n",
325+
" {\"role\": \"user\", \"content\": dialogue},\n",
326+
" ],\n",
327+
" )\n",
328+
"\n",
329+
" extracted_info = response.choices[0].message.content\n",
330+
"\n",
331+
" return {\n",
332+
" \"input\": dialogue,\n",
333+
" \"output\": extracted_info,\n",
334+
" }"
335+
]
336+
},
337+
{
338+
"cell_type": "code",
339+
"execution_count": 21,
340+
"metadata": {},
341+
"outputs": [],
342+
"source": [
343+
"test_results_azure = asyncio.run(test_evaluation.evaluate(process_medical_record_azure))\n",
344+
"print(f\"Completed test evaluation\")"
345+
]
346+
},
347+
{
348+
"cell_type": "code",
349+
"execution_count": null,
350+
"metadata": {},
351+
"outputs": [],
352+
"source": []
353+
}
354+
],
355+
"metadata": {
356+
"kernelspec": {
357+
"display_name": ".venv",
358+
"language": "python",
359+
"name": "python3"
360+
},
361+
"language_info": {
362+
"codemirror_mode": {
363+
"name": "ipython",
364+
"version": 3
365+
},
366+
"file_extension": ".py",
367+
"mimetype": "text/x-python",
368+
"name": "python",
369+
"nbconvert_exporter": "python",
370+
"pygments_lexer": "ipython3",
371+
"version": "3.11.9"
372+
}
373+
},
374+
"nbformat": 4,
375+
"nbformat_minor": 2
376+
}

0 commit comments

Comments
 (0)