Skip to content

Commit 125acb1

Browse files
committed
Data management more efficient in jupyter notebooks
1 parent 84cf555 commit 125acb1

File tree

4 files changed

+478
-316
lines changed

4 files changed

+478
-316
lines changed

pain_emnist_tf.ipynb

Lines changed: 302 additions & 97 deletions
Large diffs are not rendered by default.

pain_tf.ipynb

Lines changed: 57 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
"**Update:** August 21, 2024\n",
1212
"\n",
1313
"**Implementation:** PAIN architecture presented in:\n",
14-
"**[IEEE publication](https://ieeexplore.ieee.org/abstract/document/8682767)**"
14+
"**[IEEE publication](https://ieeexplore.ieee.org/abstract/document/8682767)**\n",
15+
"\n",
16+
"***NOTE:*** Code is setup to reconstruct 7x7 compressed noisy images. Training parameters may need to be adjusted for 4x4, 14x14, and 28x28 cases.\n",
17+
"\n",
18+
"Python Version: 3.11.5 \n",
19+
"Jupyer Notebook Version: 6.5.4"
1520
]
1621
},
1722
{
@@ -55,10 +60,16 @@
5560
"\n",
5661
"# Dimension of the compressed/noisy images (width=height) \n",
5762
"# cmp_dim = 4:(for 4x4), 7:(7x7), 14:(14x14), or 28:(28x284)\n",
58-
"cmp_dim = 14\n",
63+
"cmp_dim = 7\n",
5964
"\n",
6065
"# Dimension of output, original are 28 x 28\n",
61-
"out_dim = 28 "
66+
"out_dim = 28 \n",
67+
"\n",
68+
"# Number of training epochs\n",
69+
"num_epochs = 25\n",
70+
"\n",
71+
"# Size of training batch sizes\n",
72+
"batch_size = 250"
6273
]
6374
},
6475
{
@@ -122,35 +133,24 @@
122133
"# Create compressed noisy data\n",
123134
"def create_training_test_data(clean_train, clean_test, cmp_dim, out_dim):\n",
124135
" \n",
125-
" # Step 0: Create Training and Validation Sets\n",
126-
" clean_test, clean_valid = split_dataset_rnd(clean_test)\n",
127-
" \n",
128-
" # Step 1: Compress images using median in sliding window\n",
136+
" # Compress images using median in sliding window\n",
129137
" cmp_train = down_sample_list(clean_train, cmp_dim)\n",
130138
" cmp_test = down_sample_list(clean_test, cmp_dim)\n",
131-
" cmp_valid = down_sample_list(clean_valid, cmp_dim)\n",
132139
"\n",
133-
" # Step 2: Add Poisson noise to compressed images\n",
140+
" # Add Poisson noise to compressed images\n",
134141
" noisy_train_ = np.random.poisson(lam=cmp_train)\n",
135142
" noisy_test_ = np.random.poisson(lam=cmp_test)\n",
136-
" noisy_valid_ = np.random.poisson(lam=cmp_valid)\n",
137143
" \n",
144+
" # Scale to [0,255]\n",
138145
" noisy_train = np.clip(noisy_train_,0,255)\n",
139146
" noisy_test = np.clip(noisy_test_,0,255)\n",
140-
" noisy_valid = np.clip(noisy_valid_,0,255)\n",
141147
" \n",
142-
" # Step 3: Reshape Input Arrays and scale from [0, 255] to [0,1]\n",
148+
" # Reshape Compressed/Noisy Arrays and scale to [0,1]\n",
143149
" noisy_train = np.array([matrix.reshape(cmp_dim**2,) for matrix in noisy_train/255])\n",
144150
" noisy_test = np.array([matrix.reshape(cmp_dim**2,) for matrix in noisy_test/255])\n",
145-
" noisy_valid = np.array([matrix.reshape(cmp_dim**2,) for matrix in noisy_valid/255])\n",
146-
" \n",
147-
" # Step 4: Reshape Expected Ouput Arraus and scale from [0, 255] to [0,1]\n",
148-
" clean_train = np.array([matrix.reshape(out_dim**2,) for matrix in clean_train/255])\n",
149-
" clean_test = np.array([matrix.reshape(out_dim**2,) for matrix in clean_test/255])\n",
150-
" clean_valid = np.array([matrix.reshape(out_dim**2,) for matrix in clean_valid/255])\n",
151151
" \n",
152-
" # Step 5: Return training data\n",
153-
" return [clean_train, clean_test, clean_valid, noisy_train, noisy_test, noisy_valid]\n",
152+
" # Return training data\n",
153+
" return noisy_train, noisy_test\n",
154154
"\n",
155155
"# PAIN architecture builder function\n",
156156
"def build_PAIN(in_dim, out_dim, enc_dim = 256):\n",
@@ -188,57 +188,43 @@
188188
"metadata": {},
189189
"outputs": [],
190190
"source": [
191+
"# Check compression dimention variable:\n",
192+
"if not (cmp_dim in [4, 7, 14, 28]):\n",
193+
" raise ValueError(f'cmp_dim = {cmp_dim}, not handled. The cmp_dim value must be 4, 7, 14, or 28.')\n",
194+
"\n",
191195
"directory = 'training_data'\n",
192196
"\n",
193197
"# If the directory does not exist, create it\n",
194198
"if not os.path.exists(directory):\n",
195199
" os.makedirs(directory)\n",
196200
"\n",
197201
"# Load the MNIST Dataset\n",
198-
"file = f'mnist_training_data_cmp' # Do not modify\n",
199-
"\n",
200-
"# If the training data does not exist, create it\n",
201-
"if not os.path.exists(f'{directory}/{file}_{cmp_dim}x{cmp_dim}.h5'):\n",
202-
" # Load MNIST\n",
203-
" (clean_train, _), (clean_test, _) = tf.keras.datasets.mnist.load_data()\n",
202+
"dat_file = f'{directory}/mnist_{cmp_dim}x{cmp_dim}_train.h5'\n",
204203
" \n",
205-
" # Create compressed/noisy data, test set, and validation set\n",
206-
" train_data = create_training_test_data(clean_train, clean_test, cmp_dim, out_dim)\n",
207-
" \n",
208-
" # MNIST data\n",
209-
" clean_train = train_data[0]\n",
210-
" clean_test = train_data[1]\n",
211-
" clean_valid = train_data[2]\n",
212-
" \n",
213-
" # Compressed noisy signals\n",
214-
" noisy_train = train_data[3]\n",
215-
" noisy_test = train_data[4]\n",
216-
" noisy_valid = train_data[5]\n",
217-
" \n",
218-
" # Save with compression\n",
219-
" with h5py.File(f'{directory}/{file}_{cmp_dim}x{cmp_dim}.h5', 'w') as f:\n",
220-
" # Save original images\n",
221-
" f.create_dataset('clean_train', data=clean_train, compression='gzip')\n",
222-
" f.create_dataset('clean_test', data=clean_test, compression='gzip')\n",
223-
" f.create_dataset('clean_valid', data=clean_valid, compression='gzip')\n",
204+
"# Load MNIST\n",
205+
"(clean_train, _), (clean_test, _) = tf.keras.datasets.mnist.load_data()\n",
206+
"\n",
207+
"# If the compressed/noisy data does not exist, create it\n",
208+
"if not os.path.exists(dat_file):\n",
209+
" # Create compressed/noisy training/test data\n",
210+
" noisy_train, noisy_test = create_training_test_data(clean_train, clean_test, cmp_dim, out_dim)\n",
224211
" \n",
212+
" # Save compressed/noisy training/test data\n",
213+
" with h5py.File(dat_file, 'w') as f:\n",
225214
" # Save compressed noisy images\n",
226215
" f.create_dataset('noisy_train', data=noisy_train, compression='gzip')\n",
227216
" f.create_dataset('noisy_test', data=noisy_test, compression='gzip')\n",
228-
" f.create_dataset('noisy_valid', data=noisy_valid, compression='gzip')\n",
229217
" \n",
230-
"else: # If data exists, load it\n",
231-
" with h5py.File(f'{directory}/{file}_{cmp_dim}x{cmp_dim}.h5', 'r') as dat_file:\n",
232-
"\n",
233-
" # Load original MNIST images\n",
234-
" clean_train = dat_file['clean_train'][:]\n",
235-
" clean_test = dat_file['clean_test'][:]\n",
236-
" clean_valid = dat_file['clean_valid'][:]\n",
237-
"\n",
238-
" # Load compressed noisy images\n",
218+
"else: \n",
219+
" # Load compressed/noisy training data\n",
220+
" with h5py.File(dat_file, 'r') as dat_file:\n",
221+
" # Load compressed noisy training/test images\n",
239222
" noisy_train = dat_file['noisy_train'][:]\n",
240223
" noisy_test = dat_file['noisy_test'][:]\n",
241-
" noisy_valid = dat_file['noisy_valid'][:]"
224+
" \n",
225+
"# Prepare original mnist data for model training\n",
226+
"clean_train = np.array([matrix.reshape(out_dim**2,) for matrix in clean_train/255])\n",
227+
"clean_test = np.array([matrix.reshape(out_dim**2,) for matrix in clean_test/255])"
242228
]
243229
},
244230
{
@@ -254,9 +240,6 @@
254240
"metadata": {},
255241
"outputs": [],
256242
"source": [
257-
"# Display: training set\n",
258-
"#######################\n",
259-
"\n",
260243
"# Create a 2 by 4 subplot handle\n",
261244
"fig, axes = plt.subplots(2, 4, figsize=(9, 4.5))\n",
262245
"axes = axes.flatten()\n",
@@ -311,14 +294,14 @@
311294
"# Create & Compile the PAIN model\n",
312295
"PAIN = build_PAIN(in_dim=cmp_dim, out_dim=out_dim)\n",
313296
"\n",
314-
"# Create an RMSProp optimizer with a specific learning rate\n",
315-
"RMSp = tf.keras.optimizers.RMSprop(learning_rate=0.05)\n",
297+
"# Create an optimizer with a specific learning rate\n",
298+
"tf_opt = tf.keras.optimizers.Adam(learning_rate=0.01)\n",
316299
"\n",
317300
"# Compile the model\n",
318-
"PAIN.compile(optimizer=RMSp, loss='mean_squared_error')\n",
301+
"PAIN.compile(optimizer=tf_opt, loss='mean_squared_error')\n",
319302
"\n",
320303
"# Train model and saving fitting history\n",
321-
"fit_history = PAIN.fit(noisy_train, clean_train, epochs=120, batch_size=250, validation_data=(noisy_test, clean_test))"
304+
"fit_history = PAIN.fit(noisy_train, clean_train, epochs=num_epochs, batch_size=batch_size, validation_split=0.2)"
322305
]
323306
},
324307
{
@@ -403,7 +386,7 @@
403386
"axes[8].set_ylabel(f'Original\\n{out_dim} x {out_dim}\\nMNIST',fontdict={'fontsize': 12, 'fontfamily': 'serif'})\n",
404387
"\n",
405388
"# Adjust layout to decrease padding between subplots\n",
406-
"plt.subplots_adjust(wspace=0.1, hspace=0.00)\n",
389+
"plt.subplots_adjust(wspace=0.1, hspace=0)\n",
407390
"\n",
408391
"# Save results \n",
409392
"plt.savefig(f'{imdir}/{imfile}_{cmp_dim}x{cmp_dim}_train_PAIN.png')\n",
@@ -416,7 +399,7 @@
416399
"cell_type": "markdown",
417400
"metadata": {},
418401
"source": [
419-
"## Apply PAIN to Validation Set and Display Output"
402+
"## Apply PAIN to Test Set and Display Output"
420403
]
421404
},
422405
{
@@ -425,8 +408,8 @@
425408
"metadata": {},
426409
"outputs": [],
427410
"source": [
428-
"# Apply PAIN to all validation data inputs\n",
429-
"pred_valid_out = PAIN.predict(noisy_valid[0:101])"
411+
"# Apply PAIN to all test data inputs\n",
412+
"pred_test_out = PAIN.predict(noisy_test[0:101])"
430413
]
431414
},
432415
{
@@ -440,26 +423,26 @@
440423
"axes = axes.flatten()\n",
441424
"\n",
442425
"# Add title\n",
443-
"fig.suptitle('Application of PAIN Architecture\\n(MNIST Validation Set)',fontsize=20,fontweight='bold', fontfamily='serif')\n",
426+
"fig.suptitle('Application of PAIN Architecture\\n(MNIST Test Set)',fontsize=20,fontweight='bold', fontfamily='serif')\n",
444427
"\n",
445-
"# Shift window through validation dataset \n",
428+
"# Shift window through test dataset \n",
446429
"sft_idx = 0\n",
447430
"\n",
448431
"# Plot noisy and decompressed realizations in each subplot\n",
449432
"for idx in range(4):\n",
450433
" # Plot noisy\n",
451-
" axes[idx].imshow(noisy_valid[idx+sft_idx].reshape(cmp_dim,cmp_dim),cmap='gray')\n",
434+
" axes[idx].imshow(noisy_test[idx+sft_idx].reshape(cmp_dim,cmp_dim),cmap='gray')\n",
452435
" axes[idx].set_xticks([]) # Remove xticklabels\n",
453436
" axes[idx].set_yticks([]) # Remove yticklabels\n",
454437
" axes[idx].set_xlabel('⇩',fontdict={'fontsize': 25, 'fontweight': 'bold', 'fontfamily': 'serif', 'color':'blue'})\n",
455438
" \n",
456439
" # Plot decompressed with PAIN\n",
457-
" axes[idx+4].imshow(pred_valid_out[idx+sft_idx].reshape(28,28),cmap='gray')\n",
440+
" axes[idx+4].imshow(pred_test_out[idx+sft_idx].reshape(28,28),cmap='gray')\n",
458441
" axes[idx+4].set_xticks([]) # Remove xticklabels\n",
459442
" axes[idx+4].set_yticks([]) # Remove yticklabels\n",
460443
" \n",
461444
" # Plot original\n",
462-
" axes[idx+8].imshow(clean_valid[idx+sft_idx].reshape(28,28),cmap='gray')\n",
445+
" axes[idx+8].imshow(clean_test[idx+sft_idx].reshape(28,28),cmap='gray')\n",
463446
" axes[idx+8].set_xticks([]) # Remove xticklabels\n",
464447
" axes[idx+8].set_yticks([]) # Remove yticklabels\n",
465448
" \n",
@@ -469,10 +452,10 @@
469452
"axes[8].set_ylabel(f'(Original)\\n{out_dim} x {out_dim}\\nMNIST',fontdict={'fontsize': 12, 'fontfamily': 'serif'})\n",
470453
"\n",
471454
"# Adjust layout to decrease padding between subplots\n",
472-
"plt.subplots_adjust(wspace=0.1, hspace=0.25)\n",
455+
"plt.subplots_adjust(wspace=0.1, hspace=0)\n",
473456
"\n",
474457
"# Save results \n",
475-
"plt.savefig(f'{imdir}/{imfile}_{cmp_dim}x{cmp_dim}_valid_PAIN.png')\n",
458+
"plt.savefig(f'{imdir}/{imfile}_{cmp_dim}x{cmp_dim}_test_PAIN.png')\n",
476459
"\n",
477460
"# Display the figure\n",
478461
"plt.show()"

0 commit comments

Comments
 (0)