diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md
index 03e4f62536..fa972b55a6 100644
--- a/docs/zh/api/arch.md
+++ b/docs/zh/api/arch.md
@@ -5,14 +5,23 @@
     options:
       members:
         - Arch
+        - FullyConnectedLayer
+        - DeepOperatorLayer
+        - LorenzEmbeddingLayer
+        - RosslerEmbeddingLayer
+        - CylinderEmbeddingLayer
+        - DiscriminatorLayer
+        - PhysformerGPT2Layer
+        - GeneratorLayer
+        - UNetExLayer
         - MLP
         - DeepONet
-        - DeepPhyLSTM
         - LorenzEmbedding
         - RosslerEmbedding
         - CylinderEmbedding
         - Generator
         - Discriminator
+        - DeepPhyLSTM
         - PhysformerGPT2
         - ModelList
         - AFNONet
diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py
index 8f1e417d75..4cb1f6c2d5 100644
--- a/ppsci/arch/__init__.py
+++ b/ppsci/arch/__init__.py
@@ -16,36 +16,54 @@
 
 from ppsci.arch.base import Arch  # isort:skip
 from ppsci.arch.mlp import MLP  # isort:skip
+from ppsci.arch.mlp import FullyConnectedLayer  # isort:skip
 from ppsci.arch.deeponet import DeepONet  # isort:skip
+from ppsci.arch.deeponet import DeepOperatorLayer  # isort:skip
 from ppsci.arch.embedding_koopman import LorenzEmbedding  # isort:skip
+from ppsci.arch.embedding_koopman import LorenzEmbeddingLayer  # isort:skip
 from ppsci.arch.embedding_koopman import RosslerEmbedding  # isort:skip
+from ppsci.arch.embedding_koopman import RosslerEmbeddingLayer  # isort:skip
 from ppsci.arch.embedding_koopman import CylinderEmbedding  # isort:skip
+from ppsci.arch.embedding_koopman import CylinderEmbeddingLayer  # isort:skip
 from ppsci.arch.gan import Generator  # isort:skip
+from ppsci.arch.gan import GeneratorLayer  # isort:skip
 from ppsci.arch.gan import Discriminator  # isort:skip
+from ppsci.arch.gan import DiscriminatorLayer  # isort:skip
 from ppsci.arch.phylstm import DeepPhyLSTM  # isort:skip
 from ppsci.arch.physx_transformer import PhysformerGPT2  # isort:skip
+from ppsci.arch.physx_transformer import PhysformerGPT2Layer  # isort:skip
 from ppsci.arch.model_list import ModelList  # isort:skip
 from ppsci.arch.afno import AFNONet  # isort:skip
 from ppsci.arch.afno import PrecipNet  # isort:skip
 from ppsci.arch.unetex import UNetEx  # isort:skip
+from ppsci.arch.unetex import UNetExLayer  # isort:skip
 from ppsci.utils import logger  # isort:skip
 
 
 __all__ = [
     "Arch",
     "MLP",
+    "FullyConnectedLayer",
     "DeepONet",
-    "DeepPhyLSTM",
+    "DeepOperatorLayer",
     "LorenzEmbedding",
+    "LorenzEmbeddingLayer",
     "RosslerEmbedding",
+    "RosslerEmbeddingLayer",
     "CylinderEmbedding",
+    "CylinderEmbeddingLayer",
     "Generator",
+    "GeneratorLayer",
     "Discriminator",
+    "DiscriminatorLayer",
+    "DeepPhyLSTM",
     "PhysformerGPT2",
+    "PhysformerGPT2Layer",
     "ModelList",
     "AFNONet",
     "PrecipNet",
     "UNetEx",
+    "UNetExLayer",
     "build_model",
 ]
 
diff --git a/ppsci/arch/activation.py b/ppsci/arch/activation.py
index 6277811227..6905856a95 100644
--- a/ppsci/arch/activation.py
+++ b/ppsci/arch/activation.py
@@ -24,6 +24,10 @@
 from ppsci.utils import initializer
 from ppsci.utils import misc
 
