diff --git a/pydeepflow/model.py b/pydeepflow/model.py index c7197c0..eed726c 100644 --- a/pydeepflow/model.py +++ b/pydeepflow/model.py @@ -291,6 +291,190 @@ def backward(self, dOut): return dOut.reshape(original_shape) +class MaxPooling2D: + """ + A Max Pooling layer for 2D inputs (e.g., images). + + This layer down-samples the input by taking the maximum value over a pooling window. + + Attributes: + pool_size (tuple): The height and width of the pooling window. + stride (int): The step size of the pooling window. + cache (dict): A dictionary to store information needed for backpropagation, + including the input shape and the indices of the max values. + """ + def __init__(self, pool_size=(2, 2), stride=2): + """ + Initializes the MaxPooling2D layer. + + Args: + pool_size (tuple, optional): The size of the pooling window. Defaults to (2, 2). + stride (int, optional): The stride of the pooling operation. Defaults to 2. + """ + self.pool_height, self.pool_width = pool_size + self.stride = stride + self.cache = {} + + def forward(self, X): + """ + Performs the forward pass of the max pooling layer. + + Args: + X (np.ndarray): The input data of shape (N, H, W, C), where N is the batch size, + H is the height, W is the width, and C is the number of channels. + + Returns: + np.ndarray: The output of the max pooling layer. + """ + N, H, W, C = X.shape + out_h = (H - self.pool_height) // self.stride + 1 + out_w = (W - self.pool_width) // self.stride + 1 + + out = np.zeros((N, out_h, out_w, C)) + max_indices = np.zeros_like(out, dtype=int) + + for i in range(out_h): + for j in range(out_w): + h_start = i * self.stride + h_end = h_start + self.pool_height + w_start = j * self.stride + w_end = w_start + self.pool_width + + window = X[:, h_start:h_end, w_start:w_end, :] + out[:, i, j, :] = np.max(window, axis=(1, 2)) + + # Store the indices of the max values for backpropagation + reshaped_window = window.reshape(N, -1, C) + max_indices[:, i, j, :] = np.argmax(reshaped_window, axis=1) + + self.cache = {'X_shape': X.shape, 'max_indices': max_indices} + return out + + def backward(self, dOut): + """ + Performs the backward pass of the max pooling layer. + + Args: + dOut (np.ndarray): The gradient of the loss with respect to the output of this layer. + + Returns: + np.ndarray: The gradient of the loss with respect to the input of this layer. + """ + X_shape = self.cache['X_shape'] + max_indices = self.cache['max_indices'] + N, H, W, C = X_shape + _, out_h, out_w, _ = dOut.shape + + dX = np.zeros(X_shape) + + for i in range(out_h): + for j in range(out_w): + h_start = i * self.stride + h_end = h_start + self.pool_height + w_start = j * self.stride + w_end = w_start + self.pool_width + + # Get the gradient for the current window + grad = dOut[:, i, j, :][:, np.newaxis, :] + + # Get the indices of the max values + indices = max_indices[:, i, j, :] + + # Create a mask to place the gradients at the right positions + mask = np.zeros((N, self.pool_height * self.pool_width, C)) + + n_idx, c_idx = np.indices((N, C)) + mask[n_idx, indices, c_idx] = 1 + + # Reshape the mask to the window shape and add the gradients + dX[:, h_start:h_end, w_start:w_end, :] += mask.reshape(N, self.pool_height, self.pool_width, C) * grad + + return dX + + +class AveragePooling2D: + """ + An Average Pooling layer for 2D inputs. + + This layer down-samples the input by taking the average value over a pooling window. + + Attributes: + pool_size (tuple): The height and width of the pooling window. + stride (int): The step size of the pooling window. + cache (dict): A dictionary to store the input shape for backpropagation. + """ + def __init__(self, pool_size=(2, 2), stride=2): + """ + Initializes the AveragePooling2D layer. + + Args: + pool_size (tuple, optional): The size of the pooling window. Defaults to (2, 2). + stride (int, optional): The stride of the pooling operation. Defaults to 2. + """ + self.pool_height, self.pool_width = pool_size + self.stride = stride + self.cache = {} + + def forward(self, X): + """ + Performs the forward pass of the average pooling layer. + + Args: + X (np.ndarray): The input data of shape (N, H, W, C). + + Returns: + np.ndarray: The output of the average pooling layer. + """ + self.cache['X_shape'] = X.shape + N, H, W, C = X.shape + out_h = (H - self.pool_height) // self.stride + 1 + out_w = (W - self.pool_width) // self.stride + 1 + + out = np.zeros((N, out_h, out_w, C)) + + for i in range(out_h): + for j in range(out_w): + h_start = i * self.stride + h_end = h_start + self.pool_height + w_start = j * self.stride + w_end = w_start + self.pool_width + + window = X[:, h_start:h_end, w_start:w_end, :] + out[:, i, j, :] = np.mean(window, axis=(1, 2)) + + return out + + def backward(self, dOut): + """ + Performs the backward pass of the average pooling layer. + + Args: + dOut (np.ndarray): The gradient of the loss with respect to the output of this layer. + + Returns: + np.ndarray: The gradient of the loss with respect to the input of this layer. + """ + X_shape = self.cache['X_shape'] + N, H, W, C = X_shape + _, out_h, out_w, _ = dOut.shape + + dX = np.zeros(X_shape) + pool_area = self.pool_height * self.pool_width + + for i in range(out_h): + for j in range(out_w): + h_start = i * self.stride + h_end = h_start + self.pool_height + w_start = j * self.stride + w_end = w_start + self.pool_width + + # Distribute the gradient equally over the pooling window + grad = dOut[:, i, j, :][:, np.newaxis, np.newaxis, :] + dX[:, h_start:h_end, w_start:w_end, :] += grad / pool_area + + return dX + + # ==================================================================== # Multi_Layer_ANN (Dense-only training logic) - UNMODIFIED # ==================================================================== @@ -1295,6 +1479,27 @@ def __init__(self, layers_list, X_train, Y_train, activations, loss='categorical if layer_type == 'conv': if len(current_input_shape) != 3: raise ValueError("ConvLayer requires 4D input (N, H, W, C). Check previous layer configuration.") + + elif layer_type == 'maxpool': + pool_size = layer_config.get('pool_size', (2, 2)) + stride = layer_config.get('stride', 2) + pool_layer = MaxPooling2D(pool_size=pool_size, stride=stride) + self.layers_list.append(pool_layer) + H, W, C = current_input_shape + out_h = (H - pool_size[0]) // stride + 1 + out_w = (W - pool_size[1]) // stride + 1 + current_input_shape = (out_h, out_w, C) + + elif layer_type == 'avgpool': + pool_size = layer_config.get('pool_size', (2, 2)) + stride = layer_config.get('stride', 2) + pool_layer = AveragePooling2D(pool_size=pool_size, stride=stride) + self.layers_list.append(pool_layer) + H, W, C = current_input_shape + out_h = (H - pool_size[0]) // stride + 1 + out_w = (W - pool_size[1]) // stride + 1 + current_input_shape = (out_h, out_w, C) + in_c = current_input_shape[-1] out_c = layer_config['out_channels'] @@ -1376,8 +1581,7 @@ def forward_propagation(self, X): A_values = [X] # Stores all activation outputs (A0 = Input, A1, A2...) for layer_idx, layer in enumerate(self.layers_list): - if isinstance(layer, ConvLayer): - # --- ConvLayer forward --- + if isinstance(layer, (ConvLayer, Flatten, MaxPooling2D, AveragePooling2D)): current_activation = layer.forward(current_activation) A_values.append(current_activation) elif isinstance(layer, Flatten): @@ -1423,13 +1627,9 @@ def backpropagation(self, X, y, A_values, Z_values, learning_rate, clip_value=No for i in reversed(range(len(self.layers_list))): layer = self.layers_list[i] - if isinstance(layer, ConvLayer): - # --- ConvLayer backward: computes dW, db, and dIn for previous layer --- - dIn = layer.backward(dOut) # Populates layer.grads - # Store Conv gradients (W then b) - insert at beginning to maintain order - grads_to_update.insert(0, layer.grads['db']) # Insert bias first - grads_to_update.insert(0, layer.grads['dW']) # Insert weights second - dOut = dIn # Gradient for the previous layer (Flatten or another Conv) + if isinstance(layer, (ConvLayer, Flatten, MaxPooling2D, AveragePooling2D)): + dIn = layer.backward(dOut) + dOut = dIn elif isinstance(layer, Flatten): # --- Flatten backward: reshapes gradient for previous Conv layer --- diff --git a/tests/test_layers.py b/tests/test_layers.py new file mode 100644 index 0000000..90ead2b --- /dev/null +++ b/tests/test_layers.py @@ -0,0 +1,43 @@ +import unittest +import numpy as np +from pydeepflow.model import MaxPooling2D, AveragePooling2D + +class TestPoolingLayers(unittest.TestCase): + + def test_max_pooling_forward(self): + pool = MaxPooling2D(pool_size=(2, 2), stride=2) + X = np.arange(16).reshape(1, 4, 4, 1) + out = pool.forward(X) + expected_out = np.array([[[[ 5.], [ 7.]],[[13.], [15.]]]]) + self.assertEqual(out.shape, (1, 2, 2, 1)) + np.testing.assert_array_almost_equal(out, expected_out) + + def test_max_pooling_backward(self): + pool = MaxPooling2D(pool_size=(2, 2), stride=2) + X = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32).reshape(1, 4, 4, 1) + pool.forward(X) + dOut = np.ones((1, 2, 2, 1)) + dX = pool.backward(dOut) + expected_dX = np.array([[0,0,0,0],[0,1,0,1],[0,0,0,0],[0,1,0,1]], dtype=np.float32).reshape(1, 4, 4, 1) + np.testing.assert_array_almost_equal(dX, expected_dX) + + def test_avg_pooling_forward(self): + pool = AveragePooling2D(pool_size=(2, 2), stride=2) + X = np.arange(16).reshape(1, 4, 4, 1) + out = pool.forward(X) + expected_out = np.array([[[[ 2.5], [ 4.5]],[[10.5], [12.5]]]]) + self.assertEqual(out.shape, (1, 2, 2, 1)) + np.testing.assert_array_almost_equal(out, expected_out) + + def test_avg_pooling_backward(self): + pool = AveragePooling2D(pool_size=(2, 2), stride=2) + X = np.arange(16).reshape(1, 4, 4, 1) + pool.forward(X) + dOut = np.ones((1, 2, 2, 1)) + dX = pool.backward(dOut) + expected_dX = np.ones((4, 4)) * 0.25 + expected_dX = expected_dX.reshape(1, 4, 4, 1) + np.testing.assert_array_almost_equal(dX, expected_dX) + +if __name__ == '__main__': + unittest.main()