@@ -26,6 +26,16 @@ class GraphConstruction(nn.Module, core.Configurable):
2626 2. For ``gearnet``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue :math:`j`
2727 is the concatenation ``[residue_type(i), residue_type(j), edge_type(e_ij),
2828 sequential_distance(i,j), spatial_distance(i,j)]``.
29+
30+ .. note::
31+ You may customize your own edge features by inheriting this class and define a member function
32+ for your features. Use ``edge_feature="my_feature"`` to call the following feature function.
33+
34+ .. code:: python
35+
36+ def edge_my_feature(self, graph, edge_list, num_relation):
37+ ...
38+ return feature # the first dimension must be ``graph.num_edge``
2939 """
3040
3141 max_seq_dist = 10
@@ -43,7 +53,7 @@ def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_typ
4353 self .edge_layers = edge_layers
4454 self .edge_feature = edge_feature
4555
46- def edge_residue_type (self , graph , edge_list ):
56+ def edge_residue_type (self , graph , edge_list , num_relation ):
4757 node_in , node_out , _ = edge_list .t ()
4858 residue_in , residue_out = graph .atom2residue [node_in ], graph .atom2residue [node_out ]
4959 in_residue_type = graph .residue_type [residue_in ]
@@ -103,10 +113,8 @@ def apply_edge_layer(self, graph):
103113 num_edges = edge2graph .bincount (minlength = graph .batch_size )
104114 offsets = (graph .num_cum_nodes - graph .num_nodes ).repeat_interleave (num_edges )
105115
106- if self .edge_feature == "residue_type" :
107- edge_feature = self .edge_residue_type (graph , edge_list )
108- elif self .edge_feature == "gearnet" :
109- edge_feature = self .edge_gearnet (graph , edge_list , num_relation )
116+ if hasattr (self , "edge_%s" % self .edge_feature ):
117+ edge_feature = getattr (self , "edge_%s" % self .edge_feature )(graph , edge_list , num_relation )
110118 elif self .edge_feature is None :
111119 edge_feature = None
112120 else :
0 commit comments