Skip to content

Commit efa5f20

Browse files
committed
update model
1 parent f552543 commit efa5f20

23 files changed

+6499
-554
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/graphs
22
/graphs-gpt4
3+
/ehr_training_result
34
.exp_data
45
/output
56
.ipynb_checkpoints
11.8 KB
Binary file not shown.

__pycache__/graphcare.cpython-38.pyc

15.1 KB
Binary file not shown.

data_prepare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import csv
22
from pyhealth.datasets import MIMIC3Dataset, MIMIC4Dataset
3-
from GraphCare.task_fn import drug_recommendation_fn, drug_recommendation_mimic4_fn, mortality_prediction_mimic3_fn, readmission_prediction_mimic3_fn, length_of_stay_prediction_mimic3_fn, length_of_stay_prediction_mimic4_fn, mortality_prediction_mimic4_fn, readmission_prediction_mimic4_fn
3+
from graphcare_.task_fn import drug_recommendation_fn, drug_recommendation_mimic4_fn, mortality_prediction_mimic3_fn, readmission_prediction_mimic3_fn, length_of_stay_prediction_mimic3_fn, length_of_stay_prediction_mimic4_fn, mortality_prediction_mimic4_fn, readmission_prediction_mimic4_fn
44
import pickle
55
import json
66
from pyhealth.tokenizer import Tokenizer
@@ -524,4 +524,4 @@ def main():
524524

525525

526526
if __name__ == "__main__":
527-
main()
527+
main()

drug_rec_ehr.ipynb

Lines changed: 206 additions & 379 deletions
Large diffs are not rendered by default.

