From a88bb97a2e2f7fbf169f9b316e4d0cc5927e53b8 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 21 Sep 2023 11:00:06 +0000 Subject: [PATCH 1/9] split MLP into symbolic style and vanilla style --- docs/zh/api/arch.md | 1 + ppsci/arch/__init__.py | 2 + ppsci/arch/mlp.py | 85 +++++++++++++++++++++++++++++------------- 3 files changed, 63 insertions(+), 25 deletions(-) diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index 30358350ee..0f08760dc0 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -6,6 +6,7 @@ members: - Arch - MLP + - FullyConnectedLayers - DeepONet - LorenzEmbedding - RosslerEmbedding diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 986e318933..2282e21bbf 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -16,6 +16,7 @@ from ppsci.arch.base import Arch # isort:skip from ppsci.arch.mlp import MLP # isort:skip +from ppsci.arch.mlp import FullyConnectedLayers # isort:skip from ppsci.arch.deeponet import DeepONet # isort:skip from ppsci.arch.embedding_koopman import LorenzEmbedding # isort:skip from ppsci.arch.embedding_koopman import RosslerEmbedding # isort:skip @@ -32,6 +33,7 @@ __all__ = [ "Arch", "MLP", + "FullyConnectedLayers", "DeepONet", "LorenzEmbedding", "RosslerEmbedding", diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index c874669d2e..cfeaee8fec 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -50,43 +50,34 @@ def forward(self, input): return nn.functional.linear(input, weight, self.bias) -class MLP(base.Arch): - """Multi layer perceptron network. +class FullyConnectedLayers(base.Arch): + """Fully Connected Layers, core implementation of MLP. Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z"). - output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w"). + input_dim (int): Number of input's dimension. + output_dim (int): Number of output's dimension. num_layers (int): Number of hidden layers. hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size. An integer for all layers, or list of integer specify each layer's size. activation (str, optional): Name of activation function. Defaults to "tanh". skip_connection (bool, optional): Whether to use skip connection. Defaults to False. weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False. - input_dim (Optional[int]): Number of input's dimension. Defaults to None. - output_dim (Optional[int]): Number of output's dimension. Defaults to None. Examples: >>> import ppsci - >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 5, 128) + >>> model = ppsci.arch.FullyConnectedLayers(3, 4, num_layers=5, hidden_size=128) """ def __init__( self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], + input_dim: int, + output_dim: int, num_layers: int, hidden_size: Union[int, Tuple[int, ...]], activation: str = "tanh", skip_connection: bool = False, weight_norm: bool = False, - input_dim: Optional[int] = None, - output_dim: Optional[int] = None, ): - super().__init__() - self.input_keys = input_keys - self.output_keys = output_keys - self.linears = [] - self.acts = [] if isinstance(hidden_size, (tuple, list)): if num_layers is not None: raise ValueError( @@ -104,8 +95,11 @@ def __init__( f"but got {type(hidden_size)}" ) + self.linears = nn.LayerList() + self.acts = nn.LayerList() + # initialize FC layer(s) - cur_size = len(self.input_keys) if input_dim is None else input_dim + cur_size = input_dim for i, _size in enumerate(hidden_size): self.linears.append( WeightNormLinear(cur_size, _size) @@ -128,16 +122,11 @@ def __init__( cur_size = _size - self.linears = nn.LayerList(self.linears) - self.acts = nn.LayerList(self.acts) - self.last_fc = nn.Linear( - cur_size, - len(self.output_keys) if output_dim is None else output_dim, - ) + self.last_fc = nn.Linear(cur_size, output_dim) self.skip_connection = skip_connection - def forward_tensor(self, x): + def forward(self, x): y = x skip = None for i, linear in enumerate(self.linears): @@ -154,12 +143,58 @@ def forward_tensor(self, x): return y + +class MLP(FullyConnectedLayers): + """Multi layer perceptron network derivated by FullyConnectedLayers. + Which accepts input/output string key(s) for symbolic computation. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z"). + output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w"). + num_layers (int): Number of hidden layers. + hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size. + An integer for all layers, or list of integer specify each layer's size. + activation (str, optional): Name of activation function. Defaults to "tanh". + skip_connection (bool, optional): Whether to use skip connection. Defaults to False. + weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False. + input_dim (Optional[int]): Number of input's dimension. Defaults to None. + output_dim (Optional[int]): Number of output's dimension. Defaults to None. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), num_layers=5, hidden_size=128) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + num_layers: int, + hidden_size: Union[int, Tuple[int, ...]], + activation: str = "tanh", + skip_connection: bool = False, + weight_norm: bool = False, + input_dim: Optional[int] = None, + output_dim: Optional[int] = None, + ): + self.input_keys = input_keys + self.output_keys = output_keys + super().__init__( + len(input_keys) if not input_dim else input_dim, + len(output_keys) if not output_dim else output_dim, + num_layers, + hidden_size, + activation, + skip_connection, + weight_norm, + ) + def forward(self, x): if self._input_transform is not None: x = self._input_transform(x) y = self.concat_to_tensor(x, self.input_keys, axis=-1) - y = self.forward_tensor(y) + y = super().forward(x) y = self.split_to_dict(y, self.output_keys, axis=-1) if self._output_transform is not None: From 8cddc54a8de08ee96354829329ec3a6bec948841 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 21 Sep 2023 11:15:27 +0000 Subject: [PATCH 2/9] split AFNO into symbolic style and vanilla style --- docs/zh/api/arch.md | 1 + ppsci/arch/__init__.py | 2 + ppsci/arch/afno.py | 86 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 74 insertions(+), 15 deletions(-) diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index 0f08760dc0..b0193f3cfa 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -15,6 +15,7 @@ - Discriminator - PhysformerGPT2 - ModelList + - AdaptiveFourierLayers - AFNONet - PrecipNet show_root_heading: false diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 2282e21bbf..a40ec485b8 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -26,6 +26,7 @@ from ppsci.arch.physx_transformer import PhysformerGPT2 # isort:skip from ppsci.arch.model_list import ModelList # isort:skip from ppsci.arch.afno import AFNONet # isort:skip +from ppsci.arch.afno import AdaptiveFourierLayers # isort:skip from ppsci.arch.afno import PrecipNet # isort:skip from ppsci.utils import logger # isort:skip @@ -43,6 +44,7 @@ "PhysformerGPT2", "ModelList", "AFNONet", + "AdaptiveFourierLayers", "PrecipNet", "build_model", ] diff --git a/ppsci/arch/afno.py b/ppsci/arch/afno.py index 6f820bd4d8..0f2397fd55 100644 --- a/ppsci/arch/afno.py +++ b/ppsci/arch/afno.py @@ -391,12 +391,10 @@ def forward(self, x): return x -class AFNONet(base.Arch): - """Adaptive Fourier Neural Network. +class AdaptiveFourierLayers(base.Arch): + """Adaptive Fourier Neural Operators Network, core implementation of AFNO. Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). - output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). in_channels (int, optional): The input tensor channels. Defaults to 20. @@ -413,13 +411,11 @@ class AFNONet(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + >>> model = ppsci.arch.AdaptiveFourierLayers() """ def __init__( self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], img_size: Tuple[int, ...] = (720, 1440), patch_size: Tuple[int, ...] = (8, 8), in_channels: int = 20, @@ -435,9 +431,6 @@ def __init__( num_timestamps: int = 1, ): super().__init__() - self.input_keys = input_keys - self.output_keys = output_keys - self.img_size = img_size self.patch_size = patch_size self.in_channels = in_channels @@ -505,7 +498,7 @@ def _init_weights(self, m): elif isinstance(m, nn.Conv2D): initializer.conv_init_(m) - def forward_tensor(self, x): + def forward(self, x): B = x.shape[0] x = self.patch_embed(x) x = x + self.pos_embed @@ -529,10 +522,68 @@ def forward_tensor(self, x): return x - def split_to_dict( - self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + +class AFNONet(AdaptiveFourierLayers): + """Adaptive Fourier Neural Operators Network. + Which accepts input/output string key(s) for symbolic computation. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depth: int = 12, + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, ): - return {key: data_tensors[i] for i, key in enumerate(keys)} + self.input_keys = input_keys + self.output_keys = output_keys + super().__init__( + img_size, + patch_size, + in_channels, + out_channels, + embed_dim, + depth, + mlp_ratio, + drop_rate, + drop_path_rate, + num_blocks, + sparsity_threshold, + hard_thresholding_fraction, + num_timestamps, + ) def forward(self, x): if self._input_transform is not None: @@ -543,7 +594,7 @@ def forward(self, x): y = [] input = x_tensor for _ in range(self.num_timestamps): - out = self.forward_tensor(input) + out = super().forward(input) y.append(out) input = out y = self.split_to_dict(y, self.output_keys) @@ -552,6 +603,11 @@ def forward(self, x): y = self._output_transform(x, y) return y + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + class PrecipNet(base.Arch): """Precipitation Network. From 3f2c5d09c0e6fb03442cb36e3ae692c87a01503b Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 21 Sep 2023 11:46:35 +0000 Subject: [PATCH 3/9] split DeepOnet into symbolic style and vanilla style --- docs/zh/api/arch.md | 1 + ppsci/arch/__init__.py | 2 + ppsci/arch/deeponet.py | 144 ++++++++++++++++++++++++++++++----------- 3 files changed, 111 insertions(+), 36 deletions(-) diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index b0193f3cfa..123623486f 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -8,6 +8,7 @@ - MLP - FullyConnectedLayers - DeepONet + - DeepOperatorLayers - LorenzEmbedding - RosslerEmbedding - CylinderEmbedding diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index a40ec485b8..88f655a9bb 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -18,6 +18,7 @@ from ppsci.arch.mlp import MLP # isort:skip from ppsci.arch.mlp import FullyConnectedLayers # isort:skip from ppsci.arch.deeponet import DeepONet # isort:skip +from ppsci.arch.deeponet import DeepOperatorLayers # isort:skip from ppsci.arch.embedding_koopman import LorenzEmbedding # isort:skip from ppsci.arch.embedding_koopman import RosslerEmbedding # isort:skip from ppsci.arch.embedding_koopman import CylinderEmbedding # isort:skip @@ -36,6 +37,7 @@ "MLP", "FullyConnectedLayers", "DeepONet", + "DeepOperatorLayers", "LorenzEmbedding", "RosslerEmbedding", "CylinderEmbedding", diff --git a/ppsci/arch/deeponet.py b/ppsci/arch/deeponet.py index 374c09cd39..d0725565f6 100644 --- a/ppsci/arch/deeponet.py +++ b/ppsci/arch/deeponet.py @@ -25,15 +25,13 @@ from ppsci.arch import mlp -class DeepONet(base.Arch): - """Deep operator network. +class DeepOperatorLayers(base.Arch): + """Deep operator network, core implementation of `DeepONet`. [Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.](https://doi.org/10.1038/s42256-021-00302-5) Args: - u_key (str): Name of function data for input function u(x). - y_key (str): Name of location data for input function G(u). - G_key (str): Output name of predicted G(u)(y). + trunck_dim (int): Dimension of sampled u(x)(1 for scalar function, >1 for vector function). num_loc (int): Number of sampled u(x), i.e. `m` in paper. num_features (int): Number of features extracted from u(x), same for y. branch_num_layers (int): Number of hidden layers of branch net. @@ -52,8 +50,8 @@ class DeepONet(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.DeepONet( - ... "u", "y", "G", + >>> model = ppsci.arch.DeepOperatorLayers( + ... 1, ... 100, 40, ... 1, 1, ... 40, 40, @@ -64,9 +62,7 @@ class DeepONet(base.Arch): def __init__( self, - u_key: str, - y_key: str, - G_key: str, + trunck_dim: int, num_loc: int, num_features: int, branch_num_layers: int, @@ -82,33 +78,25 @@ def __init__( use_bias: bool = True, ): super().__init__() - self.u_key = u_key - self.y_key = y_key - self.input_keys = (u_key, y_key) - self.output_keys = (G_key,) - - self.branch_net = mlp.MLP( - (self.u_key,), - ("b",), + self.trunck_dim = trunck_dim + self.branch_net = mlp.FullyConnectedLayers( + num_loc, + num_features, branch_num_layers, branch_hidden_size, branch_activation, branch_skip_connection, branch_weight_norm, - input_dim=num_loc, - output_dim=num_features, ) - self.trunk_net = mlp.MLP( - (self.y_key,), - ("t",), + self.trunk_net = mlp.FullyConnectedLayers( + trunck_dim, + num_features, trunk_num_layers, trunk_hidden_size, trunk_activation, trunk_skip_connection, trunk_weight_norm, - input_dim=1, - output_dim=num_features, ) self.trunk_act = act_mod.get_activation(trunk_activation) @@ -120,28 +108,112 @@ def __init__( attr=nn.initializer.Constant(0.0), ) - def forward(self, x): - if self._input_transform is not None: - x = self._input_transform(x) - + def forward(self, u, y): # Branch net to encode the input function - u_features = self.branch_net(x)[self.branch_net.output_keys[0]] + u_features = self.branch_net(u) # Trunk net to encode the domain of the output function - y_features = self.trunk_net(x) - y_features = self.trunk_act(y_features[self.trunk_net.output_keys[0]]) + y_features = self.trunk_net(y) + y_features = self.trunk_act(y_features) # Dot product G_u = paddle.einsum("bi,bi->b", u_features, y_features) # [batch_size, ] - G_u = paddle.reshape(G_u, [-1, 1]) # reshape [batch_size, ] to [batch_size, 1] + G_u = paddle.reshape( + G_u, [-1, self.trunck_dim] + ) # reshape [batch_size, ] to [batch_size, 1] # Add bias if self.use_bias: G_u += self.b - result_dict = { - self.output_keys[0]: G_u, - } + return G_u + + +class DeepONet(DeepOperatorLayers): + """Deep operator network. + Different from `DeepOperatorLayers`, this class accepts input/output string key(s) for symbolic computation. + + [Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.](https://doi.org/10.1038/s42256-021-00302-5) + + Args: + u_key (str): Name of function data for input function u(x). + y_key (str): Name of location data for input function G(u). + G_key (str): Output name of predicted G(u)(y). + num_loc (int): Number of sampled u(x), i.e. `m` in paper. + num_features (int): Number of features extracted from u(x), same for y. + branch_num_layers (int): Number of hidden layers of branch net. + trunk_num_layers (int): Number of hidden layers of trunk net. + branch_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of branch net. + An integer for all layers, or list of integer specify each layer's size. + trunk_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of trunk net. + An integer for all layers, or list of integer specify each layer's size. + branch_skip_connection (bool, optional): Whether to use skip connection for branch net. Defaults to False. + trunk_skip_connection (bool, optional): Whether to use skip connection for trunk net. Defaults to False. + branch_activation (str, optional): Name of activation function. Defaults to "tanh". + trunk_activation (str, optional): Name of activation function. Defaults to "tanh". + branch_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for branch net. Defaults to False. + trunk_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for trunk net. Defaults to False. + use_bias (bool, optional): Whether to add bias on predicted G(u)(y). Defaults to True. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.DeepONet( + ... "u", "y", "G", + ... 100, 40, + ... 1, 1, + ... 40, 40, + ... branch_activation="relu", trunk_activation="relu", + ... use_bias=True, + ... ) + """ + + def __init__( + self, + u_key: str, + y_key: str, + G_key: str, + num_loc: int, + num_features: int, + branch_num_layers: int, + trunk_num_layers: int, + branch_hidden_size: Union[int, Tuple[int, ...]], + trunk_hidden_size: Union[int, Tuple[int, ...]], + branch_skip_connection: bool = False, + trunk_skip_connection: bool = False, + branch_activation: str = "tanh", + trunk_activation: str = "tanh", + branch_weight_norm: bool = False, + trunk_weight_norm: bool = False, + use_bias: bool = True, + ): + super().__init__() + self.input_keys = (u_key, y_key) + self.output_keys = (G_key,) + + super().__init__( + 1, + num_loc, + num_features, + branch_num_layers, + trunk_num_layers, + branch_hidden_size, + trunk_hidden_size, + branch_skip_connection, + trunk_skip_connection, + branch_activation, + trunk_activation, + branch_weight_norm, + trunk_weight_norm, + use_bias, + ) + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + G_u = super().forward(x[self.input_keys[0]], x[self.input_keys[1]]) + result_dict = {self.output_keys[0]: G_u} + if self._output_transform is not None: result_dict = self._output_transform(x, result_dict) From 771594b16852067d23ae754afee52168b37f1c7f Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 21 Sep 2023 12:01:13 +0000 Subject: [PATCH 4/9] split PrecipNet into symbolic style and vanilla style --- docs/zh/api/arch.md | 1 + ppsci/arch/__init__.py | 1 + ppsci/arch/afno.py | 89 +++++++++++++++++++++++++++++++++++------- 3 files changed, 77 insertions(+), 14 deletions(-) diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index 123623486f..b78a34b69b 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -19,5 +19,6 @@ - AdaptiveFourierLayers - AFNONet - PrecipNet + - PrecipLayers show_root_heading: false heading_level: 3 diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 88f655a9bb..cae59c5d80 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -48,6 +48,7 @@ "AFNONet", "AdaptiveFourierLayers", "PrecipNet", + "PrecipLayers", "build_model", ] diff --git a/ppsci/arch/afno.py b/ppsci/arch/afno.py index 0f2397fd55..0698a338dd 100644 --- a/ppsci/arch/afno.py +++ b/ppsci/arch/afno.py @@ -525,7 +525,7 @@ def forward(self, x): class AFNONet(AdaptiveFourierLayers): """Adaptive Fourier Neural Operators Network. - Which accepts input/output string key(s) for symbolic computation. + Different from `AdaptiveFourierLayers`, this class accepts input/output string key(s) for symbolic computation. Args: input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). @@ -609,8 +609,8 @@ def split_to_dict( return {key: data_tensors[i] for i, key in enumerate(keys)} -class PrecipNet(base.Arch): - """Precipitation Network. +class PrecipLayers(base.Arch): + """Precipitation Network, core implementation of PrecipNet. Args: input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). @@ -638,8 +638,6 @@ class PrecipNet(base.Arch): def __init__( self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], wind_model: base.Arch, img_size: Tuple[int, ...] = (720, 1440), patch_size: Tuple[int, ...] = (8, 8), @@ -656,9 +654,6 @@ def __init__( num_timestamps=1, ): super().__init__() - self.input_keys = input_keys - self.output_keys = output_keys - self.img_size = img_size self.patch_size = patch_size self.in_channels = in_channels @@ -666,9 +661,7 @@ def __init__( self.embed_dim = embed_dim self.num_blocks = num_blocks self.num_timestamps = num_timestamps - self.backbone = AFNONet( - ("input",), - ("output",), + self.backbone = AdaptiveFourierLayers( img_size=img_size, patch_size=patch_size, in_channels=in_channels, @@ -702,13 +695,80 @@ def _init_weights(self, m): elif isinstance(m, nn.Conv2D): initializer.conv_init_(m) - def forward_tensor(self, x): - x = self.backbone.forward_tensor(x) + def forward(self, x): + x = self.backbone.forward(x) x = self.ppad(x) x = self.conv(x) x = self.act(x) return x + +class PrecipNet(PrecipLayers): + """Precipitation Network. + Different from `PrecipLayers`, this class accepts input/output string key(s) for symbolic computation. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + wind_model (base.Arch): Wind model. + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 1. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> wind_model = ppsci.arch.AFNONet(("input", ), ("output", )) + >>> model = ppsci.arch.PrecipNet(("input", ), ("output", ), wind_model) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + wind_model: base.Arch, + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 1, + embed_dim: int = 768, + depth: int = 12, + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps=1, + ): + self.input_keys = input_keys + self.output_keys = output_keys + super().__init__( + wind_model, + img_size, + patch_size, + in_channels, + out_channels, + embed_dim, + depth, + mlp_ratio, + drop_rate, + drop_path_rate, + num_blocks, + sparsity_threshold, + hard_thresholding_fraction, + num_timestamps, + ) + def split_to_dict( self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] ): @@ -725,9 +785,10 @@ def forward(self, x): for _ in range(self.num_timestamps): with paddle.no_grad(): out_wind = self.wind_model.forward_tensor(input_wind) - out = self.forward_tensor(out_wind) + out = super().forward(out_wind) y.append(out) input_wind = out_wind + y = self.split_to_dict(y, self.output_keys) if self._output_transform is not None: From 89823c3d41fdc0a56963a2cee14277928ee3141a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 22 Sep 2023 04:45:04 +0000 Subject: [PATCH 5/9] fix for call super().__init__() --- ppsci/arch/mlp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index cfeaee8fec..420275b7cd 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -78,6 +78,7 @@ def __init__( skip_connection: bool = False, weight_norm: bool = False, ): + super().__init__() if isinstance(hidden_size, (tuple, list)): if num_layers is not None: raise ValueError( @@ -146,7 +147,7 @@ def forward(self, x): class MLP(FullyConnectedLayers): """Multi layer perceptron network derivated by FullyConnectedLayers. - Which accepts input/output string key(s) for symbolic computation. + Different from `FullyConnectedLayers`, this class accepts input/output string key(s) for symbolic computation. Args: input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z"). @@ -194,7 +195,7 @@ def forward(self, x): x = self._input_transform(x) y = self.concat_to_tensor(x, self.input_keys, axis=-1) - y = super().forward(x) + y = super().forward(y) y = self.split_to_dict(y, self.output_keys, axis=-1) if self._output_transform is not None: From a69310eeefe171650353fe65d45753ff9fac54a8 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 22 Sep 2023 05:28:19 +0000 Subject: [PATCH 6/9] rename suffix from 'Layers' to 'Layer' --- docs/zh/api/arch.md | 8 ++++---- ppsci/arch/__init__.py | 15 ++++++++------- ppsci/arch/afno.py | 16 ++++++++-------- ppsci/arch/deeponet.py | 12 ++++++------ ppsci/arch/mlp.py | 12 ++++++------ 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index b78a34b69b..9c53f77c0f 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -6,9 +6,9 @@ members: - Arch - MLP - - FullyConnectedLayers + - FullyConnectedLayer - DeepONet - - DeepOperatorLayers + - DeepOperatorLayer - LorenzEmbedding - RosslerEmbedding - CylinderEmbedding @@ -16,9 +16,9 @@ - Discriminator - PhysformerGPT2 - ModelList - - AdaptiveFourierLayers + - AdaptiveFourierLayer - AFNONet - PrecipNet - - PrecipLayers + - PrecipLayer show_root_heading: false heading_level: 3 diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index cae59c5d80..e403df8997 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -16,9 +16,9 @@ from ppsci.arch.base import Arch # isort:skip from ppsci.arch.mlp import MLP # isort:skip -from ppsci.arch.mlp import FullyConnectedLayers # isort:skip +from ppsci.arch.mlp import FullyConnectedLayer # isort:skip from ppsci.arch.deeponet import DeepONet # isort:skip -from ppsci.arch.deeponet import DeepOperatorLayers # isort:skip +from ppsci.arch.deeponet import DeepOperatorLayer # isort:skip from ppsci.arch.embedding_koopman import LorenzEmbedding # isort:skip from ppsci.arch.embedding_koopman import RosslerEmbedding # isort:skip from ppsci.arch.embedding_koopman import CylinderEmbedding # isort:skip @@ -27,17 +27,18 @@ from ppsci.arch.physx_transformer import PhysformerGPT2 # isort:skip from ppsci.arch.model_list import ModelList # isort:skip from ppsci.arch.afno import AFNONet # isort:skip -from ppsci.arch.afno import AdaptiveFourierLayers # isort:skip +from ppsci.arch.afno import AdaptiveFourierLayer # isort:skip from ppsci.arch.afno import PrecipNet # isort:skip +from ppsci.arch.afno import PrecipLayer # isort:skip from ppsci.utils import logger # isort:skip __all__ = [ "Arch", "MLP", - "FullyConnectedLayers", + "FullyConnectedLayer", "DeepONet", - "DeepOperatorLayers", + "DeepOperatorLayer", "LorenzEmbedding", "RosslerEmbedding", "CylinderEmbedding", @@ -46,9 +47,9 @@ "PhysformerGPT2", "ModelList", "AFNONet", - "AdaptiveFourierLayers", + "AdaptiveFourierLayer", "PrecipNet", - "PrecipLayers", + "PrecipLayer", "build_model", ] diff --git a/ppsci/arch/afno.py b/ppsci/arch/afno.py index 0698a338dd..3acf010cc3 100644 --- a/ppsci/arch/afno.py +++ b/ppsci/arch/afno.py @@ -391,7 +391,7 @@ def forward(self, x): return x -class AdaptiveFourierLayers(base.Arch): +class AdaptiveFourierLayer(base.Arch): """Adaptive Fourier Neural Operators Network, core implementation of AFNO. Args: @@ -411,7 +411,7 @@ class AdaptiveFourierLayers(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.AdaptiveFourierLayers() + >>> model = ppsci.arch.AdaptiveFourierLayer() """ def __init__( @@ -523,9 +523,9 @@ def forward(self, x): return x -class AFNONet(AdaptiveFourierLayers): +class AFNONet(AdaptiveFourierLayer): """Adaptive Fourier Neural Operators Network. - Different from `AdaptiveFourierLayers`, this class accepts input/output string key(s) for symbolic computation. + Different from `AdaptiveFourierLayer`, this class accepts input/output string key(s) for symbolic computation. Args: input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). @@ -609,7 +609,7 @@ def split_to_dict( return {key: data_tensors[i] for i, key in enumerate(keys)} -class PrecipLayers(base.Arch): +class PrecipLayer(base.Arch): """Precipitation Network, core implementation of PrecipNet. Args: @@ -661,7 +661,7 @@ def __init__( self.embed_dim = embed_dim self.num_blocks = num_blocks self.num_timestamps = num_timestamps - self.backbone = AdaptiveFourierLayers( + self.backbone = AdaptiveFourierLayer( img_size=img_size, patch_size=patch_size, in_channels=in_channels, @@ -703,9 +703,9 @@ def forward(self, x): return x -class PrecipNet(PrecipLayers): +class PrecipNet(PrecipLayer): """Precipitation Network. - Different from `PrecipLayers`, this class accepts input/output string key(s) for symbolic computation. + Different from `PrecipLayer`, this class accepts input/output string key(s) for symbolic computation. Args: input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). diff --git a/ppsci/arch/deeponet.py b/ppsci/arch/deeponet.py index d0725565f6..29bc62ff53 100644 --- a/ppsci/arch/deeponet.py +++ b/ppsci/arch/deeponet.py @@ -25,7 +25,7 @@ from ppsci.arch import mlp -class DeepOperatorLayers(base.Arch): +class DeepOperatorLayer(base.Arch): """Deep operator network, core implementation of `DeepONet`. [Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.](https://doi.org/10.1038/s42256-021-00302-5) @@ -50,7 +50,7 @@ class DeepOperatorLayers(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.DeepOperatorLayers( + >>> model = ppsci.arch.DeepOperatorLayer( ... 1, ... 100, 40, ... 1, 1, @@ -79,7 +79,7 @@ def __init__( ): super().__init__() self.trunck_dim = trunck_dim - self.branch_net = mlp.FullyConnectedLayers( + self.branch_net = mlp.FullyConnectedLayer( num_loc, num_features, branch_num_layers, @@ -89,7 +89,7 @@ def __init__( branch_weight_norm, ) - self.trunk_net = mlp.FullyConnectedLayers( + self.trunk_net = mlp.FullyConnectedLayer( trunck_dim, num_features, trunk_num_layers, @@ -129,9 +129,9 @@ def forward(self, u, y): return G_u -class DeepONet(DeepOperatorLayers): +class DeepONet(DeepOperatorLayer): """Deep operator network. - Different from `DeepOperatorLayers`, this class accepts input/output string key(s) for symbolic computation. + Different from `DeepOperatorLayer`, this class accepts input/output string key(s) for symbolic computation. [Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.](https://doi.org/10.1038/s42256-021-00302-5) diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index 420275b7cd..20f5e5a853 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -50,8 +50,8 @@ def forward(self, input): return nn.functional.linear(input, weight, self.bias) -class FullyConnectedLayers(base.Arch): - """Fully Connected Layers, core implementation of MLP. +class FullyConnectedLayer(base.Arch): + """Fully Connected Layer, core implementation of MLP. Args: input_dim (int): Number of input's dimension. @@ -65,7 +65,7 @@ class FullyConnectedLayers(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.FullyConnectedLayers(3, 4, num_layers=5, hidden_size=128) + >>> model = ppsci.arch.FullyConnectedLayer(3, 4, num_layers=5, hidden_size=128) """ def __init__( @@ -145,9 +145,9 @@ def forward(self, x): return y -class MLP(FullyConnectedLayers): - """Multi layer perceptron network derivated by FullyConnectedLayers. - Different from `FullyConnectedLayers`, this class accepts input/output string key(s) for symbolic computation. +class MLP(FullyConnectedLayer): + """Multi layer perceptron network derivated by FullyConnectedLayer. + Different from `FullyConnectedLayer`, this class accepts input/output string key(s) for symbolic computation. Args: input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z"). From cfab2c12709b764b32eb564741d9e1a2ff0963a2 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 22 Sep 2023 07:15:56 +0000 Subject: [PATCH 7/9] update equation UT --- test/equation/test_biharmonic.py | 8 +++++--- test/equation/test_laplace.py | 8 +++++--- test/equation/test_linear_elasticity.py | 14 ++++++-------- test/equation/test_navier_stokes.py | 8 +++++--- test/equation/test_poisson.py | 8 +++++--- test/equation/test_viv.py | 8 +++++--- 6 files changed, 31 insertions(+), 23 deletions(-) diff --git a/test/equation/test_biharmonic.py b/test/equation/test_biharmonic.py index 8e1d6c2be0..74aecb700f 100644 --- a/test/equation/test_biharmonic.py +++ b/test/equation/test_biharmonic.py @@ -31,10 +31,12 @@ def test_biharmonic(dim): input_data = paddle.concat([x, y, z], axis=1) # build NN model - model = arch.MLP(input_dims, output_dims, 2, 16) + model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16) + model_sym = arch.MLP(input_dims, output_dims, 2, 16) + model_sym.load_dict(model.state_dict()) # manually generate output - u = model.forward_tensor(input_data) + u = model(input_data) # use self-defined jacobian and hessian def jacobian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor": @@ -60,7 +62,7 @@ def hessian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor": if isinstance(expr, sp.Basic): biharmonic_equation.equations[name] = ppsci.lambdify( expr, - model, + model_sym, ) data_dict = { "x": x, diff --git a/test/equation/test_laplace.py b/test/equation/test_laplace.py index 6c438df3e4..1f79af940c 100644 --- a/test/equation/test_laplace.py +++ b/test/equation/test_laplace.py @@ -28,10 +28,12 @@ def test_l1loss_mean(dim): input_data = paddle.concat([x, y, z], axis=1) # build NN model - model = arch.MLP(input_dims, output_dims, 2, 16) + model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16) + model_sym = arch.MLP(input_dims, output_dims, 2, 16) + model_sym.load_dict(model.state_dict()) # manually generate output - u = model.forward_tensor(input_data) + u = model(input_data) # use self-defined jacobian and hessian def jacobian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor": @@ -51,7 +53,7 @@ def hessian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor": if isinstance(expr, sp.Basic): laplace_equation.equations[name] = ppsci.lambdify( expr, - model, + model_sym, ) data_dict = { diff --git a/test/equation/test_linear_elasticity.py b/test/equation/test_linear_elasticity.py index 973e3df104..ed418f01a0 100644 --- a/test/equation/test_linear_elasticity.py +++ b/test/equation/test_linear_elasticity.py @@ -172,14 +172,12 @@ def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time): if dim == 3: input_data = paddle.concat([input_data, z], axis=1) - model = arch.MLP(input_dims, output_dims, 2, 16) + model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16) + model_sym = arch.MLP(input_dims, output_dims, 2, 16) + model_sym.load_dict(model.state_dict()) - # model = nn.Sequential( - # nn.Linear(input_data.shape[1], 9 if dim == 3 else 5), - # nn.Tanh(), - # ) - - output = model.forward_tensor(input_data) + # manually generate output + output = model(input_data) u, v, *other_outputs = paddle.split(output, num_or_sections=output.shape[1], axis=1) @@ -234,7 +232,7 @@ def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time): if isinstance(expr, sp.Basic): linear_elasticity.equations[name] = ppsci.lambdify( expr, - model, + model_sym, ) data_dict = { "t": t, diff --git a/test/equation/test_navier_stokes.py b/test/equation/test_navier_stokes.py index 0279374ac8..7bc55b6fef 100644 --- a/test/equation/test_navier_stokes.py +++ b/test/equation/test_navier_stokes.py @@ -110,10 +110,12 @@ def test_navierstokes(nu, rho, dim, time): input_dims = input_dims + ("z",) input_data = paddle.concat(inputs, axis=1) - model = arch.MLP(input_dims, output_dims, 2, 16) + model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16) + model_sym = arch.MLP(input_dims, output_dims, 2, 16) + model_sym.load_dict(model.state_dict()) # manually generate output - output = model.forward_tensor(input_data) + output = model(input_data) if dim == 2: u, v, p = paddle.split(output, num_or_sections=len(output_dims), axis=1) @@ -140,7 +142,7 @@ def test_navierstokes(nu, rho, dim, time): if isinstance(expr, sp.Basic): navier_stokes_equation.equations[name] = ppsci.lambdify( expr, - model, + model_sym, ) data_dict = {"x": x, "y": y, "u": u, "v": v, "p": p} diff --git a/test/equation/test_poisson.py b/test/equation/test_poisson.py index ca86d98db2..627b8352bf 100644 --- a/test/equation/test_poisson.py +++ b/test/equation/test_poisson.py @@ -42,10 +42,12 @@ def test_poisson(dim): input_data = paddle.concat([x, y, z], axis=1) # build NN model - model = arch.MLP(input_dims, output_dims, 2, 16) + model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16) + model_sym = arch.MLP(input_dims, output_dims, 2, 16) + model_sym.load_dict(model.state_dict()) # manually generate output - p = model.forward_tensor(input_data) + p = model(input_data) def jacobian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor: return paddle.grad(y, x, create_graph=True)[0] @@ -64,7 +66,7 @@ def hessian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor: if isinstance(expr, sp.Basic): poisson_equation.equations[name] = ppsci.lambdify( expr, - model, + model_sym, ) data_dict = { diff --git a/test/equation/test_viv.py b/test/equation/test_viv.py index 2dc979912d..8806ef9888 100644 --- a/test/equation/test_viv.py +++ b/test/equation/test_viv.py @@ -33,10 +33,12 @@ def test_vibration(rho, k1, k2): input_data = t_f input_dims = ("t_f",) output_dims = ("eta",) - model = arch.MLP(input_dims, output_dims, 2, 16) + model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16) + model_sym = arch.MLP(input_dims, output_dims, 2, 16) + model_sym.load_dict(model.state_dict()) # manually generate output - eta = model.forward_tensor(input_data) + eta = model(input_data) def jacobian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor: return paddle.grad(y, x, create_graph=True)[0] @@ -56,7 +58,7 @@ def hessian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor: if isinstance(expr, sp.Basic): vibration_equation.equations[name] = ppsci.lambdify( expr, - model, + model_sym, vibration_equation.learnable_parameters, ) input_data_dict = {"t_f": t_f} From cde2f8c8f744008d14ea8de0d618149951047775 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Oct 2023 05:42:08 +0000 Subject: [PATCH 8/9] split more archs into layer --- docs/zh/api/arch.md | 17 +++- ppsci/arch/__init__.py | 16 +++- ppsci/arch/activation.py | 4 + ppsci/arch/afno.py | 9 +- ppsci/arch/base.py | 4 + ppsci/arch/deeponet.py | 5 + ppsci/arch/embedding_koopman.py | 157 +++++++++++++++++++++++++++----- ppsci/arch/gan.py | 150 +++++++++++++++++++++++++----- ppsci/arch/mlp.py | 6 ++ ppsci/arch/model_list.py | 4 + ppsci/arch/phylstm.py | 8 +- ppsci/arch/physx_transformer.py | 74 ++++++++++++--- ppsci/arch/unetex.py | 85 +++++++++++++++-- 13 files changed, 466 insertions(+), 73 deletions(-) diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index e5c7ce729d..09329d6cd2 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -5,22 +5,29 @@ options: members: - Arch - - MLP - FullyConnectedLayer - - DeepONet - DeepOperatorLayer - - DeepPhyLSTM + - LorenzEmbeddingLayer + - RosslerEmbeddingLayer + - CylinderEmbeddingLayer + - DiscriminatorLayer + - PhysformerGPT2Layer + - AdaptiveFourierLayer + - GeneratorLayer + - PrecipLayer + - UNetExLayer + - MLP + - DeepONet - LorenzEmbedding - RosslerEmbedding - CylinderEmbedding - Generator - Discriminator + - DeepPhyLSTM - PhysformerGPT2 - ModelList - AFNONet - - AdaptiveFourierLayer - PrecipNet - - PrecipLayer - UNetEx show_root_heading: false heading_level: 3 diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index d5395f9d49..523d3f8676 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -20,17 +20,24 @@ from ppsci.arch.deeponet import DeepONet # isort:skip from ppsci.arch.deeponet import DeepOperatorLayer # isort:skip from ppsci.arch.embedding_koopman import LorenzEmbedding # isort:skip +from ppsci.arch.embedding_koopman import LorenzEmbeddingLayer # isort:skip from ppsci.arch.embedding_koopman import RosslerEmbedding # isort:skip +from ppsci.arch.embedding_koopman import RosslerEmbeddingLayer # isort:skip from ppsci.arch.embedding_koopman import CylinderEmbedding # isort:skip +from ppsci.arch.embedding_koopman import CylinderEmbeddingLayer # isort:skip from ppsci.arch.gan import Generator # isort:skip +from ppsci.arch.gan import GeneratorLayer # isort:skip from ppsci.arch.gan import Discriminator # isort:skip +from ppsci.arch.gan import DiscriminatorLayer # isort:skip from ppsci.arch.phylstm import DeepPhyLSTM # isort:skip from ppsci.arch.physx_transformer import PhysformerGPT2 # isort:skip +from ppsci.arch.physx_transformer import PhysformerGPT2Layer # isort:skip from ppsci.arch.model_list import ModelList # isort:skip from ppsci.arch.afno import AFNONet # isort:skip from ppsci.arch.afno import AdaptiveFourierLayer # isort:skip from ppsci.arch.afno import PrecipNet # isort:skip from ppsci.arch.afno import PrecipLayer # isort:skip +from ppsci.arch.unetex import UNetExLayer # isort:skip from ppsci.arch.unetex import UNetEx # isort:skip from ppsci.utils import logger # isort:skip @@ -41,18 +48,25 @@ "FullyConnectedLayer", "DeepONet", "DeepOperatorLayer", - "DeepPhyLSTM", "LorenzEmbedding", + "LorenzEmbeddingLayer", "RosslerEmbedding", + "RosslerEmbeddingLayer", "CylinderEmbedding", + "CylinderEmbeddingLayer", "Generator", + "GeneratorLayer", "Discriminator", + "DiscriminatorLayer", + "DeepPhyLSTM", "PhysformerGPT2", + "PhysformerGPT2Layer", "ModelList", "AFNONet", "AdaptiveFourierLayer", "PrecipNet", "PrecipLayer", + "UNetExLayer", "UNetEx", "build_model", ] diff --git a/ppsci/arch/activation.py b/ppsci/arch/activation.py index 6277811227..6905856a95 100644 --- a/ppsci/arch/activation.py +++ b/ppsci/arch/activation.py @@ -24,6 +24,10 @@ from ppsci.utils import initializer from ppsci.utils import misc +__all__ = [ + "get_activation", +] + class Stan(nn.Layer): """Self-scalable Tanh. diff --git a/ppsci/arch/afno.py b/ppsci/arch/afno.py index 3acf010cc3..483a13ffde 100644 --- a/ppsci/arch/afno.py +++ b/ppsci/arch/afno.py @@ -30,6 +30,13 @@ from ppsci.arch import base from ppsci.utils import initializer +__all__ = [ + "AdaptiveFourierLayer", + "AFNONet", + "PrecipLayer", + "PrecipNet", +] + def drop_path( x: paddle.Tensor, @@ -613,8 +620,6 @@ class PrecipLayer(base.Arch): """Precipitation Network, core implementation of PrecipNet. Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). - output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). wind_model (base.Arch): Wind model. img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). diff --git a/ppsci/arch/base.py b/ppsci/arch/base.py index cf4d79dc58..90071da1b7 100644 --- a/ppsci/arch/base.py +++ b/ppsci/arch/base.py @@ -24,6 +24,10 @@ from ppsci.utils import logger +__all__ = [ + "Arch", +] + class Arch(nn.Layer): """Base class for Network.""" diff --git a/ppsci/arch/deeponet.py b/ppsci/arch/deeponet.py index 29bc62ff53..ffb1954ac0 100644 --- a/ppsci/arch/deeponet.py +++ b/ppsci/arch/deeponet.py @@ -24,6 +24,11 @@ from ppsci.arch import base from ppsci.arch import mlp +__all__ = [ + "DeepOperatorLayer", + "DeepONet", +] + class DeepOperatorLayer(base.Arch): """Deep operator network, core implementation of `DeepONet`. diff --git a/ppsci/arch/embedding_koopman.py b/ppsci/arch/embedding_koopman.py index 5bae9ce2b8..44eed36d1c 100644 --- a/ppsci/arch/embedding_koopman.py +++ b/ppsci/arch/embedding_koopman.py @@ -29,16 +29,24 @@ from ppsci.arch import base +__all__ = [ + "LorenzEmbeddingLayer", + "LorenzEmbedding", + "RosslerEmbeddingLayer", + "RosslerEmbedding", + "CylinderEmbeddingLayer", + "CylinderEmbedding", +] + + zeros_ = Constant(value=0.0) ones_ = Constant(value=1.0) -class LorenzEmbedding(base.Arch): - """Embedding Koopman model for the Lorenz ODE system. +class LorenzEmbeddingLayer(base.Arch): + """Embedding Koopman layer for the Lorenz ODE system, core implementation of LorenzEmbedding Args: - input_keys (Tuple[str, ...]): Input keys, such as ("states",). - output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states"). mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None. std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None. input_size (int, optional): Size of input data. Defaults to 3. @@ -53,8 +61,6 @@ class LorenzEmbedding(base.Arch): def __init__( self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], mean: Optional[Tuple[float, ...]] = None, std: Optional[Tuple[float, ...]] = None, input_size: int = 3, @@ -63,8 +69,6 @@ def __init__( drop: float = 0.0, ): super().__init__() - self.input_keys = input_keys - self.output_keys = output_keys self.input_size = input_size self.hidden_size = hidden_size self.embed_size = embed_size @@ -167,7 +171,7 @@ def get_koopman_matrix(self): k_matrix = k_matrix + paddle.diag(self.k_diag) return k_matrix - def forward_tensor(self, x): + def forward(self, x): k_matrix = self.get_koopman_matrix() embed_data = self.encoder(x) recover_data = self.decoder(embed_data) @@ -177,6 +181,47 @@ def forward_tensor(self, x): return (pred_data[:, :-1, :], recover_data, k_matrix) + +class LorenzEmbedding(LorenzEmbeddingLayer): + """Embedding Koopman model for the Lorenz ODE system. + + Args: + input_keys (Tuple[str, ...]): Input keys, such as ("states",). + output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states"). + mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None. + std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None. + input_size (int, optional): Size of input data. Defaults to 3. + hidden_size (int, optional): Number of hidden size. Defaults to 500. + embed_size (int, optional): Number of embedding size. Defaults to 32. + drop (float, optional): Probability of dropout the units. Defaults to 0.0. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.LorenzEmbedding(("x", "y"), ("u", "v")) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + input_size: int = 3, + hidden_size: int = 500, + embed_size: int = 32, + drop: float = 0.0, + ): + self.input_keys = input_keys + self.output_keys = output_keys + super().__init__( + mean, + std, + input_size, + hidden_size, + embed_size, + drop, + ) + def split_to_dict( self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] ): @@ -187,7 +232,7 @@ def forward(self, x): x = self._input_transform(x) x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1) - y = self.forward_tensor(x_tensor) + y = super().forward(x_tensor) y = self.split_to_dict(y, self.output_keys) if self._output_transform is not None: @@ -195,6 +240,41 @@ def forward(self, x): return y +class RosslerEmbeddingLayer(LorenzEmbeddingLayer): + """Embedding Koopman layer for the Rossler ODE system, core implementation of RosslerEmbedding. + + Args: + mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None. + std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None. + input_size (int, optional): Size of input data. Defaults to 3. + hidden_size (int, optional): Number of hidden size. Defaults to 500. + embed_size (int, optional): Number of embedding size. Defaults to 32. + drop (float, optional): Probability of dropout the units. Defaults to 0.0. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.RosslerEmbedding(("x", "y"), ("u", "v")) + """ + + def __init__( + self, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + input_size: int = 3, + hidden_size: int = 500, + embed_size: int = 32, + drop: float = 0.0, + ): + super().__init__( + mean, + std, + input_size, + hidden_size, + embed_size, + drop, + ) + + class RosslerEmbedding(LorenzEmbedding): """Embedding Koopman model for the Rossler ODE system. @@ -236,12 +316,10 @@ def __init__( ) -class CylinderEmbedding(base.Arch): - """Embedding Koopman model for the Cylinder system. +class CylinderEmbeddingLayer(base.Arch): + """Embedding Koopman layer for the Cylinder system, core implementation of CylinderEmbedding. Args: - input_keys (Tuple[str, ...]): Input keys, such as ("states", "visc"). - output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states"). mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None. std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None. embed_size (int, optional): Number of embedding size. Defaults to 128. @@ -256,8 +334,6 @@ class CylinderEmbedding(base.Arch): def __init__( self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], mean: Optional[Tuple[float, ...]] = None, std: Optional[Tuple[float, ...]] = None, embed_size: int = 128, @@ -266,8 +342,6 @@ def __init__( drop: float = 0.0, ): super().__init__() - self.input_keys = input_keys - self.output_keys = output_keys self.embed_size = embed_size X, Y = np.meshgrid(np.linspace(-2, 14, 128), np.linspace(-4, 4, 64)) @@ -471,7 +545,7 @@ def _normalize(self, x: paddle.Tensor): def _unnormalize(self, x: paddle.Tensor): return self.std[:, :3] * x + self.mean[:, :3] - def forward_tensor(self, states, visc): + def forward(self, states, visc): # states.shape=(B, T, C, H, W) embed_data = self.encoder(states, visc) recover_data = self.decoder(embed_data) @@ -482,17 +556,58 @@ def forward_tensor(self, states, visc): return (pred_data[:, :-1], recover_data, k_matrix) + +class CylinderEmbedding(CylinderEmbeddingLayer): + """Embedding Koopman model for the Cylinder system. + + Args: + input_keys (Tuple[str, ...]): Input keys, such as ("states", "visc"). + output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states"). + mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None. + std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None. + embed_size (int, optional): Number of embedding size. Defaults to 128. + encoder_channels (Optional[Tuple[int, ...]]): Number of channels in encoder network. Defaults to None. + decoder_channels (Optional[Tuple[int, ...]]): Number of channels in decoder network. Defaults to None. + drop (float, optional): Probability of dropout the units. Defaults to 0.0. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.CylinderEmbedding(("x", "y"), ("u", "v")) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + embed_size: int = 128, + encoder_channels: Optional[Tuple[int, ...]] = None, + decoder_channels: Optional[Tuple[int, ...]] = None, + drop: float = 0.0, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + super().__init__( + mean, + std, + embed_size, + encoder_channels, + decoder_channels, + drop, + ) + def split_to_dict( self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] ): return {key: data_tensors[i] for i, key in enumerate(keys)} def forward(self, x): - if self._input_transform is not None: x = self._input_transform(x) - y = self.forward_tensor(**x) + y = super().forward(**x) y = self.split_to_dict(y, self.output_keys) if self._output_transform is not None: diff --git a/ppsci/arch/gan.py b/ppsci/arch/gan.py index 3c7ec2c192..dfe646505c 100644 --- a/ppsci/arch/gan.py +++ b/ppsci/arch/gan.py @@ -24,6 +24,13 @@ from ppsci.arch import activation as act_mod from ppsci.arch import base +__all__ = [ + "GeneratorLayer", + "Generator", + "DiscriminatorLayer", + "Discriminator", +] + class Conv2DBlock(nn.Layer): def __init__( @@ -151,13 +158,10 @@ def forward(self, x): return y -class Generator(base.Arch): - """Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is - unique to "tempoGAN" example but not an open source network. +class GeneratorLayer(base.Arch): + """Generator Layer of GAN, core implementation of Generator. Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2"). - output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2"). in_channel (int): Number of input channels of the first conv layer. out_channels_tuple (Tuple[Tuple[int, ...], ...]): Number of output channels of all conv layers, such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]] @@ -186,8 +190,6 @@ class Generator(base.Arch): def __init__( self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], in_channel: int, out_channels_tuple: Tuple[Tuple[int, ...], ...], kernel_sizes_tuple: Tuple[Tuple[int, ...], ...], @@ -196,8 +198,6 @@ def __init__( acts_tuple: Tuple[Tuple[str, ...], ...], ): super().__init__() - self.input_keys = input_keys - self.output_keys = output_keys self.in_channel = in_channel self.out_channels_tuple = out_channels_tuple self.kernel_sizes_tuple = kernel_sizes_tuple @@ -228,18 +228,74 @@ def init_blocks(self): ) self.blocks = nn.LayerList(blocks_list) - def forward_tensor(self, x): + def forward(self, x): y = x for block in self.blocks: y = block(y) return y + +class Generator(GeneratorLayer): + """Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is + unique to "tempoGAN" example but not an open source network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2"). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2"). + in_channel (int): Number of input channels of the first conv layer. + out_channels_tuple (Tuple[Tuple[int, ...], ...]): Number of output channels of all conv layers, + such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]] + kernel_sizes_tuple (Tuple[Tuple[int, ...], ...]): Number of kernel_size of all conv layers, + such as [[kernel_size_res0_conv0, kernel_size_res0_conv1], [kernel_size_res1_conv0, kernel_size_res1_conv1]] + strides_tuple (Tuple[Tuple[int, ...], ...]): Number of stride of all conv layers, + such as [[stride_res0_conv0, stride_res0_conv1], [stride_res1_conv0, stride_res1_conv1]] + use_bns_tuple (Tuple[Tuple[bool, ...], ...]): Whether to use the batch_norm layer after each conv layer. + acts_tuple (Tuple[Tuple[str, ...], ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use, + such as [[act_res0_conv0, act_res0_conv1], [act_res1_conv0, act_res1_conv1]] + + Examples: + >>> import ppsci + >>> in_channel = 1 + >>> rb_channel0 = (2, 8, 8) + >>> rb_channel1 = (128, 128, 128) + >>> rb_channel2 = (32, 8, 8) + >>> rb_channel3 = (2, 1, 1) + >>> out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3) + >>> kernel_sizes_tuple = (((5, 5), ) * 2 + ((1, 1), ), ) * 4 + >>> strides_tuple = ((1, 1, 1), ) * 4 + >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), ) + >>> acts_tuple = (("relu", None, None), ) * 4 + >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + in_channel: int, + out_channels_tuple: Tuple[Tuple[int, ...], ...], + kernel_sizes_tuple: Tuple[Tuple[int, ...], ...], + strides_tuple: Tuple[Tuple[int, ...], ...], + use_bns_tuple: Tuple[Tuple[bool, ...], ...], + acts_tuple: Tuple[Tuple[str, ...], ...], + ): + self.input_keys = input_keys + self.output_keys = output_keys + super().__init__( + in_channel, + out_channels_tuple, + kernel_sizes_tuple, + strides_tuple, + use_bns_tuple, + acts_tuple, + ) + def forward(self, x): if self._input_transform is not None: x = self._input_transform(x) y = self.concat_to_tensor(x, self.input_keys, axis=-1) - y = self.forward_tensor(y) + y = super().forward(y) y = self.split_to_dict(y, self.output_keys, axis=-1) if self._output_transform is not None: @@ -247,12 +303,10 @@ def forward(self, x): return y -class Discriminator(base.Arch): - """Discriminator Net of GAN. +class DiscriminatorLayer(base.Arch): + """Discriminator Net of GAN, core implementation of Discriminator. Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2"). - output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2"). in_channel (int): Number of input channels of the first conv layer. out_channels (Tuple[int, ...]): Number of output channels of all conv layers, such as (out_conv0, out_conv1, out_conv2). @@ -282,8 +336,6 @@ class Discriminator(base.Arch): def __init__( self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], in_channel: int, out_channels: Tuple[int, ...], fc_channel: int, @@ -293,8 +345,6 @@ def __init__( acts: Tuple[str, ...], ): super().__init__() - self.input_keys = input_keys - self.output_keys = output_keys self.in_channel = in_channel self.out_channels = out_channels self.fc_channel = fc_channel @@ -328,7 +378,7 @@ def init_layers(self): ) self.layers = nn.LayerList(layers_list) - def forward_tensor(self, x): + def forward(self, x): y = x y_list = [] for layer in self.layers: @@ -336,6 +386,64 @@ def forward_tensor(self, x): y_list.append(y) return y_list # y_conv1, y_conv2, y_conv3, y_conv4, y_fc(y_out) + +class Discriminator(DiscriminatorLayer): + """Discriminator Net of GAN. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2"). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2"). + in_channel (int): Number of input channels of the first conv layer. + out_channels (Tuple[int, ...]): Number of output channels of all conv layers, + such as (out_conv0, out_conv1, out_conv2). + fc_channel (int): Number of input features of linear layer. Number of output features of the layer + is set to 1 in this Net to construct a fully_connected layer. + kernel_sizes (Tuple[int, ...]): Number of kernel_size of all conv layers, + such as (kernel_size_conv0, kernel_size_conv1, kernel_size_conv2). + strides (Tuple[int, ...]): Number of stride of all conv layers, + such as (stride_conv0, stride_conv1, stride_conv2). + use_bns (Tuple[bool, ...]): Whether to use the batch_norm layer after each conv layer. + acts (Tuple[str, ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use, + such as (act_conv0, act_conv1, act_conv2). + + Examples: + >>> import ppsci + >>> in_channel = 2 + >>> in_channel_tempo = 3 + >>> out_channels = (32, 64, 128, 256) + >>> fc_channel = 65536 + >>> kernel_sizes = ((4, 4), (4, 4), (4, 4), (4, 4)) + >>> strides = (2, 2, 2, 1) + >>> use_bns = (False, True, True, True) + >>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None) + >>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10") + >>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + in_channel: int, + out_channels: Tuple[int, ...], + fc_channel: int, + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + use_bns: Tuple[bool, ...], + acts: Tuple[str, ...], + ): + self.input_keys = input_keys + self.output_keys = output_keys + super().__init__( + in_channel, + out_channels, + fc_channel, + kernel_sizes, + strides, + use_bns, + acts, + ) + def forward(self, x): if self._input_transform is not None: x = self._input_transform(x) @@ -343,7 +451,7 @@ def forward(self, x): y_list = [] # y1_conv1, y1_conv2, y1_conv3, y1_conv4, y1_fc, y2_conv1, y2_conv2, y2_conv3, y2_conv4, y2_fc for k in x: - y_list.extend(self.forward_tensor(x[k])) + y_list.extend(super().forward(x[k])) y = self.split_to_dict(y_list, self.output_keys) diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index 20f5e5a853..3ba5cb372a 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -24,6 +24,12 @@ from ppsci.arch import base from ppsci.utils import initializer +__all__ = [ + "WeightNormLinear", + "FullyConnectedLayer", + "MLP", +] + class WeightNormLinear(nn.Layer): def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: diff --git a/ppsci/arch/model_list.py b/ppsci/arch/model_list.py index f463f17226..746a787c27 100644 --- a/ppsci/arch/model_list.py +++ b/ppsci/arch/model_list.py @@ -20,6 +20,10 @@ from ppsci.arch import base +__all__ = [ + "ModelList", +] + class ModelList(base.Arch): """ModelList layer which wrap more than one model that shares inputs. diff --git a/ppsci/arch/phylstm.py b/ppsci/arch/phylstm.py index 14ead3f6c1..765e340d8e 100644 --- a/ppsci/arch/phylstm.py +++ b/ppsci/arch/phylstm.py @@ -17,9 +17,15 @@ from ppsci.arch import base +__all__ = [ + "DeepPhyLSTM", +] + class DeepPhyLSTM(base.Arch): - """DeepPhyLSTM init function. + """Physics-informed LSTM Network. + Zhang, R., Liu, Y., & Sun, H. (2020). Physics-informed multi-LSTM networks for metamodeling of nonlinear structures. + Computer Methods in Applied Mechanics and Engineering 369, 113226. Args: input_size (int): The input size. diff --git a/ppsci/arch/physx_transformer.py b/ppsci/arch/physx_transformer.py index a3fdb81207..22ae0b30a8 100644 --- a/ppsci/arch/physx_transformer.py +++ b/ppsci/arch/physx_transformer.py @@ -29,6 +29,12 @@ from ppsci.arch import base +__all__ = [ + "PhysformerGPT2Layer", + "PhysformerGPT2", +] + + zeros_ = Constant(value=0.0) ones_ = Constant(value=1.0) @@ -237,12 +243,10 @@ def forward( return outputs -class PhysformerGPT2(base.Arch): - """Transformer decoder model for modeling physics. +class PhysformerGPT2Layer(base.Arch): + """Transformer decoder layer for modeling physics, core implementation of PhysformerGPT2. Args: - input_keys (Tuple[str, ...]): Input keys, such as ("embeds",). - output_keys (Tuple[str, ...]): Output keys, such as ("pred_embeds",). num_layers (int): Number of transformer layers. num_ctx (int): Contex length of block. embed_size (int): The number of embedding size. @@ -259,8 +263,6 @@ class PhysformerGPT2(base.Arch): def __init__( self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], num_layers: int, num_ctx: int, embed_size: int, @@ -271,9 +273,6 @@ def __init__( initializer_range: float = 0.05, ): super().__init__() - self.input_keys = input_keys - self.output_keys = output_keys - self.num_layers = num_layers self.num_ctx = num_ctx self.embed_size = embed_size @@ -349,7 +348,7 @@ def generate(self, x, max_length=256): outputs = self._generate_time_series(x, max_length) return outputs - def forward_tensor(self, x): + def forward(self, x): position_embeds = self.get_position_embed(x) # Combine input embedding, position embeding hidden_states = x + position_embeds @@ -367,18 +366,69 @@ def forward_eval(self, x): outputs = self.generate(input_embeds) return (outputs[:, 1:],) + +class PhysformerGPT2(PhysformerGPT2Layer): + """Transformer decoder model for modeling physics. + + Args: + input_keys (Tuple[str, ...]): Input keys, such as ("embeds",). + output_keys (Tuple[str, ...]): Output keys, such as ("pred_embeds",). + num_layers (int): Number of transformer layers. + num_ctx (int): Contex length of block. + embed_size (int): The number of embedding size. + num_heads (int): The number of heads in multi-head attention. + embd_pdrop (float, optional): The dropout probability used on embedding features. Defaults to 0.0. + attn_pdrop (float, optional): The dropout probability used on attention weights. Defaults to 0.0. + resid_pdrop (float, optional): The dropout probability used on block outputs. Defaults to 0.0. + initializer_range (float, optional): Initializer range of linear layer. Defaults to 0.05. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.PhysformerGPT2(("embeds", ), ("pred_embeds", ), 6, 16, 128, 4) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + num_layers: int, + num_ctx: int, + embed_size: int, + num_heads: int, + embd_pdrop: float = 0.0, + attn_pdrop: float = 0.0, + resid_pdrop: float = 0.0, + initializer_range: float = 0.05, + ): + self.input_keys = input_keys + self.output_keys = output_keys + super().__init__( + num_layers, + num_ctx, + embed_size, + num_heads, + embd_pdrop, + attn_pdrop, + resid_pdrop, + initializer_range, + ) + def split_to_dict(self, data_tensors, keys): return {key: data_tensors[i] for i, key in enumerate(keys)} def forward(self, x): if self._input_transform is not None: x = self._input_transform(x) + x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1) + if self.training: - y = self.forward_tensor(x_tensor) + y = super().forward(x_tensor) else: - y = self.forward_eval(x_tensor) + y = super().forward_eval(x_tensor) + y = self.split_to_dict(y, self.output_keys) + if self._output_transform is not None: y = self._output_transform(x, y) return y diff --git a/ppsci/arch/unetex.py b/ppsci/arch/unetex.py index 4972e55088..6df5e8b2e4 100644 --- a/ppsci/arch/unetex.py +++ b/ppsci/arch/unetex.py @@ -21,6 +21,11 @@ from ppsci.arch import base +__all__ = [ + "UNetExLayer", + "UNetEx", +] + def create_layer( in_channel, @@ -173,14 +178,12 @@ def create_decoder( return nn.Sequential(*decoder) -class UNetEx(base.Arch): - """U-Net +class UNetExLayer(base.Arch): + """U-NetEx layer, core implementation of UNetEx [Ribeiro M D, Rehman A, Ahmed S, et al. DeepCFD: Efficient steady-state laminar flow approximation with deep convolutional neural networks[J]. arXiv preprint arXiv:2004.08826, 2020.](https://arxiv.org/abs/2004.08826) Args: - input_key (str): Name of function data for input. - output_key (str): Name of function data for output. in_channel (int): Number of channels of input. out_channel (int): Number of channels of output. kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3. @@ -198,8 +201,6 @@ class UNetEx(base.Arch): def __init__( self, - input_key: str, - output_key: str, in_channel: int, out_channel: int, kernel_size: int = 3, @@ -214,8 +215,6 @@ def __init__( raise ValueError("The filters shouldn't be empty ") super().__init__() - self.input_keys = (input_key,) - self.output_keys = (output_key,) self.final_activation = final_activation self.encoder = create_encoder( in_channel, @@ -265,9 +264,75 @@ def decode(self, x, tensors, indices, sizes): return paddle.concat(y, axis=1) def forward(self, x): - x = x[self.input_keys[0]] x, tensors, indices, sizes = self.encode(x) x = self.decode(x, tensors, indices, sizes) if self.final_activation is not None: x = self.final_activation(x) - return {self.output_keys[0]: x} + return x + + +class UNetEx(UNetExLayer): + """U-NetEx. + + [Ribeiro M D, Rehman A, Ahmed S, et al. DeepCFD: Efficient steady-state laminar flow approximation with deep convolutional neural networks[J]. arXiv preprint arXiv:2004.08826, 2020.](https://arxiv.org/abs/2004.08826) + + Args: + input_key (str): Name of function data for input. + output_key (str): Name of function data for output. + in_channel (int): Number of channels of input. + out_channel (int): Number of channels of output. + kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3. + filters (Tuple[int, ...], optional): Number of filters. Defaults to (16, 32, 64). + layers (int, optional): Number of encoders or decoders. Defaults to 3. + weight_norm (bool, optional): Whether use weight normalization layer. Defaults to True. + batch_norm (bool, optional): Whether add batch normalization layer. Defaults to True. + activation (Type[nn.Layer], optional): Name of activation function. Defaults to nn.ReLU. + final_activation (Optional[Type[nn.Layer]]): Name of final activation function. Defaults to None. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 3, 3, (8, 16, 32, 32), 5, Flase, False) + """ + + def __init__( + self, + input_key: str, + output_key: str, + in_channel: int, + out_channel: int, + kernel_size: int = 3, + filters: Tuple[int, ...] = (16, 32, 64), + layers: int = 3, + weight_norm: bool = True, + batch_norm: bool = True, + activation: Type[nn.Layer] = nn.ReLU, + final_activation: Optional[Type[nn.Layer]] = None, + ): + if len(filters) == 0: + raise ValueError("The filters shouldn't be empty ") + + self.input_keys = (input_key,) + self.output_keys = (output_key,) + super().__init__( + in_channel, + out_channel, + kernel_size, + filters, + layers, + weight_norm, + batch_norm, + activation, + final_activation, + ) + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x_tensor = x[self.input_keys[0]] + y = super().forward(x_tensor) + y = {self.output_keys[0]: y} + + if self._output_transform is not None: + y = self._output_transform(x, y) + return y From a668f658232dce144043611a7285f38a5c469af0 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Oct 2023 08:58:50 +0000 Subject: [PATCH 9/9] update code --- docs/zh/api/arch.md | 2 - ppsci/arch/__init__.py | 8 +- ppsci/arch/afno.py | 177 ++++++-------------------------- ppsci/arch/embedding_koopman.py | 12 +-- ppsci/arch/gan.py | 4 +- ppsci/arch/phylstm.py | 9 +- ppsci/arch/physx_transformer.py | 2 +- ppsci/arch/unetex.py | 2 +- 8 files changed, 47 insertions(+), 169 deletions(-) diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index 09329d6cd2..fa972b55a6 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -12,9 +12,7 @@ - CylinderEmbeddingLayer - DiscriminatorLayer - PhysformerGPT2Layer - - AdaptiveFourierLayer - GeneratorLayer - - PrecipLayer - UNetExLayer - MLP - DeepONet diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 523d3f8676..4cb1f6c2d5 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -34,11 +34,9 @@ from ppsci.arch.physx_transformer import PhysformerGPT2Layer # isort:skip from ppsci.arch.model_list import ModelList # isort:skip from ppsci.arch.afno import AFNONet # isort:skip -from ppsci.arch.afno import AdaptiveFourierLayer # isort:skip from ppsci.arch.afno import PrecipNet # isort:skip -from ppsci.arch.afno import PrecipLayer # isort:skip -from ppsci.arch.unetex import UNetExLayer # isort:skip from ppsci.arch.unetex import UNetEx # isort:skip +from ppsci.arch.unetex import UNetExLayer # isort:skip from ppsci.utils import logger # isort:skip @@ -63,11 +61,9 @@ "PhysformerGPT2Layer", "ModelList", "AFNONet", - "AdaptiveFourierLayer", "PrecipNet", - "PrecipLayer", - "UNetExLayer", "UNetEx", + "UNetExLayer", "build_model", ] diff --git a/ppsci/arch/afno.py b/ppsci/arch/afno.py index 483a13ffde..da7666d709 100644 --- a/ppsci/arch/afno.py +++ b/ppsci/arch/afno.py @@ -31,9 +31,7 @@ from ppsci.utils import initializer __all__ = [ - "AdaptiveFourierLayer", "AFNONet", - "PrecipLayer", "PrecipNet", ] @@ -398,10 +396,12 @@ def forward(self, x): return x -class AdaptiveFourierLayer(base.Arch): - """Adaptive Fourier Neural Operators Network, core implementation of AFNO. +class AFNONet(base.Arch): + """Adaptive Fourier Neural Network. Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). in_channels (int, optional): The input tensor channels. Defaults to 20. @@ -418,11 +418,13 @@ class AdaptiveFourierLayer(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.AdaptiveFourierLayer() + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) """ def __init__( self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], img_size: Tuple[int, ...] = (720, 1440), patch_size: Tuple[int, ...] = (8, 8), in_channels: int = 20, @@ -438,6 +440,9 @@ def __init__( num_timestamps: int = 1, ): super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + self.img_size = img_size self.patch_size = patch_size self.in_channels = in_channels @@ -505,7 +510,7 @@ def _init_weights(self, m): elif isinstance(m, nn.Conv2D): initializer.conv_init_(m) - def forward(self, x): + def forward_tensor(self, x): B = x.shape[0] x = self.patch_embed(x) x = x + self.pos_embed @@ -529,68 +534,10 @@ def forward(self, x): return x - -class AFNONet(AdaptiveFourierLayer): - """Adaptive Fourier Neural Operators Network. - Different from `AdaptiveFourierLayer`, this class accepts input/output string key(s) for symbolic computation. - - Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). - output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). - img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). - patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). - in_channels (int, optional): The input tensor channels. Defaults to 20. - out_channels (int, optional): The output tensor channels. Defaults to 20. - embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. - depth (int, optional): Number of transformer depth. Defaults to 12. - mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. - drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. - drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. - num_blocks (int, optional): Number of blocks. Defaults to 8. - sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. - hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. - num_timestamps (int, optional): Number of timestamp. Defaults to 1. - - Examples: - >>> import ppsci - >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) - """ - - def __init__( - self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], - img_size: Tuple[int, ...] = (720, 1440), - patch_size: Tuple[int, ...] = (8, 8), - in_channels: int = 20, - out_channels: int = 20, - embed_dim: int = 768, - depth: int = 12, - mlp_ratio: float = 4.0, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, - num_blocks: int = 8, - sparsity_threshold: float = 0.01, - hard_thresholding_fraction: float = 1.0, - num_timestamps: int = 1, + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] ): - self.input_keys = input_keys - self.output_keys = output_keys - super().__init__( - img_size, - patch_size, - in_channels, - out_channels, - embed_dim, - depth, - mlp_ratio, - drop_rate, - drop_path_rate, - num_blocks, - sparsity_threshold, - hard_thresholding_fraction, - num_timestamps, - ) + return {key: data_tensors[i] for i, key in enumerate(keys)} def forward(self, x): if self._input_transform is not None: @@ -601,7 +548,7 @@ def forward(self, x): y = [] input = x_tensor for _ in range(self.num_timestamps): - out = super().forward(input) + out = self.forward_tensor(input) y.append(out) input = out y = self.split_to_dict(y, self.output_keys) @@ -610,16 +557,13 @@ def forward(self, x): y = self._output_transform(x, y) return y - def split_to_dict( - self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] - ): - return {key: data_tensors[i] for i, key in enumerate(keys)} - -class PrecipLayer(base.Arch): - """Precipitation Network, core implementation of PrecipNet. +class PrecipNet(base.Arch): + """Precipitation Network. Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). wind_model (base.Arch): Wind model. img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). @@ -643,6 +587,8 @@ class PrecipLayer(base.Arch): def __init__( self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], wind_model: base.Arch, img_size: Tuple[int, ...] = (720, 1440), patch_size: Tuple[int, ...] = (8, 8), @@ -659,6 +605,9 @@ def __init__( num_timestamps=1, ): super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + self.img_size = img_size self.patch_size = patch_size self.in_channels = in_channels @@ -666,7 +615,9 @@ def __init__( self.embed_dim = embed_dim self.num_blocks = num_blocks self.num_timestamps = num_timestamps - self.backbone = AdaptiveFourierLayer( + self.backbone = AFNONet( + ("input",), + ("output",), img_size=img_size, patch_size=patch_size, in_channels=in_channels, @@ -700,80 +651,13 @@ def _init_weights(self, m): elif isinstance(m, nn.Conv2D): initializer.conv_init_(m) - def forward(self, x): - x = self.backbone.forward(x) + def forward_tensor(self, x): + x = self.backbone.forward_tensor(x) x = self.ppad(x) x = self.conv(x) x = self.act(x) return x - -class PrecipNet(PrecipLayer): - """Precipitation Network. - Different from `PrecipLayer`, this class accepts input/output string key(s) for symbolic computation. - - Args: - input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). - output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). - wind_model (base.Arch): Wind model. - img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). - patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). - in_channels (int, optional): The input tensor channels. Defaults to 20. - out_channels (int, optional): The output tensor channels. Defaults to 1. - embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. - depth (int, optional): Number of transformer depth. Defaults to 12. - mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. - drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. - drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. - num_blocks (int, optional): Number of blocks. Defaults to 8. - sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. - hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. - num_timestamps (int, optional): Number of timestamp. Defaults to 1. - - Examples: - >>> import ppsci - >>> wind_model = ppsci.arch.AFNONet(("input", ), ("output", )) - >>> model = ppsci.arch.PrecipNet(("input", ), ("output", ), wind_model) - """ - - def __init__( - self, - input_keys: Tuple[str, ...], - output_keys: Tuple[str, ...], - wind_model: base.Arch, - img_size: Tuple[int, ...] = (720, 1440), - patch_size: Tuple[int, ...] = (8, 8), - in_channels: int = 20, - out_channels: int = 1, - embed_dim: int = 768, - depth: int = 12, - mlp_ratio: float = 4.0, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, - num_blocks: int = 8, - sparsity_threshold: float = 0.01, - hard_thresholding_fraction: float = 1.0, - num_timestamps=1, - ): - self.input_keys = input_keys - self.output_keys = output_keys - super().__init__( - wind_model, - img_size, - patch_size, - in_channels, - out_channels, - embed_dim, - depth, - mlp_ratio, - drop_rate, - drop_path_rate, - num_blocks, - sparsity_threshold, - hard_thresholding_fraction, - num_timestamps, - ) - def split_to_dict( self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] ): @@ -790,10 +674,9 @@ def forward(self, x): for _ in range(self.num_timestamps): with paddle.no_grad(): out_wind = self.wind_model.forward_tensor(input_wind) - out = super().forward(out_wind) + out = self.forward_tensor(out_wind) y.append(out) input_wind = out_wind - y = self.split_to_dict(y, self.output_keys) if self._output_transform is not None: diff --git a/ppsci/arch/embedding_koopman.py b/ppsci/arch/embedding_koopman.py index 44eed36d1c..a43c6e5714 100644 --- a/ppsci/arch/embedding_koopman.py +++ b/ppsci/arch/embedding_koopman.py @@ -30,12 +30,12 @@ from ppsci.arch import base __all__ = [ - "LorenzEmbeddingLayer", "LorenzEmbedding", - "RosslerEmbeddingLayer", + "LorenzEmbeddingLayer", "RosslerEmbedding", - "CylinderEmbeddingLayer", + "RosslerEmbeddingLayer", "CylinderEmbedding", + "CylinderEmbeddingLayer", ] @@ -56,7 +56,7 @@ class LorenzEmbeddingLayer(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.LorenzEmbedding(("x", "y"), ("u", "v")) + >>> model = ppsci.arch.LorenzEmbeddingLayer() """ def __init__( @@ -253,7 +253,7 @@ class RosslerEmbeddingLayer(LorenzEmbeddingLayer): Examples: >>> import ppsci - >>> model = ppsci.arch.RosslerEmbedding(("x", "y"), ("u", "v")) + >>> model = ppsci.arch.RosslerEmbeddingLayer() """ def __init__( @@ -329,7 +329,7 @@ class CylinderEmbeddingLayer(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.CylinderEmbedding(("x", "y"), ("u", "v")) + >>> model = ppsci.arch.CylinderEmbeddingLayer() """ def __init__( diff --git a/ppsci/arch/gan.py b/ppsci/arch/gan.py index dfe646505c..da39037104 100644 --- a/ppsci/arch/gan.py +++ b/ppsci/arch/gan.py @@ -185,7 +185,7 @@ class GeneratorLayer(base.Arch): >>> strides_tuple = ((1, 1, 1), ) * 4 >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), ) >>> acts_tuple = (("relu", None, None), ) * 4 - >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple) + >>> model = ppsci.arch.GeneratorLayer(in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple) """ def __init__( @@ -331,7 +331,7 @@ class DiscriminatorLayer(base.Arch): >>> use_bns = (False, True, True, True) >>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None) >>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10") - >>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts) + >>> model = ppsci.arch.DiscriminatorLayer(in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts) """ def __init__( diff --git a/ppsci/arch/phylstm.py b/ppsci/arch/phylstm.py index 765e340d8e..b90a1e6561 100644 --- a/ppsci/arch/phylstm.py +++ b/ppsci/arch/phylstm.py @@ -105,12 +105,13 @@ def forward(self, x): x = self._input_transform(x) if self.model_type == 2: - result_dict = self._forward_type_2(x) + y = self._forward_type_2(x) elif self.model_type == 3: - result_dict = self._forward_type_3(x) + y = self._forward_type_3(x) + if self._output_transform is not None: - result_dict = self._output_transform(x, result_dict) - return result_dict + y = self._output_transform(x, y) + return y def _forward_type_2(self, x): output = self.lstm_model(x["ag"]) diff --git a/ppsci/arch/physx_transformer.py b/ppsci/arch/physx_transformer.py index 22ae0b30a8..2e4f49d215 100644 --- a/ppsci/arch/physx_transformer.py +++ b/ppsci/arch/physx_transformer.py @@ -258,7 +258,7 @@ class PhysformerGPT2Layer(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.PhysformerGPT2(("embeds", ), ("pred_embeds", ), 6, 16, 128, 4) + >>> model = ppsci.arch.PhysformerGPT2Layer(6, 16, 128, 4) """ def __init__( diff --git a/ppsci/arch/unetex.py b/ppsci/arch/unetex.py index 6df5e8b2e4..6fa19b4f15 100644 --- a/ppsci/arch/unetex.py +++ b/ppsci/arch/unetex.py @@ -196,7 +196,7 @@ class UNetExLayer(base.Arch): Examples: >>> import ppsci - >>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 3, 3, (8, 16, 32, 32), 5, Flase, False) + >>> model = ppsci.arch.ppsci.arch.UNetEx(3, 3, (8, 16, 32, 32), 5, Flase, False) """ def __init__(