diff --git a/ctlearn/core/model.py b/ctlearn/core/model.py index 80ff0668..62875ce9 100644 --- a/ctlearn/core/model.py +++ b/ctlearn/core/model.py @@ -4,7 +4,8 @@ from abc import abstractmethod import keras - +from tensorflow.keras.layers import Layer, Concatenate, Input +import tensorflow as tf from ctapipe.core import Component from ctapipe.core.traits import Bool, Int, CaselessStrEnum, List, Dict, Unicode, Path from ctlearn.core.attention import ( @@ -20,6 +21,7 @@ "SingleCNN", "ResNet", "LoadedModel", + "Attention_ResNet" ] @@ -857,3 +859,563 @@ def _build_backbone(self, input_shape): network_input, network_output, name=self.backbone_name ) return backbone_model, network_input + + + +class ChannelAttention(Layer): + def __init__(self, channel_attention_reduction=16, name=None): + super(ChannelAttention, self).__init__(name=name) + self.channel_attention_reduction = channel_attention_reduction + + def build(self, input_shape): + channels = input_shape[-1] + self.global_avg_pool = keras.layers.GlobalAveragePooling2D() + self.fc1 = keras.layers.Dense(channels // self.channel_attention_reduction, activation='relu', use_bias=False) + self.fc2 = keras.layers.Dense(channels, activation='sigmoid', use_bias=False) + super(ChannelAttention, self).build(input_shape) + + def call(self, x): + avg_out = self.global_avg_pool(x) + avg_out = self.fc1(avg_out) + avg_out = self.fc2(avg_out) + attention = keras.layers.Reshape((1, 1, -1))(avg_out) + return x * attention + + +class SpatialAttention(Layer): + def __init__(self, se_kernel_size=1, name=None): + super(SpatialAttention, self).__init__(name=name) + self.se_kernel_size = se_kernel_size + + def build(self, input_shape): + # Define un kernel (peso) de forma (1, 1, input_shape[-1], 1) + # input_shape[-1] será 3 para imágenes RGB + self.kernel = self.add_weight(name='kernel', + shape=(self.se_kernel_size,self.se_kernel_size,input_shape[-1],1), + initializer='uniform', + trainable=True) + super(SpatialAttention, self).build(input_shape,) + + def call(self, x): + # Aplica una convolución 2D usando el kernel definido + attention = tf.nn.sigmoid(tf.nn.conv2d(x, self.kernel, strides=[1,1], padding='SAME')) + # Multiplica la entrada original por la máscara de atención + return x * attention + + +class CBAM(Layer): + def __init__(self, channel_attention_reduction=16, se_kernel_size=7, name=None): + super(CBAM, self).__init__(name=name) + self.channel_attention = ChannelAttention(channel_attention_reduction) + self.spatial_attention = SpatialAttention(se_kernel_size) + + def call(self, x): + x = self.channel_attention(x) + x = self.spatial_attention(x) + return x + + +class Attention_ResNet(CTLearnModel): + """ + ``ResNet`` is a residual neural network model. + + This class extends the functionality of ``CTLearnModel`` by implementing + methods to build a residual neural network model. + """ + + name = Unicode( + "Attention_ThinResNet", + help="Name of the model backbone.", + ).tag(config=True) + + init_layer = Dict( + default_value=None, + allow_none=True, + help=( + "Parameters for the first convolutional layer. " + "E.g. ``{'filters': 64, 'kernel_size': 7, 'strides': 2}``." + ), + ).tag(config=True) + + init_max_pool = Dict( + default_value=None, + allow_none=True, + help=( + "Parameters for the first max pooling layer. " + "E.g. ``{'size': 3, 'strides': 2}``." + ), + ).tag(config=True) + + residual_block_type = CaselessStrEnum( + ["basic", "bottleneck"], + default_value="bottleneck", + allow_none=False, + help="Type of residual block to use.", + ).tag(config=True) + + architecture = List( + trait=Dict(), + default_value=[{'filters': 48, 'blocks': 2}, {'filters': 96, 'blocks': 3}, {'filters': 128, 'blocks': 3}, {'filters': 256, 'blocks': 3}], + allow_none=False, + help=( + "List of dicts containing the number of filters and residual blocks. " + "E.g. ``[{'filters': 12, 'blocks': 2}, ...]``." + ), + ).tag(config=True) + + se_kernel_size = Int( + default_value=1, + allow_none=False, + help="Kernel size of the Spatial Attention layer", + ).tag(config=True) + + + channel_attention_reduction = Int( + default_value=16, + allow_none=False, + help="Reduction size of the Channel Attention layer", + ).tag(config=True) + + attention_type = CaselessStrEnum( + ["spatial", "channel", "both"], + default_value="spatial", + allow_none=False, + help="Type of attention layer(s) to apply", + ).tag(config=True) + + + attention_location = CaselessStrEnum( + ["initial", "after_conv"], + default_value="initial", + allow_none=False, + help="Location of the attention layer(s) to apply", + ).tag(config=True) + + + def __init__( + self, + input_shape, + tasks, + config=None, + parent=None, + **kwargs, + ): + super().__init__( + config=config, + parent=parent, + **kwargs, + ) + + + + # Validate the architecture trait + for layer in self.architecture: + validate_trait_dict(layer, ["filters", "blocks"]) + # Validate the initial layers trait + if self.init_layer is not None: + validate_trait_dict(self.init_layer, ["filters", "kernel_size", "strides"]) + if self.init_max_pool is not None: + validate_trait_dict(self.init_max_pool, ["size", "strides"]) + + # Construct the name of the backbone model by appending "_block" to the model name + self.backbone_name = self.name + "_block" + + # Build the ResNet model backbone + self.backbone_model, self.input_layer = self._build_backbone(input_shape) + backbone_output = self.backbone_model(self.input_layer) + # Validate the head trait with the provided tasks + validate_trait_dict(self.head, tasks) + # Build the fully connected head depending on the tasks + self.logits = build_fully_connect_head(backbone_output, self.head, tasks) + + self.model = keras.Model(self.input_layer, self.logits, name="CTLearn_model") + print(self.model.summary(expand_nested=True)) + + def _build_backbone(self, input_shape): + """ + Build the ResNet model backbone. + + Function to build the backbone of the ResNet model using the specified parameters. + + Parameters + ---------- + input_shape : tuple + Shape of the input data (batch_size, height, width, channels). + + Returns + ------- + backbone_model : keras.Model + Keras model object representing the ResNet backbone. + network_input : keras.Input + Keras input layer object for the backbone model. + """ + # Define the input layer from the input shape + network_input = Input(shape=input_shape, name="input") + + # Apply initial padding if specified + if self.init_padding > 0: + network_input = keras.layers.ZeroPadding2D( + padding=self.init_padding, + kernel_size=self.init_layer["kernel_size"], + strides=self.init_layer["strides"], + name=self.backbone_name + "_padding", + )(network_input) + # Apply initial convolutional layer if specified + if self.init_layer is not None: + network_input = keras.layers.Conv2D( + filters=self.init_layer["filters"], + kernel_size=self.init_layer["kernel_size"], + strides=self.init_layer["strides"], + name=self.backbone_name + "_conv1_conv", + )(network_input) + # Apply max pooling if specified + if self.init_max_pool is not None: + network_input = keras.layers.MaxPool2D( + pool_size=self.init_max_pool["size"], + strides=self.init_max_pool["strides"], + name=self.backbone_name + "_pool1_pool", + )(network_input) + # Build the residual blocks + + engine_output = self._stacked_res_blocks( + network_input, + architecture=self.architecture, + residual_block_type=self.residual_block_type, + attention=self.attention, + name=self.backbone_name + + ) + + # Apply global average pooling as the final layer of the backbone + network_output = keras.layers.GlobalAveragePooling2D( + name=self.backbone_name + "_global_avgpool" + )(engine_output) + + # Create the backbone model + backbone_model = keras.Model( + network_input, network_output, name=self.backbone_name + ) + return backbone_model, network_input + + + def _stacked_res_blocks(self, inputs, architecture, residual_block_type, attention, name=None): + """ + Build a stack of residual blocks for the CTLearn model. + + This function constructs a stack of residual blocks, which are used to build the backbone of the CTLearn model. + Each residual block consists of a series of convolutional layers with skip connections. + + Parameters + ---------- + inputs : keras.layers.Layer + Input Keras layer to the residual blocks. + architecture : list of dict + List of dictionaries containing the architecture of the ResNet model, which includes: + - Number of filters for the convolutional layers in the residual blocks. + - Number of residual blocks to stack. + residual_block_type : str + Type of residual block to use. Options are 'basic' or 'bottleneck'. + attention : dict + Dictionary containing the configuration parameters for the attention mechanism. + name : str, optional + Label for the model. + + Returns + ------- + x : keras.layers.Layer + Output Keras layer after passing through the stack of residual blocks. + """ + + # Get hyperparameters for the model architecture + filters_list = [ + layer["filters"] + for layer in architecture + ] + blocks_list = [ + layer["blocks"] + for layer in architecture + ] + + + if self.attention_type == "spatial": + attention_layer = SpatialAttention(self.se_kernel_size,"SPATIAL")(inputs) + elif self.attention_type == "channel": + attention_layer = ChannelAttention(self.channel_attention_reduction,"CHANNEL")(inputs) + else: + attention_layer = CBAM(self.channel_attention_reduction,self.se_kernel_size,"CBAM")(inputs) + + + + # Build the ResNet model + x = self._stack_fn( + attention_layer, + filters_list[0], + blocks_list[0], + residual_block_type, + stride=1, + attention=attention, + name=name + "_conv2", + ) + + if self.attention_location != 'initial': + + if self.attention_type == "spatial": + x = SpatialAttention(self.se_kernel_size,name + "_SPATIAL_conv2")(x) + elif self.attention_type == "channel": + x = ChannelAttention(self.channel_attention_reduction,name + "_CHANNEL_conv2")(x) + else: + x = CBAM(self.channel_attention_reduction,self.se_kernel_size,name + "_CBAM_conv2")(x) + + + for i, (filters, blocks) in enumerate(zip(filters_list[1:], blocks_list[1:])): + x = self._stack_fn( + x, + filters, + blocks, + residual_block_type, + attention=attention, + name=name + "_conv" + str(i + 3), + ) + + if self.attention_location == 'after_conv': + if self.attention_type == "spatial": + x = SpatialAttention(self.se_kernel_size,name + "_SPATIAL_conv" + str(i + 3))(x) + elif self.attention_type == "channel": + x = ChannelAttention(self.channel_attention_reduction,name + "_CHANNEL_conv" + str(i + 3))(x) + else: + x = CBAM(self.channel_attention_reduction,self.se_kernel_size,name + "-CBAM_conv" + str(i + 3))(x) + + return x + + + def _stack_fn( + self, + inputs, + filters, + blocks, + residual_block_type, + stride=2, + attention=None, + name=None, + ): + """ + Stack residual blocks for the CTLearn model. + + This function constructs a stack of residual blocks, which are used to build the backbone of the CTLearn model. + Each residual block can be of different types (e.g., basic or bottleneck) and can include attention mechanisms. + + Parameters + ---------- + inputs : keras.layers.Layer + Input tensor to the residual blocks. + filters : int + Number of filters for the bottleneck layer in a block. + blocks : int + Number of residual blocks to stack. + residual_block_type : str + Type of residual block ('basic' or 'bottleneck'). + stride : int, optional + Stride for the first layer in the first block. Default is 2. + attention : dict, optional + Configuration parameters for the attention mechanism. Default is None. + name : str, optional + Label for the stack. Default is None. + + Returns + ------- + keras.layers.Layer + Output tensor for the stacked blocks. + """ + + res_blocks = { + "basic": self._basic_residual_block, + "bottleneck": self._bottleneck_residual_block, + } + + x = res_blocks[residual_block_type]( + inputs, + filters, + stride=stride, + attention=attention, + name=name + "_block1", + ) + for i in range(2, blocks + 1): + x = res_blocks[residual_block_type]( + x, + filters, + conv_shortcut=False, + attention=attention, + name=name + "_block" + str(i), + ) + + return x + + + def _basic_residual_block( + self, + inputs, + filters, + kernel_size=3, + stride=1, + conv_shortcut=True, + attention=None, + name=None, + ): + """ + Build a basic residual block for the CTLearn model. + + This function constructs a basic residual block, which is a fundamental building block + of ResNet architectures. The block consists of two convolutional layers with an optional + convolutional shortcut, and can include attention mechanisms. + + Parameters + ---------- + inputs : keras.layers.Layer + Input tensor to the residual block. + filters : int + Number of filters for the convolutional layers. + kernel_size : int, optional + Size of the convolutional kernel. Default is 3. + stride : int, optional + Stride for the convolutional layers. Default is 1. + conv_shortcut : bool, optional + Whether to use a convolutional layer for the shortcut connection. Default is True. + attention : dict, optional + Configuration parameters for the attention mechanism. Default is None. + name : str, optional + Name for the residual block. Default is None. + + Returns + ------- + keras.layers.Layer + Output tensor after applying the residual block. + """ + + if conv_shortcut: + shortcut = keras.layers.Conv2D( + filters=filters, kernel_size=1, strides=stride, name=name + "_0_conv" + )(inputs) + else: + shortcut = inputs + + x = keras.layers.Conv2D( + filters=filters, + kernel_size=kernel_size, + strides=stride, + padding="same", + activation="relu", + name=name + "_1_conv", + )(inputs) + x = keras.layers.Conv2D( + filters=filters, + kernel_size=kernel_size, + padding="same", + activation="relu", + name=name + "_2_conv", + )(x) + + # Attention mechanism + if attention is not None: + if attention["mechanism"] == "Dual-SE": + x = dual_squeeze_excite_block( + x, attention["reduction_ratio"], name=name + "_dse" + ) + elif attention["mechanism"] == "Channel-SE": + x = channel_squeeze_excite_block( + x, attention["reduction_ratio"], name=name + "_cse" + ) + elif attention["mechanism"] == "Spatial-SE": + x = spatial_squeeze_excite_block(x, name=name + "_sse") + + x = keras.layers.Add(name=name + "_add")([shortcut, x]) + x = keras.layers.ReLU(name=name + "_out")(x) + + return x + + + def _bottleneck_residual_block( + self, + inputs, + filters, + kernel_size=3, + stride=1, + conv_shortcut=True, + attention=None, + name=None, + ): + """ + Build a bottleneck residual block for the CTLearn model. + + This function constructs a bottleneck residual block, which is a fundamental building block of + ResNet architectures. The block consists of three convolutional layers: a 1x1 convolution to reduce + dimensionality, a 3x3 convolution for main computation, and another 1x1 convolution to restore dimensionality. + It also includes an optional shortcut connection and can include attention mechanisms. + + Parameters + ---------- + inputs : keras.layers.Layer + Input tensor to the residual block. + filters : int + Number of filters for the convolutional layers. + kernel_size : int, optional + Size of the convolutional kernel. Default is 3. + stride : int, optional + Stride for the convolutional layers. Default is 1. + conv_shortcut : bool, optional + Whether to use a convolutional layer for the shortcut connection. Default is True. + attention : dict, optional + Configuration parameters for the attention mechanism. Default is None. + name : str, optional + Name for the residual block. Default is None. + + Returns + ------- + output : keras.layers.Layer + Output layer of the residual block. + """ + + if conv_shortcut: + shortcut = keras.layers.Conv2D( + filters=4 * filters, + kernel_size=1, + strides=stride, + name=name + "_0_conv", + )(inputs) + else: + shortcut = inputs + + x = keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=stride, + activation="relu", + name=name + "_1_conv", + )(inputs) + x = keras.layers.Conv2D( + filters=filters, + kernel_size=kernel_size, + padding="same", + activation="relu", + name=name + "_2_conv", + )(x) + x = keras.layers.Conv2D(filters=4 * filters, kernel_size=1, name=name + "_3_conv")( + x + ) + + # Attention mechanism + if attention is not None: + if attention["mechanism"] == "Dual-SE": + x = dual_squeeze_excite_block( + x, attention["reduction_ratio"], name=name + "_dse" + ) + elif attention["mechanism"] == "Channel-SE": + x = channel_squeeze_excite_block( + x, attention["reduction_ratio"], name=name + "_cse" + ) + elif attention["mechanism"] == "Spatial-SE": + x = spatial_squeeze_excite_block(x, name=name + "_sse") + + x = keras.layers.Add(name=name + "_add")([shortcut, x]) + x = keras.layers.ReLU(name=name + "_out")(x) + + return x