drug_rec_ehr_feat.ipynb

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 2,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"task = \"lenofstay\"\n",
10+
"\n",
11+
"ratios = [\n",
12+
" 0.1,\n",
13+
" 0.2,\n",
14+
" 0.3,\n",
15+
" 0.4,\n",
16+
" 0.5,\n",
17+
" 0.7,\n",
18+
" 0.9,\n",
19+
"]"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 3,
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"from pyhealth.datasets import split_by_patient, get_dataloader\n",
29+
"import pickle\n",
30+
"\n",
31+
"with open(f'/data/pj20/exp_data/ccscm_ccsproc/sample_dataset_mimic3_{task}_th015.pkl', 'rb') as f:\n",
32+
" sample_dataset = pickle.load(f)\n",
33+
"\n",
34+
"train_dataset, _, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], train_ratio=1.0, seed=528)\n",
35+
"train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)\n",
36+
"test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 4,
42+
"metadata": {},
43+
"outputs": [],
44+
"source": [
45+
"from pyhealth.trainer import Trainer\n",
46+
"import torch\n",
47+
"from pyhealth.models import Transformer, RETAIN, SafeDrug, MICRON, CNN, RNN, GAMENet\n",
48+
"from collections import defaultdict\n",
49+
"\n",
50+
"\n",
51+
"for ratio in ratios:\n",
52+
" with open(f'/data/pj20/exp_data/ccscm_ccsproc_atc3/val_dataset_mimic3_{task}_th015_{1-ratio}.pkl', 'rb') as f:\n",
53+
" val_dataset = pickle.load(f)\n",
54+
" val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)\n"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 7,
60+
"metadata": {},
61+
"outputs": [
62+
{
63+
"name": "stderr",
64+
"output_type": "stream",
65+
"text": [
66+
"GAMENet(\n",
67+
" (embeddings): ModuleDict(\n",
68+
" (conditions): Embedding(283, 128, padding_idx=0)\n",
69+
" (procedures): Embedding(223, 128, padding_idx=0)\n",
70+
" )\n",
71+
" (cond_rnn): GRU(128, 128, batch_first=True)\n",
72+
" (proc_rnn): GRU(128, 128, batch_first=True)\n",
73+
" (query): Sequential(\n",
74+
" (0): ReLU()\n",
75+
" (1): Linear(in_features=256, out_features=128, bias=True)\n",
76+
" )\n",
77+
" (gamenet): GAMENetLayer(\n",
78+
" (ehr_gcn): GCN(\n",
79+
" (gcn1): GCNLayer()\n",
80+
" (dropout_layer): Dropout(p=0.5, inplace=False)\n",
81+
" (gcn2): GCNLayer()\n",
82+
" )\n",
83+
" (ddi_gcn): GCN(\n",
84+
" (gcn1): GCNLayer()\n",
85+
" (dropout_layer): Dropout(p=0.5, inplace=False)\n",
86+
" (gcn2): GCNLayer()\n",
87+
" )\n",
88+
" (fc): Linear(in_features=384, out_features=197, bias=True)\n",
89+
" (bce_loss_fn): BCEWithLogitsLoss()\n",
90+
" )\n",
91+
")\n",
92+
"Metrics: ['pr_auc_samples', 'roc_auc_samples', 'f1_samples', 'jaccard_samples']\n",
93+
"Device: cuda:1\n",
94+
"\n",
95+
"Training:\n",
96+
"Batch size: 64\n",
97+
"Optimizer: <class 'torch.optim.adam.Adam'>\n",
98+
"Optimizer params: {'lr': 0.001}\n",
99+
"Weight decay: 0.0\n",
100+
"Max grad norm: None\n",
101+
"Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7fb588a29b50>\n",
102+
"Monitor: pr_auc_samples\n",
103+
"Monitor criterion: max\n",
104+
"Epochs: 5\n",
105+
"\n",
106+
"Epoch 0 / 5: 100%|██████████| 1/1 [00:00<00:00, 3.45it/s]\n",
107+
"--- Train epoch-0, step-1 ---\n",
108+
"loss: 0.6954\n",
109+
"Evaluation: 100%|██████████| 68/68 [00:00<00:00, 121.22it/s]\n",
110+
"--- Eval epoch-0, step-1 ---\n",
111+
"pr_auc_samples: 0.2212\n",
112+
"roc_auc_samples: 0.5977\n",
113+
"f1_samples: 0.2464\n",
114+
"jaccard_samples: 0.1441\n",
115+
"loss: 0.6834\n",
116+
"New best pr_auc_samples score (0.2212) at epoch-0, step-1\n",
117+
"\n",
118+
"Epoch 1 / 5: 100%|██████████| 1/1 [00:00<00:00, 90.69it/s]\n",
119+
"--- Train epoch-1, step-2 ---\n",
120+
"loss: 0.6839\n",
121+
"Evaluation: 100%|██████████| 68/68 [00:00<00:00, 155.81it/s]\n",
122+
"--- Eval epoch-1, step-2 ---\n",
123+
"pr_auc_samples: 0.3191\n",
124+
"roc_auc_samples: 0.6721\n",
125+
"f1_samples: 0.3108\n",
126+
"jaccard_samples: 0.1885\n",
127+
"loss: 0.6718\n",
128+
"New best pr_auc_samples score (0.3191) at epoch-1, step-2\n",
129+
"\n",
130+
"Epoch 2 / 5: 100%|██████████| 1/1 [00:00<00:00, 86.69it/s]\n",
131+
"--- Train epoch-2, step-3 ---\n",
132+
"loss: 0.6737\n",
133+
"Evaluation: 100%|██████████| 68/68 [00:00<00:00, 153.50it/s]\n",
134+
"--- Eval epoch-2, step-3 ---\n",
135+
"pr_auc_samples: 0.4212\n",
136+
"roc_auc_samples: 0.7142\n",
137+
"f1_samples: 0.3806\n",
138+
"jaccard_samples: 0.2418\n",
139+
"loss: 0.6606\n",
140+
"New best pr_auc_samples score (0.4212) at epoch-2, step-3\n",
141+
"\n",
142+
"Epoch 3 / 5: 100%|██████████| 1/1 [00:00<00:00, 85.59it/s]\n",
143+
"--- Train epoch-3, step-4 ---\n",
144+
"loss: 0.6613\n",
145+
"Evaluation: 100%|██████████| 68/68 [00:00<00:00, 149.41it/s]\n",
146+
"--- Eval epoch-3, step-4 ---\n",
147+
"pr_auc_samples: 0.4770\n",
148+
"roc_auc_samples: 0.7327\n",
149+
"f1_samples: 0.4432\n",
150+
"jaccard_samples: 0.2942\n",
151+
"loss: 0.6491\n",
152+
"New best pr_auc_samples score (0.4770) at epoch-3, step-4\n",
153+
"\n",
154+
"Epoch 4 / 5: 100%|██████████| 1/1 [00:00<00:00, 84.91it/s]\n",
155+
"--- Train epoch-4, step-5 ---\n",
156+
"loss: 0.6454\n",
157+
"Evaluation: 100%|██████████| 68/68 [00:00<00:00, 150.65it/s]\n",
158+
"--- Eval epoch-4, step-5 ---\n",
159+
"pr_auc_samples: 0.4981\n",
160+
"roc_auc_samples: 0.7424\n",
161+
"f1_samples: 0.4729\n",
162+
"jaccard_samples: 0.3208\n",
163+
"loss: 0.6370\n",
164+
"New best pr_auc_samples score (0.4981) at epoch-4, step-5\n",
165+
"Loaded best model\n",
166+
"Evaluation: 100%|██████████| 68/68 [00:00<00:00, 152.59it/s]\n"
167+
]
168+
}
169+
],
170+
"source": [
171+
"from pyhealth.trainer import Trainer\n",
172+
"import torch\n",
173+
"from pyhealth.models import Transformer, RETAIN, SafeDrug, MICRON, CNN, RNN, GAMENet\n",
174+
"from collections import defaultdict\n",
175+
"\n",
176+
"results = defaultdict(list)\n",
177+
"\n",
178+
"for i in range(1):\n",
179+
" for model_ in [\n",
180+
" # Transformer, \n",
181+
" # RETAIN,\n",
182+
" # SafeDrug,\n",
183+
" # MICRON,\n",
184+
" GAMENet\n",
185+
" ]:\n",
186+
" try:\n",
187+
" model = model_(\n",
188+
" dataset=sample_dataset,\n",
189+
" feature_keys=[\"conditions\", \"procedures\"],\n",
190+
" label_key=\"drugs\",\n",
191+
" mode=\"multilabel\",\n",
192+
" )\n",
193+
" except:\n",
194+
" model = model_(dataset=sample_dataset)\n",
195+
"\n",
196+
" device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
197+
"\n",
198+
" ## binary\n",
199+
" # trainer = Trainer(model=model, device=device, metrics=[\"pr_auc\", \"roc_auc\", \"accuracy\", \"f1\", \"jaccard\"])\n",
200+
" # trainer.train(\n",
201+
" # train_dataloader=train_loader,\n",
202+
" # val_dataloader=val_loader,\n",
203+
" # epochs=5,\n",
204+
" # monitor=\"accuracy\",\n",
205+
" # )\n",
206+
"\n",
207+
" ## multi-label\n",
208+
" trainer = Trainer(model=model, device=device, metrics=[\"pr_auc_samples\", \"roc_auc_samples\", \"f1_samples\", \"jaccard_samples\"])\n",
209+
" trainer.train(\n",
210+
" train_dataloader=train_loader,\n",
211+
" val_dataloader=val_loader,\n",
212+
" epochs=5,\n",
213+
" monitor=\"pr_auc_samples\",\n",
214+
" )\n",
215+
"\n",
216+
" ## multi-class\n",
217+
" # trainer = Trainer(model=model, device=device, metrics=[\"roc_auc_weighted_ovr\", \"cohen_kappa\", \"accuracy\", \"f1_weighted\"])\n",
218+
" # trainer.train(\n",
219+
" # train_dataloader=train_loader,\n",
220+
" # val_dataloader=val_loader,\n",
221+
" # epochs=5,\n",
222+
" # monitor=\"roc_auc_weighted_ovr\",\n",
223+
" # )\n",
224+
"\n",
225+
" results[model_.__name__].append(trainer.evaluate(val_loader))"
226+
]
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": 12,
231+
"metadata": {},
232+
"outputs": [],
233+
"source": [
234+
"avg_results = defaultdict(dict)\n",
235+
"\n",
236+
"for k, v in results.items():\n",
237+
" for k_, v_ in v[0].items():\n",
238+
" avg_results[k][k_] = sum([vv[k_] for vv in v]) / len(v)"
239+
]
240+
},
241+
{
242+
"cell_type": "code",
243+
"execution_count": 13,
244+
"metadata": {},
245+
"outputs": [],
246+
"source": [
247+
"import numpy as np\n",
248+
"# calculate standard deviation\n",
249+
"variation_results = defaultdict(dict)\n",
250+
"\n",
251+
"for k, v in results.items():\n",
252+
" for k_, v_ in v[0].items():\n",
253+
" variation_results[k][k_] = np.std([vv[k_] for vv in v])"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": 14,
259+
"metadata": {},
260+
"outputs": [
261+
{
262+
"data": {
263+
"text/plain": [
264+
"defaultdict(dict,\n",
265+
" {'GAMENet': {'pr_auc_samples': 0.4980838198236469,\n",
266+
" 'roc_auc_samples': 0.7424090396318291,\n",
267+
" 'f1_samples': 0.4728838360695048,\n",
268+
" 'jaccard_samples': 0.32078592771277264,\n",
269+
" 'loss': 0.6370396333582261}})"
270+
]
271+
},
272+
"execution_count": 14,
273+
"metadata": {},
274+
"output_type": "execute_result"
275+
}
276+
],
277+
"source": [
278+
"avg_results"
279+
]
280+
},
281+
{
282+
"cell_type": "code",
283+
"execution_count": 11,
284+
"metadata": {},
285+
"outputs": [
286+
{
287+
"data": {
288+
"text/plain": [
289+
"defaultdict(dict,\n",
290+
" {'GAMENet': {'pr_auc_samples': 0.0,\n",
291+
" 'roc_auc_samples': 0.0,\n",
292+
" 'f1_samples': 0.0,\n",
293+
" 'jaccard_samples': 0.0,\n",
294+
" 'loss': 0.0}})"
295+
]
296+
},
297+
"execution_count": 11,
298+
"metadata": {},
299+
"output_type": "execute_result"
300+
}
301+
],
302+
"source": [
303+
"variation_results"
304+
]
305+
},
306+
{
307+
"cell_type": "code",
308+
"execution_count": null,
309+
"metadata": {},
310+
"outputs": [],
311+
"source": []
312+
},
313+
{
314+
"cell_type": "code",
315+
"execution_count": null,
316+
"metadata": {},
317+
"outputs": [],
318+
"source": []
319+
}
320+
],
321+
"metadata": {
322+
"kernelspec": {
323+
"display_name": "Python 3.8.13 ('kgc')",
324+
"language": "python",
325+
"name": "python3"
326+
},
327+
"language_info": {
328+
"codemirror_mode": {
329+
"name": "ipython",
330+
"version": 3
331+
},
332+
"file_extension": ".py",
333+
"mimetype": "text/x-python",
334+
"name": "python",
335+
"nbconvert_exporter": "python",
336+
"pygments_lexer": "ipython3",
337+
"version": "3.8.13"
338+
},
339+
"orig_nbformat": 4,
340+
"vscode": {
341+
"interpreter": {
342+
"hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
343+
}
344+
}
345+
},
346+
"nbformat": 4,
347+
"nbformat_minor": 2
348+
}

0 commit comments

Comments
 (0)