+__all__ = [
+    "get_activation",
+]
+
 
 class Stan(nn.Layer):
     """Self-scalable Tanh.
diff --git a/ppsci/arch/afno.py b/ppsci/arch/afno.py
index 6f820bd4d8..da7666d709 100644
--- a/ppsci/arch/afno.py
+++ b/ppsci/arch/afno.py
@@ -30,6 +30,11 @@
 from ppsci.arch import base
 from ppsci.utils import initializer
 
+__all__ = [
+    "AFNONet",
+    "PrecipNet",
+]
+
 
 def drop_path(
     x: paddle.Tensor,
diff --git a/ppsci/arch/base.py b/ppsci/arch/base.py
index cf4d79dc58..90071da1b7 100644
--- a/ppsci/arch/base.py
+++ b/ppsci/arch/base.py
@@ -24,6 +24,10 @@
 
 from ppsci.utils import logger
 
+__all__ = [
+    "Arch",
+]
+
 
 class Arch(nn.Layer):
     """Base class for Network."""
diff --git a/ppsci/arch/deeponet.py b/ppsci/arch/deeponet.py
index 374c09cd39..ffb1954ac0 100644
--- a/ppsci/arch/deeponet.py
+++ b/ppsci/arch/deeponet.py
@@ -24,16 +24,19 @@
 from ppsci.arch import base
 from ppsci.arch import mlp
 
+__all__ = [
+    "DeepOperatorLayer",
+    "DeepONet",
+]
 
-class DeepONet(base.Arch):
-    """Deep operator network.
+
+class DeepOperatorLayer(base.Arch):
+    """Deep operator network, core implementation of `DeepONet`.
 
     [Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.](https://doi.org/10.1038/s42256-021-00302-5)
 
     Args:
-        u_key (str): Name of function data for input function u(x).
-        y_key (str): Name of location data for input function G(u).
-        G_key (str): Output name of predicted G(u)(y).
+        trunck_dim (int): Dimension of sampled u(x)(1 for scalar function, >1 for vector function).
         num_loc (int): Number of sampled u(x), i.e. `m` in paper.
         num_features (int): Number of features extracted from u(x), same for y.
         branch_num_layers (int): Number of hidden layers of branch net.
@@ -52,8 +55,8 @@ class DeepONet(base.Arch):
 
     Examples:
         >>> import ppsci
-        >>> model = ppsci.arch.DeepONet(
-        ...     "u", "y", "G",
+        >>> model = ppsci.arch.DeepOperatorLayer(
+        ...     1,
         ...     100, 40,
         ...     1, 1,
         ...     40, 40,
@@ -64,9 +67,7 @@ class DeepONet(base.Arch):
 
     def __init__(
         self,
-        u_key: str,
-        y_key: str,
-        G_key: str,
+        trunck_dim: int,
         num_loc: int,
         num_features: int,
         branch_num_layers: int,
@@ -82,33 +83,25 @@ def __init__(
         use_bias: bool = True,
     ):
         super().__init__()
-        self.u_key = u_key
-        self.y_key = y_key
-        self.input_keys = (u_key, y_key)
-        self.output_keys = (G_key,)
-
-        self.branch_net = mlp.MLP(
-            (self.u_key,),
-            ("b",),
+        self.trunck_dim = trunck_dim
+        self.branch_net = mlp.FullyConnectedLayer(
+            num_loc,
+            num_features,
             branch_num_layers,
             branch_hidden_size,
             branch_activation,
             branch_skip_connection,
             branch_weight_norm,
-            input_dim=num_loc,
-            output_dim=num_features,
         )
 
-        self.trunk_net = mlp.MLP(
-            (self.y_key,),
-            ("t",),
+        self.trunk_net = mlp.FullyConnectedLayer(
+            trunck_dim,
+            num_features,
             trunk_num_layers,
             trunk_hidden_size,
             trunk_activation,
             trunk_skip_connection,
             trunk_weight_norm,
-            input_dim=1,
-            output_dim=num_features,
         )
         self.trunk_act = act_mod.get_activation(trunk_activation)
 
@@ -120,28 +113,112 @@ def __init__(
                 attr=nn.initializer.Constant(0.0),
             )
 
-    def forward(self, x):
-        if self._input_transform is not None:
-            x = self._input_transform(x)
-
+    def forward(self, u, y):
         # Branch net to encode the input function
-        u_features = self.branch_net(x)[self.branch_net.output_keys[0]]
+        u_features = self.branch_net(u)
 
         # Trunk net to encode the domain of the output function
-        y_features = self.trunk_net(x)
-        y_features = self.trunk_act(y_features[self.trunk_net.output_keys[0]])
+        y_features = self.trunk_net(y)
+        y_features = self.trunk_act(y_features)
 
         # Dot product
         G_u = paddle.einsum("bi,bi->b", u_features, y_features)  # [batch_size, ]
-        G_u = paddle.reshape(G_u, [-1, 1])  # reshape [batch_size, ] to [batch_size, 1]
+        G_u = paddle.reshape(
+            G_u, [-1, self.trunck_dim]
+        )  # reshape [batch_size, ] to [batch_size, 1]
 
         # Add bias
         if self.use_bias:
             G_u += self.b
 
-        result_dict = {
-            self.output_keys[0]: G_u,
-        }
+        return G_u
+
+
+class DeepONet(DeepOperatorLayer):
+    """Deep operator network.
+    Different from `DeepOperatorLayer`, this class accepts input/output string key(s) for symbolic computation.
+
+    [Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.](https://doi.org/10.1038/s42256-021-00302-5)
+
+    Args:
+        u_key (str): Name of function data for input function u(x).
+        y_key (str): Name of location data for input function G(u).
+        G_key (str): Output name of predicted G(u)(y).
+        num_loc (int): Number of sampled u(x), i.e. `m` in paper.
+        num_features (int): Number of features extracted from u(x), same for y.
+        branch_num_layers (int): Number of hidden layers of branch net.
+        trunk_num_layers (int): Number of hidden layers of trunk net.
+        branch_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of branch net.
+            An integer for all layers, or list of integer specify each layer's size.
+        trunk_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of trunk net.
+            An integer for all layers, or list of integer specify each layer's size.
+        branch_skip_connection (bool, optional): Whether to use skip connection for branch net. Defaults to False.
+        trunk_skip_connection (bool, optional): Whether to use skip connection for trunk net. Defaults to False.
+        branch_activation (str, optional): Name of activation function. Defaults to "tanh".
+        trunk_activation (str, optional): Name of activation function. Defaults to "tanh".
+        branch_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for branch net. Defaults to False.
+        trunk_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.
+        use_bias (bool, optional): Whether to add bias on predicted G(u)(y). Defaults to True.
+
+    Examples:
+        >>> import ppsci
+        >>> model = ppsci.arch.DeepONet(
+        ...     "u", "y", "G",
+        ...     100, 40,
+        ...     1, 1,
+        ...     40, 40,
+        ...     branch_activation="relu", trunk_activation="relu",
+        ...     use_bias=True,
+        ... )
+    """
+
+    def __init__(
+        self,
+        u_key: str,
+        y_key: str,
+        G_key: str,
+        num_loc: int,
+        num_features: int,
+        branch_num_layers: int,
+        trunk_num_layers: int,
+        branch_hidden_size: Union[int, Tuple[int, ...]],
+        trunk_hidden_size: Union[int, Tuple[int, ...]],
+        branch_skip_connection: bool = False,
+        trunk_skip_connection: bool = False,
+        branch_activation: str = "tanh",
+        trunk_activation: str = "tanh",
+        branch_weight_norm: bool = False,
+        trunk_weight_norm: bool = False,
+        use_bias: bool = True,
+    ):
+        super().__init__()
+        self.input_keys = (u_key, y_key)
+        self.output_keys = (G_key,)
+
+        super().__init__(
+            1,
+            num_loc,
+            num_features,
+            branch_num_layers,
+            trunk_num_layers,
+            branch_hidden_size,
+            trunk_hidden_size,
+            branch_skip_connection,
+            trunk_skip_connection,
+            branch_activation,
+            trunk_activation,
+            branch_weight_norm,
+            trunk_weight_norm,
+            use_bias,
+        )
+
+    def forward(self, x):
+        if self._input_transform is not None:
+            x = self._input_transform(x)
+
+        G_u = super().forward(x[self.input_keys[0]], x[self.input_keys[1]])
+        result_dict = {self.output_keys[0]: G_u}
+
         if self._output_transform is not None:
             result_dict = self._output_transform(x, result_dict)
 
diff --git a/ppsci/arch/embedding_koopman.py b/ppsci/arch/embedding_koopman.py
index 5bae9ce2b8..a43c6e5714 100644
--- a/ppsci/arch/embedding_koopman.py
+++ b/ppsci/arch/embedding_koopman.py
@@ -29,16 +29,24 @@
 
 from ppsci.arch import base
 
+__all__ = [
+    "LorenzEmbedding",
+    "LorenzEmbeddingLayer",
+    "RosslerEmbedding",
+    "RosslerEmbeddingLayer",
+    "CylinderEmbedding",
+    "CylinderEmbeddingLayer",
+]
+
+
 zeros_ = Constant(value=0.0)
 ones_ = Constant(value=1.0)
 
 
-class LorenzEmbedding(base.Arch):
-    """Embedding Koopman model for the Lorenz ODE system.
+class LorenzEmbeddingLayer(base.Arch):
+    """Embedding Koopman layer for the Lorenz ODE system, core implementation of LorenzEmbedding
 
     Args:
-        input_keys (Tuple[str, ...]): Input keys, such as ("states",).
-        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
         mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
         std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
         input_size (int, optional): Size of input data. Defaults to 3.
@@ -48,13 +56,11 @@ class LorenzEmbedding(base.Arch):
 
     Examples:
         >>> import ppsci
-        >>> model = ppsci.arch.LorenzEmbedding(("x", "y"), ("u", "v"))
+        >>> model = ppsci.arch.LorenzEmbeddingLayer()
     """
 
     def __init__(
         self,
-        input_keys: Tuple[str, ...],
-        output_keys: Tuple[str, ...],
         mean: Optional[Tuple[float, ...]] = None,
         std: Optional[Tuple[float, ...]] = None,
         input_size: int = 3,
@@ -63,8 +69,6 @@ def __init__(
         drop: float = 0.0,
     ):
         super().__init__()
-        self.input_keys = input_keys
-        self.output_keys = output_keys
         self.input_size = input_size
         self.hidden_size = hidden_size
         self.embed_size = embed_size
@@ -167,7 +171,7 @@ def get_koopman_matrix(self):
         k_matrix = k_matrix + paddle.diag(self.k_diag)
         return k_matrix
 
-    def forward_tensor(self, x):
+    def forward(self, x):
         k_matrix = self.get_koopman_matrix()
         embed_data = self.encoder(x)
         recover_data = self.decoder(embed_data)
@@ -177,6 +181,47 @@ def forward_tensor(self, x):
 
         return (pred_data[:, :-1, :], recover_data, k_matrix)
 
+
+class LorenzEmbedding(LorenzEmbeddingLayer):
+    """Embedding Koopman model for the Lorenz ODE system.
+
+    Args:
+        input_keys (Tuple[str, ...]): Input keys, such as ("states",).
+        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
+        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
+        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
+        input_size (int, optional): Size of input data. Defaults to 3.
+        hidden_size (int, optional): Number of hidden size. Defaults to 500.
+        embed_size (int, optional): Number of embedding size. Defaults to 32.
+        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.
+
+    Examples:
+        >>> import ppsci
+        >>> model = ppsci.arch.LorenzEmbedding(("x", "y"), ("u", "v"))
+    """
+
+    def __init__(
+        self,
+        input_keys: Tuple[str, ...],
+        output_keys: Tuple[str, ...],
+        mean: Optional[Tuple[float, ...]] = None,
+        std: Optional[Tuple[float, ...]] = None,
+        input_size: int = 3,
+        hidden_size: int = 500,
+        embed_size: int = 32,
+        drop: float = 0.0,
+    ):
+        self.input_keys = input_keys
+        self.output_keys = output_keys
+        super().__init__(
+            mean,
+            std,
+            input_size,
+            hidden_size,
+            embed_size,
+            drop,
+        )
+
     def split_to_dict(
         self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
     ):
@@ -187,7 +232,7 @@ def forward(self, x):
             x = self._input_transform(x)
 
         x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1)
-        y = self.forward_tensor(x_tensor)
+        y = super().forward(x_tensor)
         y = self.split_to_dict(y, self.output_keys)
 
         if self._output_transform is not None:
@@ -195,6 +240,41 @@ def forward(self, x):
         return y
 
 
+class RosslerEmbeddingLayer(LorenzEmbeddingLayer):
+    """Embedding Koopman layer for the Rossler ODE system, core implementation of RosslerEmbedding.
+
+    Args:
+        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
+        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
+        input_size (int, optional): Size of input data. Defaults to 3.
+        hidden_size (int, optional): Number of hidden size. Defaults to 500.
+        embed_size (int, optional): Number of embedding size. Defaults to 32.
+        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.
+
+    Examples:
+        >>> import ppsci
+        >>> model = ppsci.arch.RosslerEmbeddingLayer()
+    """
+
+    def __init__(
+        self,
+        mean: Optional[Tuple[float, ...]] = None,
+        std: Optional[Tuple[float, ...]] = None,
+        input_size: int = 3,
+        hidden_size: int = 500,
+        embed_size: int = 32,
+        drop: float = 0.0,
+    ):
+        super().__init__(
+            mean,
+            std,
+            input_size,
+            hidden_size,
+            embed_size,
+            drop,
+        )
+
+
 class RosslerEmbedding(LorenzEmbedding):
     """Embedding Koopman model for the Rossler ODE system.
 
@@ -236,12 +316,10 @@ def __init__(
         )
 
 
-class CylinderEmbedding(base.Arch):
-    """Embedding Koopman model for the Cylinder system.
+class CylinderEmbeddingLayer(base.Arch):
+    """Embedding Koopman layer for the Cylinder system, core implementation of CylinderEmbedding.
 
     Args:
-        input_keys (Tuple[str, ...]): Input keys, such as ("states", "visc").
-        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
         mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
         std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
         embed_size (int, optional): Number of embedding size. Defaults to 128.
@@ -251,13 +329,11 @@ class CylinderEmbedding(base.Arch):
 
     Examples:
         >>> import ppsci
-        >>> model = ppsci.arch.CylinderEmbedding(("x", "y"), ("u", "v"))
+        >>> model = ppsci.arch.CylinderEmbeddingLayer()
     """
 
     def __init__(
         self,
-        input_keys: Tuple[str, ...],
-        output_keys: Tuple[str, ...],
         mean: Optional[Tuple[float, ...]] = None,
         std: Optional[Tuple[float, ...]] = None,
         embed_size: int = 128,
@@ -266,8 +342,6 @@ def __init__(
         drop: float = 0.0,
     ):
         super().__init__()
-        self.input_keys = input_keys
-        self.output_keys = output_keys
         self.embed_size = embed_size
 
         X, Y = np.meshgrid(np.linspace(-2, 14, 128), np.linspace(-4, 4, 64))
@@ -471,7 +545,7 @@ def _normalize(self, x: paddle.Tensor):
     def _unnormalize(self, x: paddle.Tensor):
         return self.std[:, :3] * x + self.mean[:, :3]
 
-    def forward_tensor(self, states, visc):
+    def forward(self, states, visc):
         # states.shape=(B, T, C, H, W)
         embed_data = self.encoder(states, visc)
         recover_data = self.decoder(embed_data)
@@ -482,17 +556,58 @@ def forward_tensor(self, states, visc):
 
         return (pred_data[:, :-1], recover_data, k_matrix)
 
+
+class CylinderEmbedding(CylinderEmbeddingLayer):
+    """Embedding Koopman model for the Cylinder system.
+
+    Args:
+        input_keys (Tuple[str, ...]): Input keys, such as ("states", "visc").
+        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
+        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
+        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
+        embed_size (int, optional): Number of embedding size. Defaults to 128.
+        encoder_channels (Optional[Tuple[int, ...]]): Number of channels in encoder network. Defaults to None.
+        decoder_channels (Optional[Tuple[int, ...]]): Number of channels in decoder network. Defaults to None.
+        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.
+
+    Examples:
+        >>> import ppsci
+        >>> model = ppsci.arch.CylinderEmbedding(("x", "y"), ("u", "v"))
+    """
+
+    def __init__(
+        self,
+        input_keys: Tuple[str, ...],
+        output_keys: Tuple[str, ...],
+        mean: Optional[Tuple[float, ...]] = None,
+        std: Optional[Tuple[float, ...]] = None,
+        embed_size: int = 128,
+        encoder_channels: Optional[Tuple[int, ...]] = None,
+        decoder_channels: Optional[Tuple[int, ...]] = None,
+        drop: float = 0.0,
+    ):
+        super().__init__()
+        self.input_keys = input_keys
+        self.output_keys = output_keys
+        super().__init__(
+            mean,
+            std,
+            embed_size,
+            encoder_channels,
+            decoder_channels,
+            drop,
+        )
+
     def split_to_dict(
         self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
     ):
         return {key: data_tensors[i] for i, key in enumerate(keys)}
 
     def forward(self, x):
-
         if self._input_transform is not None:
             x = self._input_transform(x)
 
-        y = self.forward_tensor(**x)
+        y = super().forward(**x)
         y = self.split_to_dict(y, self.output_keys)
 
         if self._output_transform is not None:
diff --git a/ppsci/arch/gan.py b/ppsci/arch/gan.py
index 3c7ec2c192..da39037104 100644
--- a/ppsci/arch/gan.py
+++ b/ppsci/arch/gan.py
@@ -24,6 +24,13 @@
 from ppsci.arch import activation as act_mod
 from ppsci.arch import base
 
+__all__ = [
+    "GeneratorLayer",
+    "Generator",
+    "DiscriminatorLayer",
+    "Discriminator",
+]
+
 
 class Conv2DBlock(nn.Layer):
     def __init__(
@@ -151,13 +158,10 @@ def forward(self, x):
         return y
 
 
-class Generator(base.Arch):
-    """Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is
-        unique to "tempoGAN" example but not an open source network.
+class GeneratorLayer(base.Arch):
+    """Generator Layer of GAN, core implementation of Generator.
 
     Args:
-        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
-        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
         in_channel (int): Number of input channels of the first conv layer.
         out_channels_tuple (Tuple[Tuple[int, ...], ...]): Number of output channels of all conv layers,
             such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]]
@@ -181,13 +185,11 @@ class Generator(base.Arch):
         >>> strides_tuple = ((1, 1, 1), ) * 4
         >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), )
         >>> acts_tuple = (("relu", None, None), ) * 4
