diff --git a/pydeepflow/__init__.py b/pydeepflow/__init__.py index c2c4211..2711150 100644 --- a/pydeepflow/__init__.py +++ b/pydeepflow/__init__.py @@ -11,6 +11,7 @@ from .gridSearch import GridSearchCV from .validation import ModelValidator from .introspection import ANNIntrospector, CNNIntrospector, ModelSummaryFormatter, create_introspector +from .preprocessing import ImageDataGenerator # Try to import optional components try: @@ -45,6 +46,7 @@ "CNNIntrospector", "ModelSummaryFormatter", "create_introspector", + "ImageDataGenerator", ] # Add optional components to __all__ if available diff --git a/pydeepflow/model.py b/pydeepflow/model.py index 0b4d9c8..8c676e1 100644 --- a/pydeepflow/model.py +++ b/pydeepflow/model.py @@ -22,6 +22,7 @@ from tqdm import tqdm from pydeepflow.validation import ModelValidator from pydeepflow.introspection import create_introspector, ModelSummaryFormatter +from pydeepflow.preprocessing import ImageDataGenerator # Added for data augmentation # ==================================================================== # IM2COL / COL2IM helper functions (USER'S TESTED WORKING VERSIONS) @@ -293,7 +294,14 @@ def backward(self, dOut): class MaxPooling2D: """A Max Pooling layer for 2D inputs.""" def __init__(self, pool_size=(2, 2), stride=2): - self.pool_height, self.pool_width = pool_size + + if isinstance(pool_size, int): + self.pool_height = self.pool_width = pool_size + elif isinstance(pool_size, (tuple, list)) and len(pool_size) == 2: + self.pool_height, self.pool_width = pool_size + else: + raise ValueError(f"Invalid pool_size '{pool_size}'. Must be int or tuple/list of 2 ints.") + self.stride = stride self.cache = {} self.params = {} @@ -332,7 +340,14 @@ def backward(self, dOut): class AveragePooling2D: """An Average Pooling layer for 2D inputs.""" def __init__(self, pool_size=(2, 2), stride=2): - self.pool_height, self.pool_width = pool_size + + if isinstance(pool_size, int): + self.pool_height = self.pool_width = pool_size + elif isinstance(pool_size, (tuple, list)) and len(pool_size) == 2: + self.pool_height, self.pool_width = pool_size + else: + raise ValueError(f"Invalid pool_size '{pool_size}'. Must be int or tuple/list of 2 ints.") + self.stride = stride self.cache = {} self.params = {} @@ -647,13 +662,16 @@ def backpropagation(self, X, y, activations, Z_values, learning_rate, clip_value for i in range(len(self.weights)): self.weights[i] -= learning_rate * self.regularization.apply_l2_regularization(self.weights[i], learning_rate, X.shape) - def fit(self, epochs, learning_rate=0.01, lr_scheduler=None, early_stop=None, X_val=None, y_val=None, checkpoint=None, verbose=False, clipping_threshold=None): + def fit(self, epochs=100, learning_rate=0.01, generator=None, steps_per_epoch=None, lr_scheduler=None, + early_stop=None, X_val=None, y_val=None, checkpoint=None, verbose=False, clipping_threshold=None): """ - Trains the neural network model. + Trains the neural network model using either a static dataset or a data generator. Args: - epochs (int): The number of epochs to train the model. + epochs (int): The number of epochs to train the model. Defaults to 100. learning_rate (float, optional): The learning rate for the optimizer. Defaults to 0.01. + generator (ImageDataGenerator, optional): A data generator. If provided, training uses batches from this generator. Defaults to None. + steps_per_epoch (int, optional): Batches per epoch when using a generator. Required if generator is provided. Defaults to None. lr_scheduler (object, optional): A learning rate scheduler. Defaults to None. early_stop (object, optional): An early stopping callback. Defaults to None. X_val (np.ndarray, optional): Validation features. Defaults to None. @@ -661,82 +679,118 @@ def fit(self, epochs, learning_rate=0.01, lr_scheduler=None, early_stop=None, X_ checkpoint (object, optional): A model checkpointing callback. Defaults to None. verbose (bool, optional): If True, prints training progress. Defaults to False. clipping_threshold (float, optional): The value for gradient clipping. Defaults to None. - - Raises: - AssertionError: If early stopping is enabled but no validation set is provided. """ - if early_stop: - assert X_val is not None and y_val is not None, "Validation set is required for early stopping" + if early_stop and (X_val is None or y_val is None): + raise ValueError("Validation set (X_val, y_val) is required for early stopping.") - for epoch in tqdm(range(epochs), desc="Training Progress", ncols=100, ascii="░▒█", colour='green', disable=not verbose): - start_time = time.time() + if generator and not isinstance(generator, ImageDataGenerator): + raise TypeError("`generator` must be an instance of ImageDataGenerator or None.") - # Adjust the learning rate using the scheduler if provided - if lr_scheduler is not None: - current_lr = lr_scheduler.get_lr(epoch) - else: - current_lr = learning_rate + if generator and steps_per_epoch is None: + raise ValueError("`steps_per_epoch` must be specified when using a generator.") - # Forward and Backpropagation - self.training = True - activations, Z_values = self.forward_propagation(self.X_train) - self.backpropagation(self.X_train, self.y_train, activations, Z_values, current_lr, clip_value=clipping_threshold) + # Determine number of epochs based on input + num_epochs = epochs if epochs is not None else 100 # Default epochs if none provided - self.training = False + for epoch in tqdm(range(num_epochs), desc="Training Progress", ncols=100, ascii="░▒█", colour='green', + disable=not verbose): + start_time = time.time() + current_lr = lr_scheduler.get_lr(epoch) if lr_scheduler else learning_rate + + epoch_train_loss = 0.0 + epoch_train_accuracy = 0.0 + + # --- Training Loop --- + if generator: + # Training with Data Generator + for step in range(steps_per_epoch): + X_batch, y_batch = next(generator) + X_batch_device = self.device.array(X_batch) + y_batch_device = self.device.array(y_batch) + + self.training = True + activations, Z_values = self.forward_propagation(X_batch_device) + self.backpropagation(X_batch_device, y_batch_device, activations, Z_values, current_lr, + clip_value=clipping_threshold) + self.training = False + + # Calculate batch metrics + loss = self.loss_func(y_batch_device, activations[-1], self.device) + # Convert to numpy for accuracy calculation if needed + preds_np = self.device.asnumpy(activations[-1]) + y_batch_np = self.device.asnumpy(y_batch_device) + accuracy = np.mean(np.argmax(preds_np, axis=1) == np.argmax(y_batch_np, + axis=1)) if self.output_activation != 'sigmoid' else np.mean( + (preds_np >= 0.5).astype(int) == y_batch_np) + + epoch_train_loss += loss + epoch_train_accuracy += accuracy + + # Average metrics over steps + train_loss = epoch_train_loss / steps_per_epoch + train_accuracy = epoch_train_accuracy / steps_per_epoch - # Compute training loss and accuracy - train_loss = self.loss_func(self.y_train, activations[-1], self.device) - train_accuracy = np.mean((activations[-1] >= 0.5).astype(int) == self.y_train) if self.output_activation == 'sigmoid' else np.mean(np.argmax(activations[-1], axis=1) == np.argmax(self.y_train, axis=1)) + else: + # Standard Training (without generator) + self.training = True + activations, Z_values = self.forward_propagation(self.X_train) + self.backpropagation(self.X_train, self.y_train, activations, Z_values, current_lr, + clip_value=clipping_threshold) + self.training = False + + train_loss = self.loss_func(self.y_train, activations[-1], self.device) + preds_np = self.device.asnumpy(activations[-1]) + y_train_np = self.device.asnumpy(self.y_train) + train_accuracy = np.mean(np.argmax(preds_np, axis=1) == np.argmax(y_train_np, + axis=1)) if self.output_activation != 'sigmoid' else np.mean( + (preds_np >= 0.5).astype(int) == y_train_np) + + # Store history + self.history['train_loss'].append(train_loss) + self.history['train_accuracy'].append(train_accuracy) - # # Debugging output - # print(f"Computed Train Loss: {train_loss}, Train Accuracy: {train_accuracy}") + # --- Validation Step --- + val_loss = None + val_accuracy = None + if X_val is not None and y_val is not None: + X_val_device = self.device.array(X_val) + y_val_device = self.device.array(y_val) + val_activations, _ = self.forward_propagation(X_val_device) + val_loss = self.loss_func(y_val_device, val_activations[-1], self.device) - if train_loss is None or train_accuracy is None: - print("Warning: train_loss or train_accuracy is None!") - continue # Skip this epoch if values are not valid + val_preds_np = self.device.asnumpy(val_activations[-1]) + # Ensure y_val is numpy for comparison + y_val_np = y_val if isinstance(y_val, np.ndarray) else self.device.asnumpy(y_val_device) - # Validation step - val_loss = val_accuracy = None - if X_val is not None and y_val is not None: - val_activations, _ = self.forward_propagation(self.device.array(X_val)) - val_loss = self.loss_func(self.device.array(y_val), val_activations[-1], self.device) - val_accuracy = np.mean((val_activations[-1] >= 0.5).astype(int) == y_val) if self.output_activation == 'sigmoid' else np.mean(np.argmax(val_activations[-1], axis=1) == np.argmax(y_val, axis=1)) + val_accuracy = np.mean(np.argmax(val_preds_np, axis=1) == np.argmax(y_val_np, + axis=1)) if self.output_activation != 'sigmoid' else np.mean( + (val_preds_np >= 0.5).astype(int) == y_val_np) - # Store training history for plotting - self.history['train_loss'].append(train_loss) - self.history['train_accuracy'].append(train_accuracy) - if val_loss is not None: self.history['val_loss'].append(val_loss) self.history['val_accuracy'].append(val_accuracy) - # Checkpoint saving logic - if checkpoint is not None and X_val is not None: - if checkpoint.should_save(epoch, val_loss): - checkpoint.save_weights(epoch, self.weights, self.biases, val_loss) - - if verbose and (epoch % 10 == 0): - # Display progress on the same line - sys.stdout.write( - f"\rEpoch {epoch + 1}/{epochs} | " - f"Train Loss: {train_loss:.4f} | " - f"Accuracy: {train_accuracy:.2f}% | " - f"Val Loss: {val_loss:.4f} | " - f"Val Accuracy: {val_accuracy:.2f}% | " - f"Learning Rate: {current_lr:.6f} " - ) - sys.stdout.flush() - - # Early stopping - if early_stop: - early_stop(val_loss) - if early_stop.early_stop: - print('\n', "#" * 150, '\n\n', "early stop at - " - f"Epoch {epoch + 1}/{epochs} Train Loss: {train_loss:.4f} Accuracy: {train_accuracy * 100:.2f}% " - f"Val Loss: {val_loss:.4f} Val Accuracy: {val_accuracy * 100:.2f}% " - f"Learning Rate: {current_lr:.6f}", '\n\n', "#" * 150) - break - - print("Training Completed!") + # Checkpoint saving + if checkpoint is not None: + if checkpoint.should_save(epoch, val_loss): + checkpoint.save_weights(epoch, self.weights, self.biases, val_loss) + + # Early stopping check + if early_stop: + early_stop(val_loss) + if early_stop.early_stop: + print(f"\nEarly stopping triggered at epoch {epoch + 1}") + break # Exit epoch loop + + # --- Verbose Output --- + if verbose: # Print at the end of each epoch + log_msg = f"\rEpoch {epoch + 1}/{num_epochs} - loss: {train_loss:.4f} - accuracy: {train_accuracy:.4f}" + if val_loss is not None: + log_msg += f" - val_loss: {val_loss:.4f} - val_accuracy: {val_accuracy:.4f}" + log_msg += f" - lr: {current_lr:.6f} " + # Use print instead of sys.stdout.write for tqdm compatibility + tqdm.write(log_msg) + + print("\nTraining Completed!") def predict(self, X): """ @@ -1372,17 +1426,57 @@ def __init__(self, layers_list, X_train, Y_train, loss='categorical_crossentropy self.trainable_params.extend([conv_layer.params['W'], conv_layer.params['b']]) elif layer_type == 'maxpool': - pool_size, stride = layer_config.get('pool_size', (2, 2)), layer_config.get('stride', 2) - self.layers_list.append(MaxPooling2D(pool_size=pool_size, stride=stride)) + # Get pool_size and stride from config, providing defaults + pool_size_config = layer_config.get('pool_size', (2, 2)) + stride = layer_config.get('stride', 2) + + # Add the layer instance + self.layers_list.append(MaxPooling2D(pool_size=pool_size_config, stride=stride)) + + # Calculate output shape H, W, C = current_input_shape - out_h, out_w = (H - pool_size[0])//stride + 1, (W - pool_size[1])//stride + 1 + + # *** FIX: Handle int or tuple for pool_size BEFORE calculating shape *** + if isinstance(pool_size_config, int): + pool_h = pool_w = pool_size_config + elif isinstance(pool_size_config, (tuple, list)) and len(pool_size_config) == 2: + pool_h, pool_w = pool_size_config + else: + raise ValueError( + f"Invalid pool_size '{pool_size_config}' for MaxPooling2D. Must be int or tuple/list of 2 ints.") + + # Calculate output height and width + out_h = (H - pool_h) // stride + 1 + out_w = (W - pool_w) // stride + 1 + + # Update current shape for the next layer current_input_shape = (out_h, out_w, C) elif layer_type == 'avgpool': - pool_size, stride = layer_config.get('pool_size', (2, 2)), layer_config.get('stride', 2) - self.layers_list.append(AveragePooling2D(pool_size=pool_size, stride=stride)) + # Get pool_size and stride from config, providing defaults + pool_size_config = layer_config.get('pool_size', (2, 2)) + stride = layer_config.get('stride', 2) + + # Add the layer instance + self.layers_list.append(AveragePooling2D(pool_size=pool_size_config, stride=stride)) + + # Calculate output shape H, W, C = current_input_shape - out_h, out_w = (H - pool_size[0])//stride + 1, (W - pool_size[1])//stride + 1 + + # *** FIX: Handle int or tuple for pool_size BEFORE calculating shape *** + if isinstance(pool_size_config, int): + pool_h = pool_w = pool_size_config + elif isinstance(pool_size_config, (tuple, list)) and len(pool_size_config) == 2: + pool_h, pool_w = pool_size_config + else: + raise ValueError( + f"Invalid pool_size '{pool_size_config}' for AveragePooling2D. Must be int or tuple/list of 2 ints.") + + # Calculate output height and width + out_h = (H - pool_h) // stride + 1 + out_w = (W - pool_w) // stride + 1 + + # Update current shape for the next layer current_input_shape = (out_h, out_w, C) elif layer_type == 'flatten': @@ -1533,18 +1627,20 @@ def backpropagation(self, X, y, A_values, Z_values, learning_rate, clip_value=No # --- Methods copied from Multi_Layer_ANN and adapted for CNN structure --- - - def fit(self, epochs, learning_rate=0.01, lr_scheduler=None, early_stop=None, X_val=None, y_val=None, checkpoint=None, verbose=False, clipping_threshold=None): + + def fit(self, epochs=50, learning_rate=0.01, generator=None, steps_per_epoch=None, lr_scheduler=None, early_stop=None, X_val=None, y_val=None, + checkpoint=None, verbose=False, clipping_threshold=None): # Validate training hyperparameters validator = ModelValidator(device=None) batch_size = validator.validate_training_hyperparameters( learning_rate, epochs, 32, self.X_train # Default batch_size=32 for CNN ) - + if early_stop: assert X_val is not None and y_val is not None, "Validation set required for early stopping" - for epoch in tqdm(range(epochs), desc="Training Progress", ncols=100, ascii="░▒█", colour='green', disable=not verbose): + for epoch in tqdm(range(epochs), desc="Training Progress", ncols=100, ascii="░▒█", colour='green', + disable=not verbose): start_time = time.time() if lr_scheduler is not None: @@ -1554,12 +1650,15 @@ def fit(self, epochs, learning_rate=0.01, lr_scheduler=None, early_stop=None, X_ self.training = True activations, Z_values = self.forward_propagation(self.X_train) - self.backpropagation(self.X_train, self.y_train, activations, Z_values, current_lr, clip_value=clipping_threshold) + self.backpropagation(self.X_train, self.y_train, activations, Z_values, current_lr, + clip_value=clipping_threshold) self.training = False # metrics train_loss = self.loss_func(self.y_train, activations[-1], self.device) - train_accuracy = np.mean((activations[-1] >= 0.5).astype(int) == self.y_train) if self.output_activation == 'sigmoid' else np.mean(np.argmax(activations[-1], axis=1) == np.argmax(self.y_train, axis=1)) + train_accuracy = np.mean((activations[-1] >= 0.5).astype( + int) == self.y_train) if self.output_activation == 'sigmoid' else np.mean( + np.argmax(activations[-1], axis=1) == np.argmax(self.y_train, axis=1)) if train_loss is None or train_accuracy is None: print("Warning: train_loss or train_accuracy is None!") @@ -1569,7 +1668,9 @@ def fit(self, epochs, learning_rate=0.01, lr_scheduler=None, early_stop=None, X_ if X_val is not None and y_val is not None: val_activations, _ = self.forward_propagation(self.device.array(X_val)) val_loss = self.loss_func(self.device.array(y_val), val_activations[-1], self.device) - val_accuracy = np.mean((val_activations[-1] >= 0.5).astype(int) == y_val) if self.output_activation == 'sigmoid' else np.mean(np.argmax(val_activations[-1], axis=1) == np.argmax(y_val, axis=1)) + val_accuracy = np.mean((val_activations[-1] >= 0.5).astype( + int) == y_val) if self.output_activation == 'sigmoid' else np.mean( + np.argmax(val_activations[-1], axis=1) == np.argmax(y_val, axis=1)) self.history['train_loss'].append(train_loss) self.history['train_accuracy'].append(train_accuracy) @@ -1595,7 +1696,8 @@ def fit(self, epochs, learning_rate=0.01, lr_scheduler=None, early_stop=None, X_ early_stop(val_loss) if early_stop.early_stop: print('\n', "#" * 80) - print(f"Early stop at epoch {epoch + 1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.4f}") + print( + f"Early stop at epoch {epoch + 1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.4f}") print('#' * 80) break diff --git a/pydeepflow/preprocessing.py b/pydeepflow/preprocessing.py new file mode 100644 index 0000000..4b0627f --- /dev/null +++ b/pydeepflow/preprocessing.py @@ -0,0 +1,123 @@ +import numpy as np +from scipy.ndimage import rotate, shift, zoom + + +class ImageDataGenerator: + """ + Generates batches of tensor image data with real-time data augmentation. + Assumes image format is (height, width, channels). + """ + + def __init__(self, + rotation_range=0., + width_shift_range=0., + height_shift_range=0., + zoom_range=0., + horizontal_flip=False, + vertical_flip=False): + """ + Initializes the ImageDataGenerator. + + Args: + rotation_range (float): Range (in degrees, 0 to 180) for random rotations. + width_shift_range (float): Fraction of total width (0 to 1) for random horizontal shifts. + height_shift_range (float): Fraction of total height (0 to 1) for random vertical shifts. + zoom_range (float or tuple/list): Range for random zoom. + - If float, zoom will be [1-zoom_range, 1+zoom_range]. + - If tuple/list [lower, upper], zoom will be [lower, upper]. + horizontal_flip (bool): Whether to randomly flip inputs horizontally. + vertical_flip (bool): Whether to randomly flip inputs vertically. + """ + self.rotation_range = rotation_range + self.width_shift_range = width_shift_range + self.height_shift_range = height_shift_range + + if isinstance(zoom_range, (float, int)): + self.zoom_range = [1 - zoom_range, 1 + zoom_range] + elif len(zoom_range) == 2: + self.zoom_range = [zoom_range[0], zoom_range[1]] + else: + raise ValueError("`zoom_range` should be a float or " + "a tuple or list of two floats. " + f"Received: {zoom_range}") + + self.horizontal_flip = horizontal_flip + self.vertical_flip = vertical_flip + + def _apply_random_transform(self, x): + """Applies a random transformation to a single image x.""" + img_h, img_w, img_c = x.shape + + # Rotation + if self.rotation_range > 0: + theta = np.random.uniform(-self.rotation_range, self.rotation_range) + x = rotate(x, theta, reshape=False, order=1, mode='constant', cval=0.) + + # Height Shift + if self.height_shift_range > 0: + ty = np.random.uniform(-self.height_shift_range, self.height_shift_range) * img_h + x = shift(x, (ty, 0, 0), order=1, mode='constant', cval=0.) + + # Width Shift + if self.width_shift_range > 0: + tx = np.random.uniform(-self.width_shift_range, self.width_shift_range) * img_w + x = shift(x, (0, tx, 0), order=1, mode='constant', cval=0.) + + # Zoom + if self.zoom_range[0] != 1.0 or self.zoom_range[1] != 1.0: + zx = zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) + zoomed_x = zoom(x, (zy, zx, 1), order=1) + zh, zw = zoomed_x.shape[:2] + + if zy < 1.0: # Zoom out - Pad + h_pad = (img_h - zh) // 2 + w_pad = (img_w - zw) // 2 + padded_x = np.zeros_like(x) + padded_x[h_pad:h_pad + zh, w_pad:w_pad + zw, :] = zoomed_x + x = padded_x + else: # Zoom in - Crop + h_crop = (zh - img_h) // 2 + w_crop = (zw - img_w) // 2 + x = zoomed_x[h_crop:h_crop + img_h, w_crop:w_crop + img_w, :] + + # Horizontal Flip + if self.horizontal_flip: + if np.random.random() < 0.5: + x = np.fliplr(x) + + # Vertical Flip + if self.vertical_flip: + if np.random.random() < 0.5: + x = np.flipud(x) + + return x + + def flow(self, X, y, batch_size=32): + """ + Generates batches of augmented data indefinitely. + + Args: + X (np.ndarray): Input data (N, H, W, C). + y (np.ndarray): Target data (N, ...). + batch_size (int): Size of the batches to generate. + + Yields: + tuple: A tuple (X_batch_augmented, y_batch). + """ + n_samples = X.shape[0] + indices = np.arange(n_samples) + + while True: + # Shuffle indices at the start of each epoch pass + np.random.shuffle(indices) + + for start_idx in range(0, n_samples, batch_size): + end_idx = min(start_idx + batch_size, n_samples) + batch_indices = indices[start_idx:end_idx] + + X_batch = X[batch_indices] + y_batch = y[batch_indices] + + X_batch_augmented = np.array([self._apply_random_transform(img) for img in X_batch]) + + yield X_batch_augmented, y_batch