Skip to content

Commit 0bc8753

Browse files
committed
update vislog to be a bit more fancy, has GPT-2/3 perf for hellaswag
1 parent 3bce68b commit 0bc8753

File tree

1 file changed

+45
-22
lines changed

1 file changed

+45
-22
lines changed

dev/vislog.ipynb

+45-22
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
},
1010
{
1111
"cell_type": "code",
12-
"execution_count": 15,
12+
"execution_count": null,
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
@@ -61,63 +61,86 @@
6161
"metadata": {},
6262
"outputs": [],
6363
"source": [
64-
"sz = \"350M\"\n",
64+
"import numpy as np\n",
65+
"\n",
66+
"sz = \"124M\"\n",
6567
"loss_baseline = {\n",
6668
" \"124M\": 3.424958,\n",
6769
" \"350M\": 3.083089,\n",
6870
" \"774M\": 3.000580,\n",
6971
" \"1558M\": 2.831273,\n",
7072
"}[sz]\n",
71-
"hella_baseline = {\n",
73+
"hella2_baseline = { # for GPT-2\n",
7274
" \"124M\": 0.294463,\n",
7375
" \"350M\": 0.375224,\n",
7476
" \"774M\": 0.431986,\n",
7577
" \"1558M\": 0.488946,\n",
7678
"}[sz]\n",
77-
"\n",
79+
"hella3_baseline = { # for GPT-3\n",
80+
" \"124M\": 0.337,\n",
81+
" \"350M\": 0.436,\n",
82+
" \"774M\": 0.510,\n",
83+
" \"1558M\": 0.547,\n",
84+
"}[sz]\n",
7885
"# assumes each model run is stored in this way\n",
79-
"logfile = f\"../log{sz}/main.log\"\n",
86+
"logfile = f\"../log_gpt2_{sz}/main.log\"\n",
8087
"streams = parse_logfile(logfile)\n",
8188
"\n",
89+
"# optional function that smooths out the loss some\n",
90+
"def smooth_moving_average(signal, window_size):\n",
91+
" if signal.ndim != 1:\n",
92+
" raise ValueError(\"smooth_moving_average only accepts 1D arrays.\")\n",
93+
" if signal.size < window_size:\n",
94+
" raise ValueError(\"Input vector needs to be bigger than window size.\")\n",
95+
" if window_size < 3:\n",
96+
" return signal\n",
97+
"\n",
98+
" s = np.pad(signal, (window_size//2, window_size-1-window_size//2), mode='edge')\n",
99+
" w = np.ones(window_size) / window_size\n",
100+
" smoothed_signal = np.convolve(s, w, mode='valid')\n",
101+
" return smoothed_signal\n",
102+
"\n",
82103
"plt.figure(figsize=(16, 6))\n",
83104
"\n",
84105
"# Panel 1: losses: both train and val\n",
85106
"plt.subplot(121)\n",
86107
"xs, ys = streams[\"trl\"] # training loss\n",
108+
"ys = np.array(ys)\n",
109+
"# smooth out ys using a rolling window\n",
110+
"# ys = smooth_moving_average(ys, 21) # optional\n",
87111
"plt.plot(xs, ys, label=f'llm.c ({sz}) train loss')\n",
88112
"print(\"Min Train Loss:\", min(ys))\n",
89113
"xs, ys = streams[\"tel\"] # validation loss\n",
90114
"plt.plot(xs, ys, label=f'llm.c ({sz}) val loss')\n",
91115
"# horizontal line at GPT-2 baseline\n",
116+
"# we don't have GPT-3 loss on this dataset because the weights were never released\n",
92117
"if loss_baseline is not None:\n",
93118
" plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint val loss\")\n",
94119
"plt.xlabel(\"steps\")\n",
95120
"plt.ylabel(\"loss\")\n",
96121
"plt.yscale('log')\n",
122+
"plt.ylim(top=4.0)\n",
97123
"plt.legend()\n",
98124
"plt.title(\"Loss\")\n",
99125
"print(\"Min Validation Loss:\", min(ys))\n",
100126
"\n",
101127
"# Panel 2: HellaSwag eval\n",
102128
"plt.subplot(122)\n",
103-
"xs, ys = streams[\"eval\"] # HellaSwag eval\n",
104-
"plt.plot(xs, ys, label=f\"llm.c ({sz})\")\n",
105-
"# horizontal line at GPT-2 baseline\n",
106-
"if hella_baseline:\n",
107-
" plt.axhline(y=hella_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint\")\n",
108-
"plt.xlabel(\"steps\")\n",
109-
"plt.ylabel(\"accuracy\")\n",
110-
"plt.legend()\n",
111-
"plt.title(\"HellaSwag eval\")\n",
112-
"print(\"Max Hellaswag eval:\", max(ys))"
129+
"if \"eval\" in streams:\n",
130+
" xs, ys = streams[\"eval\"] # HellaSwag eval\n",
131+
" ys = np.array(ys)\n",
132+
" plt.plot(xs, ys, label=f\"llm.c ({sz})\")\n",
133+
" # horizontal line at GPT-2/3 baselines\n",
134+
" if hella2_baseline:\n",
135+
" plt.axhline(y=hella2_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint\")\n",
136+
" if hella3_baseline:\n",
137+
" plt.axhline(y=hella3_baseline, color='g', linestyle='--', label=f\"OpenAI GPT-3 ({sz}) checkpoint\")\n",
138+
" plt.xlabel(\"steps\")\n",
139+
" plt.ylabel(\"accuracy\")\n",
140+
" plt.legend()\n",
141+
" plt.title(\"HellaSwag eval\")\n",
142+
" print(\"Max Hellaswag eval:\", max(ys))\n"
113143
]
114-
},
115-
{
116-
"cell_type": "code",
117-
"execution_count": null,
118-
"metadata": {},
119-
"outputs": [],
120-
"source": []
121144
}
122145
],
123146
"metadata": {

0 commit comments

Comments
 (0)