-        >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
+        >>> model = ppsci.arch.GeneratorLayer(in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
     """
 
     def __init__(
         self,
-        input_keys: Tuple[str, ...],
-        output_keys: Tuple[str, ...],
         in_channel: int,
         out_channels_tuple: Tuple[Tuple[int, ...], ...],
         kernel_sizes_tuple: Tuple[Tuple[int, ...], ...],
@@ -196,8 +198,6 @@ def __init__(
         acts_tuple: Tuple[Tuple[str, ...], ...],
     ):
         super().__init__()
-        self.input_keys = input_keys
-        self.output_keys = output_keys
         self.in_channel = in_channel
         self.out_channels_tuple = out_channels_tuple
         self.kernel_sizes_tuple = kernel_sizes_tuple
@@ -228,18 +228,74 @@ def init_blocks(self):
             )
         self.blocks = nn.LayerList(blocks_list)
 
-    def forward_tensor(self, x):
+    def forward(self, x):
         y = x
         for block in self.blocks:
             y = block(y)
         return y
 
+
+class Generator(GeneratorLayer):
+    """Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is
+        unique to "tempoGAN" example but not an open source network.
+
+    Args:
+        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
+        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
+        in_channel (int): Number of input channels of the first conv layer.
+        out_channels_tuple (Tuple[Tuple[int, ...], ...]): Number of output channels of all conv layers,
+            such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]]
+        kernel_sizes_tuple (Tuple[Tuple[int, ...], ...]): Number of kernel_size of all conv layers,
+            such as [[kernel_size_res0_conv0, kernel_size_res0_conv1], [kernel_size_res1_conv0, kernel_size_res1_conv1]]
+        strides_tuple (Tuple[Tuple[int, ...], ...]): Number of stride of all conv layers,
+            such as [[stride_res0_conv0, stride_res0_conv1], [stride_res1_conv0, stride_res1_conv1]]
+        use_bns_tuple (Tuple[Tuple[bool, ...], ...]): Whether to use the batch_norm layer after each conv layer.
+        acts_tuple (Tuple[Tuple[str, ...], ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
+            such as [[act_res0_conv0, act_res0_conv1], [act_res1_conv0, act_res1_conv1]]
+
+    Examples:
+        >>> import ppsci
+        >>> in_channel = 1
+        >>> rb_channel0 = (2, 8, 8)
+        >>> rb_channel1 = (128, 128, 128)
+        >>> rb_channel2 = (32, 8, 8)
+        >>> rb_channel3 = (2, 1, 1)
+        >>> out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3)
+        >>> kernel_sizes_tuple = (((5, 5), ) * 2 + ((1, 1), ), ) * 4
+        >>> strides_tuple = ((1, 1, 1), ) * 4
+        >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), )
+        >>> acts_tuple = (("relu", None, None), ) * 4
+        >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
+    """
+
+    def __init__(
+        self,
+        input_keys: Tuple[str, ...],
+        output_keys: Tuple[str, ...],
+        in_channel: int,
+        out_channels_tuple: Tuple[Tuple[int, ...], ...],
+        kernel_sizes_tuple: Tuple[Tuple[int, ...], ...],
+        strides_tuple: Tuple[Tuple[int, ...], ...],
+        use_bns_tuple: Tuple[Tuple[bool, ...], ...],
+        acts_tuple: Tuple[Tuple[str, ...], ...],
+    ):
+        self.input_keys = input_keys
+        self.output_keys = output_keys
+        super().__init__(
+            in_channel,
+            out_channels_tuple,
+            kernel_sizes_tuple,
+            strides_tuple,
+            use_bns_tuple,
+            acts_tuple,
+        )
+
     def forward(self, x):
         if self._input_transform is not None:
             x = self._input_transform(x)
 
         y = self.concat_to_tensor(x, self.input_keys, axis=-1)
-        y = self.forward_tensor(y)
+        y = super().forward(y)
         y = self.split_to_dict(y, self.output_keys, axis=-1)
 
         if self._output_transform is not None:
@@ -247,12 +303,10 @@ def forward(self, x):
         return y
 
 
-class Discriminator(base.Arch):
-    """Discriminator Net of GAN.
+class DiscriminatorLayer(base.Arch):
+    """Discriminator Net of GAN, core implementation of Discriminator.
 
     Args:
-        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
-        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
         in_channel (int):  Number of input channels of the first conv layer.
         out_channels (Tuple[int, ...]): Number of output channels of all conv layers,
             such as (out_conv0, out_conv1, out_conv2).
@@ -277,13 +331,11 @@ class Discriminator(base.Arch):
         >>> use_bns = (False, True, True, True)
         >>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None)
         >>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10")
-        >>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
+        >>> model = ppsci.arch.DiscriminatorLayer(in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
     """
 
     def __init__(
         self,
-        input_keys: Tuple[str, ...],
-        output_keys: Tuple[str, ...],
         in_channel: int,
         out_channels: Tuple[int, ...],
         fc_channel: int,
@@ -293,8 +345,6 @@ def __init__(
         acts: Tuple[str, ...],
     ):
         super().__init__()
-        self.input_keys = input_keys
-        self.output_keys = output_keys
         self.in_channel = in_channel
         self.out_channels = out_channels
         self.fc_channel = fc_channel
@@ -328,7 +378,7 @@ def init_layers(self):
         )
         self.layers = nn.LayerList(layers_list)
 
-    def forward_tensor(self, x):
+    def forward(self, x):
         y = x
         y_list = []
         for layer in self.layers:
@@ -336,6 +386,64 @@ def forward_tensor(self, x):
             y_list.append(y)
         return y_list  # y_conv1, y_conv2, y_conv3, y_conv4, y_fc(y_out)
 
+
+class Discriminator(DiscriminatorLayer):
+    """Discriminator Net of GAN.
+
+    Args:
+        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
+        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
+        in_channel (int):  Number of input channels of the first conv layer.
+        out_channels (Tuple[int, ...]): Number of output channels of all conv layers,
+            such as (out_conv0, out_conv1, out_conv2).
+        fc_channel (int):  Number of input features of linear layer. Number of output features of the layer
+            is set to 1 in this Net to construct a fully_connected layer.
+        kernel_sizes (Tuple[int, ...]): Number of kernel_size of all conv layers,
+            such as (kernel_size_conv0, kernel_size_conv1, kernel_size_conv2).
+        strides (Tuple[int, ...]): Number of stride of all conv layers,
+            such as (stride_conv0, stride_conv1, stride_conv2).
+        use_bns (Tuple[bool, ...]): Whether to use the batch_norm layer after each conv layer.
+        acts (Tuple[str, ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
+            such as (act_conv0, act_conv1, act_conv2).
+
+    Examples:
+        >>> import ppsci
+        >>> in_channel = 2
+        >>> in_channel_tempo = 3
+        >>> out_channels = (32, 64, 128, 256)
+        >>> fc_channel = 65536
+        >>> kernel_sizes = ((4, 4), (4, 4), (4, 4), (4, 4))
+        >>> strides = (2, 2, 2, 1)
+        >>> use_bns = (False, True, True, True)
+        >>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None)
+        >>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10")
+        >>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
+    """
+
+    def __init__(
+        self,
+        input_keys: Tuple[str, ...],
+        output_keys: Tuple[str, ...],
+        in_channel: int,
+        out_channels: Tuple[int, ...],
+        fc_channel: int,
+        kernel_sizes: Tuple[int, ...],
+        strides: Tuple[int, ...],
+        use_bns: Tuple[bool, ...],
+        acts: Tuple[str, ...],
+    ):
+        self.input_keys = input_keys
+        self.output_keys = output_keys
+        super().__init__(
+            in_channel,
+            out_channels,
+            fc_channel,
+            kernel_sizes,
+            strides,
+            use_bns,
+            acts,
+        )
+
     def forward(self, x):
         if self._input_transform is not None:
             x = self._input_transform(x)
@@ -343,7 +451,7 @@ def forward(self, x):
         y_list = []
         # y1_conv1, y1_conv2, y1_conv3, y1_conv4, y1_fc, y2_conv1, y2_conv2, y2_conv3, y2_conv4, y2_fc
         for k in x:
-            y_list.extend(self.forward_tensor(x[k]))
+            y_list.extend(super().forward(x[k]))
 
         y = self.split_to_dict(y_list, self.output_keys)
 
diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py
index c874669d2e..3ba5cb372a 100644
--- a/ppsci/arch/mlp.py
+++ b/ppsci/arch/mlp.py
@@ -24,6 +24,12 @@
 from ppsci.arch import base
 from ppsci.utils import initializer
 
+__all__ = [
+    "WeightNormLinear",
+    "FullyConnectedLayer",
+    "MLP",
+]
+
 
 class WeightNormLinear(nn.Layer):
     def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
@@ -50,43 +56,35 @@ def forward(self, input):
         return nn.functional.linear(input, weight, self.bias)
 
 
-class MLP(base.Arch):
-    """Multi layer perceptron network.
+class FullyConnectedLayer(base.Arch):
+    """Fully Connected Layer, core implementation of MLP.
 
     Args:
-        input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z").
-        output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w").
+        input_dim (int): Number of input's dimension.
+        output_dim (int): Number of output's dimension.
         num_layers (int): Number of hidden layers.
         hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size.
             An integer for all layers, or list of integer specify each layer's size.
         activation (str, optional): Name of activation function. Defaults to "tanh".
         skip_connection (bool, optional): Whether to use skip connection. Defaults to False.
         weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False.
-        input_dim (Optional[int]): Number of input's dimension. Defaults to None.
-        output_dim (Optional[int]): Number of output's dimension. Defaults to None.
 
     Examples:
         >>> import ppsci
-        >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 5, 128)
+        >>> model = ppsci.arch.FullyConnectedLayer(3, 4, num_layers=5, hidden_size=128)
     """
 
     def __init__(
         self,
-        input_keys: Tuple[str, ...],
-        output_keys: Tuple[str, ...],
+        input_dim: int,
+        output_dim: int,
         num_layers: int,
         hidden_size: Union[int, Tuple[int, ...]],
         activation: str = "tanh",
         skip_connection: bool = False,
         weight_norm: bool = False,
-        input_dim: Optional[int] = None,
-        output_dim: Optional[int] = None,
     ):
         super().__init__()
-        self.input_keys = input_keys
-        self.output_keys = output_keys
-        self.linears = []
-        self.acts = []
         if isinstance(hidden_size, (tuple, list)):
             if num_layers is not None:
                 raise ValueError(
@@ -104,8 +102,11 @@ def __init__(
                 f"but got {type(hidden_size)}"
             )
 
+        self.linears = nn.LayerList()
+        self.acts = nn.LayerList()
+
         # initialize FC layer(s)
-        cur_size = len(self.input_keys) if input_dim is None else input_dim
+        cur_size = input_dim
         for i, _size in enumerate(hidden_size):
             self.linears.append(
                 WeightNormLinear(cur_size, _size)
@@ -128,16 +129,11 @@ def __init__(
 
             cur_size = _size
 
-        self.linears = nn.LayerList(self.linears)
-        self.acts = nn.LayerList(self.acts)
-        self.last_fc = nn.Linear(
-            cur_size,
-            len(self.output_keys) if output_dim is None else output_dim,
-        )
+        self.last_fc = nn.Linear(cur_size, output_dim)
 
         self.skip_connection = skip_connection
 
-    def forward_tensor(self, x):
+    def forward(self, x):
         y = x
         skip = None
         for i, linear in enumerate(self.linears):
@@ -154,12 +150,58 @@ def forward_tensor(self, x):
 
         return y
 
+
+class MLP(FullyConnectedLayer):
+    """Multi layer perceptron network derivated by FullyConnectedLayer.
+    Different from `FullyConnectedLayer`, this class accepts input/output string key(s) for symbolic computation.
+
+    Args:
+        input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z").
+        output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w").
+        num_layers (int): Number of hidden layers.
+        hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size.
+            An integer for all layers, or list of integer specify each layer's size.
+        activation (str, optional): Name of activation function. Defaults to "tanh".
+        skip_connection (bool, optional): Whether to use skip connection. Defaults to False.
+        weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False.
+        input_dim (Optional[int]): Number of input's dimension. Defaults to None.
+        output_dim (Optional[int]): Number of output's dimension. Defaults to None.
+
+    Examples:
+        >>> import ppsci
+        >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), num_layers=5, hidden_size=128)
+    """
+
+    def __init__(
+        self,
+        input_keys: Tuple[str, ...],
+        output_keys: Tuple[str, ...],
+        num_layers: int,
+        hidden_size: Union[int, Tuple[int, ...]],
+        activation: str = "tanh",
+        skip_connection: bool = False,
+        weight_norm: bool = False,
+        input_dim: Optional[int] = None,
+        output_dim: Optional[int] = None,
+    ):
+        self.input_keys = input_keys
+        self.output_keys = output_keys
+        super().__init__(
+            len(input_keys) if not input_dim else input_dim,
+            len(output_keys) if not output_dim else output_dim,
+            num_layers,
+            hidden_size,
+            activation,
+            skip_connection,
+            weight_norm,
+        )
+
     def forward(self, x):
         if self._input_transform is not None:
             x = self._input_transform(x)
 
         y = self.concat_to_tensor(x, self.input_keys, axis=-1)
-        y = self.forward_tensor(y)
+        y = super().forward(y)
         y = self.split_to_dict(y, self.output_keys, axis=-1)
 
         if self._output_transform is not None:
diff --git a/ppsci/arch/model_list.py b/ppsci/arch/model_list.py
index f463f17226..746a787c27 100644
--- a/ppsci/arch/model_list.py
+++ b/ppsci/arch/model_list.py
@@ -20,6 +20,10 @@
 
 from ppsci.arch import base
 
+__all__ = [
+    "ModelList",
+]
+
 
 class ModelList(base.Arch):
     """ModelList layer which wrap more than one model that shares inputs.
diff --git a/ppsci/arch/phylstm.py b/ppsci/arch/phylstm.py
index 14ead3f6c1..b90a1e6561 100644
--- a/ppsci/arch/phylstm.py
+++ b/ppsci/arch/phylstm.py
@@ -17,9 +17,15 @@
 
 from ppsci.arch import base
 
+__all__ = [
+    "DeepPhyLSTM",
+]
+
 
 class DeepPhyLSTM(base.Arch):
-    """DeepPhyLSTM init function.
+    """Physics-informed LSTM Network.
+    Zhang, R., Liu, Y., & Sun, H. (2020). Physics-informed multi-LSTM networks for metamodeling of nonlinear structures.
+    Computer Methods in Applied Mechanics and Engineering 369, 113226.
 
     Args:
         input_size (int): The input size.
@@ -99,12 +105,13 @@ def forward(self, x):
             x = self._input_transform(x)
 
         if self.model_type == 2:
-            result_dict = self._forward_type_2(x)
+            y = self._forward_type_2(x)
         elif self.model_type == 3:
-            result_dict = self._forward_type_3(x)
+            y = self._forward_type_3(x)
+
         if self._output_transform is not None:
-            result_dict = self._output_transform(x, result_dict)
-        return result_dict
+            y = self._output_transform(x, y)
+        return y
 
     def _forward_type_2(self, x):
         output = self.lstm_model(x["ag"])
diff --git a/ppsci/arch/physx_transformer.py b/ppsci/arch/physx_transformer.py
index a3fdb81207..2e4f49d215 100644
--- a/ppsci/arch/physx_transformer.py
+++ b/ppsci/arch/physx_transformer.py
@@ -29,6 +29,12 @@
 
 from ppsci.arch import base
 
+__all__ = [
+    "PhysformerGPT2Layer",
+    "PhysformerGPT2",
+]
+
+
 zeros_ = Constant(value=0.0)
 ones_ = Constant(value=1.0)
 
@@ -237,12 +243,10 @@ def forward(
         return outputs
 
 
-class PhysformerGPT2(base.Arch):
-    """Transformer decoder model for modeling physics.
+class PhysformerGPT2Layer(base.Arch):
+    """Transformer decoder layer for modeling physics, core implementation of PhysformerGPT2.
 
     Args:
-        input_keys (Tuple[str, ...]): Input keys, such as ("embeds",).
-        output_keys (Tuple[str, ...]): Output keys, such as ("pred_embeds",).
         num_layers (int): Number of transformer layers.
         num_ctx (int): Contex length of block.
         embed_size (int): The number of embedding size.
@@ -254,13 +258,11 @@ class PhysformerGPT2(base.Arch):
 
     Examples:
         >>> import ppsci
-        >>> model = ppsci.arch.PhysformerGPT2(("embeds", ), ("pred_embeds", ), 6, 16, 128, 4)
+        >>> model = ppsci.arch.PhysformerGPT2Layer(6, 16, 128, 4)
     """
 
     def __init__(
         self,
-        input_keys: Tuple[str, ...],
-        output_keys: Tuple[str, ...],
         num_layers: int,
         num_ctx: int,
         embed_size: int,
@@ -271,9 +273,6 @@ def __init__(
         initializer_range: float = 0.05,
     ):
         super().__init__()
-        self.input_keys = input_keys
-        self.output_keys = output_keys
-
         self.num_layers = num_layers
         self.num_ctx = num_ctx
         self.embed_size = embed_size
@@ -349,7 +348,7 @@ def generate(self, x, max_length=256):
         outputs = self._generate_time_series(x, max_length)
         return outputs
 
-    def forward_tensor(self, x):
+    def forward(self, x):
         position_embeds = self.get_position_embed(x)
         # Combine input embedding, position embeding
         hidden_states = x + position_embeds
@@ -367,18 +366,69 @@ def forward_eval(self, x):
         outputs = self.generate(input_embeds)
         return (outputs[:, 1:],)
 
+
+class PhysformerGPT2(PhysformerGPT2Layer):
+    """Transformer decoder model for modeling physics.
+
+    Args:
+        input_keys (Tuple[str, ...]): Input keys, such as ("embeds",).
+        output_keys (Tuple[str, ...]): Output keys, such as ("pred_embeds",).
+        num_layers (int): Number of transformer layers.
+        num_ctx (int): Contex length of block.
+        embed_size (int): The number of embedding size.
+        num_heads (int): The number of heads in multi-head attention.
+        embd_pdrop (float, optional): The dropout probability used on embedding features. Defaults to 0.0.
+        attn_pdrop (float, optional): The dropout probability used on attention weights. Defaults to 0.0.
+        resid_pdrop (float, optional): The dropout probability used on block outputs. Defaults to 0.0.
+        initializer_range (float, optional): Initializer range of linear layer. Defaults to 0.05.
+
+    Examples:
+        >>> import ppsci
+        >>> model = ppsci.arch.PhysformerGPT2(("embeds", ), ("pred_embeds", ), 6, 16, 128, 4)
+    """
+
+    def __init__(
+        self,
+        input_keys: Tuple[str, ...],
+        output_keys: Tuple[str, ...],
+        num_layers: int,
+        num_ctx: int,
+        embed_size: int,
+        num_heads: int,
+        embd_pdrop: float = 0.0,
+        attn_pdrop: float = 0.0,
+        resid_pdrop: float = 0.0,
+        initializer_range: float = 0.05,
+    ):
+        self.input_keys = input_keys
+        self.output_keys = output_keys
+        super().__init__(
+            num_layers,
+            num_ctx,
+            embed_size,
+            num_heads,
+            embd_pdrop,
+            attn_pdrop,
+            resid_pdrop,
+            initializer_range,
+        )
+
     def split_to_dict(self, data_tensors, keys):
         return {key: data_tensors[i] for i, key in enumerate(keys)}
 
     def forward(self, x):
         if self._input_transform is not None:
             x = self._input_transform(x)
+
         x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1)
+
         if self.training:
-            y = self.forward_tensor(x_tensor)
+            y = super().forward(x_tensor)
         else:
-            y = self.forward_eval(x_tensor)
+            y = super().forward_eval(x_tensor)
+
         y = self.split_to_dict(y, self.output_keys)
+
         if self._output_transform is not None:
             y = self._output_transform(x, y)
         return y
diff --git a/ppsci/arch/unetex.py b/ppsci/arch/unetex.py
index 4972e55088..6fa19b4f15 100644
--- a/ppsci/arch/unetex.py
+++ b/ppsci/arch/unetex.py
@@ -21,6 +21,11 @@
 
 from ppsci.arch import base
 
+__all__ = [
+    "UNetExLayer",
+    "UNetEx",
+]
+
 
 def create_layer(
     in_channel,
@@ -173,14 +178,12 @@ def create_decoder(
     return nn.Sequential(*decoder)
 
 
-class UNetEx(base.Arch):
-    """U-Net
+class UNetExLayer(base.Arch):
+    """U-NetEx layer, core implementation of UNetEx
 
     [Ribeiro M D, Rehman A, Ahmed S, et al. DeepCFD: Efficient steady-state laminar flow approximation with deep convolutional neural networks[J]. arXiv preprint arXiv:2004.08826, 2020.](https://arxiv.org/abs/2004.08826)
 
     Args:
-        input_key (str): Name of function data for input.
-        output_key (str): Name of function data for output.
         in_channel (int): Number of channels of input.
         out_channel (int): Number of channels of output.
         kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3.
@@ -193,13 +196,11 @@ class UNetEx(base.Arch):
 
     Examples:
         >>> import ppsci
-        >>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 3, 3, (8, 16, 32, 32), 5, Flase, False)
+        >>> model = ppsci.arch.ppsci.arch.UNetEx(3, 3, (8, 16, 32, 32), 5, Flase, False)
     """
 
     def __init__(
         self,
-        input_key: str,
-        output_key: str,
         in_channel: int,
         out_channel: int,
         kernel_size: int = 3,
@@ -214,8 +215,6 @@ def __init__(
             raise ValueError("The filters shouldn't be empty ")
 
         super().__init__()
-        self.input_keys = (input_key,)
-        self.output_keys = (output_key,)
         self.final_activation = final_activation
         self.encoder = create_encoder(
             in_channel,
@@ -265,9 +264,75 @@ def decode(self, x, tensors, indices, sizes):
         return paddle.concat(y, axis=1)
 
     def forward(self, x):
-        x = x[self.input_keys[0]]
         x, tensors, indices, sizes = self.encode(x)
         x = self.decode(x, tensors, indices, sizes)
         if self.final_activation is not None:
             x = self.final_activation(x)
-        return {self.output_keys[0]: x}
+        return x
+
+
+class UNetEx(UNetExLayer):
+    """U-NetEx.
+
+    [Ribeiro M D, Rehman A, Ahmed S, et al. DeepCFD: Efficient steady-state laminar flow approximation with deep convolutional neural networks[J]. arXiv preprint arXiv:2004.08826, 2020.](https://arxiv.org/abs/2004.08826)
+
+    Args:
+        input_key (str): Name of function data for input.
+        output_key (str): Name of function data for output.
+        in_channel (int): Number of channels of input.
+        out_channel (int): Number of channels of output.
+        kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3.
+        filters (Tuple[int, ...], optional): Number of filters. Defaults to (16, 32, 64).
+        layers (int, optional): Number of encoders or decoders. Defaults to 3.
+        weight_norm (bool, optional): Whether use weight normalization layer. Defaults to True.
+        batch_norm (bool, optional): Whether add batch normalization layer. Defaults to True.
+        activation (Type[nn.Layer], optional): Name of activation function. Defaults to nn.ReLU.
+        final_activation (Optional[Type[nn.Layer]]): Name of final activation function. Defaults to None.
+
+    Examples:
+        >>> import ppsci
+        >>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 3, 3, (8, 16, 32, 32), 5, Flase, False)
+    """
+
+    def __init__(
+        self,
+        input_key: str,
+        output_key: str,
+        in_channel: int,
+        out_channel: int,
+        kernel_size: int = 3,
+        filters: Tuple[int, ...] = (16, 32, 64),
+        layers: int = 3,
+        weight_norm: bool = True,
+        batch_norm: bool = True,
+        activation: Type[nn.Layer] = nn.ReLU,
+        final_activation: Optional[Type[nn.Layer]] = None,
+    ):
+        if len(filters) == 0:
+            raise ValueError("The filters shouldn't be empty ")
+
+        self.input_keys = (input_key,)
+        self.output_keys = (output_key,)
+        super().__init__(
+            in_channel,
+            out_channel,
+            kernel_size,
+            filters,
+            layers,
+            weight_norm,
+            batch_norm,
+            activation,
+            final_activation,
+        )
+
+    def forward(self, x):
+        if self._input_transform is not None:
+            x = self._input_transform(x)
+
+        x_tensor = x[self.input_keys[0]]
+        y = super().forward(x_tensor)
+        y = {self.output_keys[0]: y}
+
+        if self._output_transform is not None:
+            y = self._output_transform(x, y)
+        return y
diff --git a/test/equation/test_biharmonic.py b/test/equation/test_biharmonic.py
index 8e1d6c2be0..74aecb700f 100644
--- a/test/equation/test_biharmonic.py
+++ b/test/equation/test_biharmonic.py
@@ -31,10 +31,12 @@ def test_biharmonic(dim):
         input_data = paddle.concat([x, y, z], axis=1)
 
     # build NN model
-    model = arch.MLP(input_dims, output_dims, 2, 16)
+    model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16)
+    model_sym = arch.MLP(input_dims, output_dims, 2, 16)
+    model_sym.load_dict(model.state_dict())
 
     # manually generate output
-    u = model.forward_tensor(input_data)
+    u = model(input_data)
 
     # use self-defined jacobian and hessian
     def jacobian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor":
@@ -60,7 +62,7 @@ def hessian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor":
         if isinstance(expr, sp.Basic):
             biharmonic_equation.equations[name] = ppsci.lambdify(
                 expr,
-                model,
+                model_sym,
             )
     data_dict = {
         "x": x,
diff --git a/test/equation/test_laplace.py b/test/equation/test_laplace.py
index 6c438df3e4..1f79af940c 100644
--- a/test/equation/test_laplace.py
+++ b/test/equation/test_laplace.py
@@ -28,10 +28,12 @@ def test_l1loss_mean(dim):
         input_data = paddle.concat([x, y, z], axis=1)
 
     # build NN model
-    model = arch.MLP(input_dims, output_dims, 2, 16)
+    model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16)
+    model_sym = arch.MLP(input_dims, output_dims, 2, 16)
+    model_sym.load_dict(model.state_dict())
 
     # manually generate output
-    u = model.forward_tensor(input_data)
+    u = model(input_data)
 
     # use self-defined jacobian and hessian
     def jacobian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor":
@@ -51,7 +53,7 @@ def hessian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor":
         if isinstance(expr, sp.Basic):
             laplace_equation.equations[name] = ppsci.lambdify(
                 expr,
-                model,
+                model_sym,
             )
 
     data_dict = {
diff --git a/test/equation/test_linear_elasticity.py b/test/equation/test_linear_elasticity.py
index 973e3df104..ed418f01a0 100644
--- a/test/equation/test_linear_elasticity.py
+++ b/test/equation/test_linear_elasticity.py
@@ -172,14 +172,12 @@ def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time):
     if dim == 3:
         input_data = paddle.concat([input_data, z], axis=1)
 
-    model = arch.MLP(input_dims, output_dims, 2, 16)
+    model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16)
+    model_sym = arch.MLP(input_dims, output_dims, 2, 16)
+    model_sym.load_dict(model.state_dict())
 
-    # model = nn.Sequential(
-    #     nn.Linear(input_data.shape[1], 9 if dim == 3 else 5),
-    #     nn.Tanh(),
-    # )
-
-    output = model.forward_tensor(input_data)
+    # manually generate output
+    output = model(input_data)
 
     u, v, *other_outputs = paddle.split(output, num_or_sections=output.shape[1], axis=1)
 
@@ -234,7 +232,7 @@ def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time):
         if isinstance(expr, sp.Basic):
             linear_elasticity.equations[name] = ppsci.lambdify(
                 expr,
-                model,
+                model_sym,
             )
     data_dict = {
         "t": t,
diff --git a/test/equation/test_navier_stokes.py b/test/equation/test_navier_stokes.py
index 0279374ac8..7bc55b6fef 100644
--- a/test/equation/test_navier_stokes.py
+++ b/test/equation/test_navier_stokes.py
@@ -110,10 +110,12 @@ def test_navierstokes(nu, rho, dim, time):
         input_dims = input_dims + ("z",)
     input_data = paddle.concat(inputs, axis=1)
 
-    model = arch.MLP(input_dims, output_dims, 2, 16)
+    model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16)
+    model_sym = arch.MLP(input_dims, output_dims, 2, 16)
+    model_sym.load_dict(model.state_dict())
 
     # manually generate output
-    output = model.forward_tensor(input_data)
+    output = model(input_data)
 
     if dim == 2:
         u, v, p = paddle.split(output, num_or_sections=len(output_dims), axis=1)
@@ -140,7 +142,7 @@ def test_navierstokes(nu, rho, dim, time):
         if isinstance(expr, sp.Basic):
             navier_stokes_equation.equations[name] = ppsci.lambdify(
                 expr,
-                model,
+                model_sym,
             )
 
     data_dict = {"x": x, "y": y, "u": u, "v": v, "p": p}
diff --git a/test/equation/test_poisson.py b/test/equation/test_poisson.py
index ca86d98db2..627b8352bf 100644
--- a/test/equation/test_poisson.py
+++ b/test/equation/test_poisson.py
@@ -42,10 +42,12 @@ def test_poisson(dim):
         input_data = paddle.concat([x, y, z], axis=1)
 
     # build NN model
-    model = arch.MLP(input_dims, output_dims, 2, 16)
+    model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16)
+    model_sym = arch.MLP(input_dims, output_dims, 2, 16)
+    model_sym.load_dict(model.state_dict())
 
     # manually generate output
-    p = model.forward_tensor(input_data)
+    p = model(input_data)
 
     def jacobian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor:
         return paddle.grad(y, x, create_graph=True)[0]
@@ -64,7 +66,7 @@ def hessian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor:
         if isinstance(expr, sp.Basic):
             poisson_equation.equations[name] = ppsci.lambdify(
                 expr,
-                model,
+                model_sym,
             )
 
     data_dict = {
diff --git a/test/equation/test_viv.py b/test/equation/test_viv.py
index 2dc979912d..8806ef9888 100644
--- a/test/equation/test_viv.py
+++ b/test/equation/test_viv.py
@@ -33,10 +33,12 @@ def test_vibration(rho, k1, k2):
     input_data = t_f
     input_dims = ("t_f",)
     output_dims = ("eta",)
-    model = arch.MLP(input_dims, output_dims, 2, 16)
+    model = arch.FullyConnectedLayer(len(input_dims), len(output_dims), 2, 16)
+    model_sym = arch.MLP(input_dims, output_dims, 2, 16)
+    model_sym.load_dict(model.state_dict())
 
     # manually generate output
-    eta = model.forward_tensor(input_data)
+    eta = model(input_data)
 
     def jacobian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor:
         return paddle.grad(y, x, create_graph=True)[0]
@@ -56,7 +58,7 @@ def hessian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor:
         if isinstance(expr, sp.Basic):
             vibration_equation.equations[name] = ppsci.lambdify(
                 expr,
-                model,
+                model_sym,
                 vibration_equation.learnable_parameters,
             )
     input_data_dict = {"t_f": t_f}