Skip to content

Commit ce75ce4

Browse files
committed
add some documentation
1 parent 957cec7 commit ce75ce4

File tree

10 files changed

+85
-34
lines changed

10 files changed

+85
-34
lines changed

asset/graph/correct_reference.png

18.2 KB
Loading

asset/graph/inverse_edge.png

30.1 KB
Loading

asset/graph/wrong_reference.png

18.1 KB
Loading

doc/source/api/layers.rst

+6
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ Variadic
214214

215215
.. autofunction:: variadic_sample
216216

217+
.. autofunction:: variadic_meshgrid
218+
219+
.. autofunction:: variadic_to_padded
220+
221+
.. autofunction:: padded_to_variadic
222+
217223
Tensor Reduction
218224
^^^^^^^^^^^^^^^^
219225
.. autofunction:: masked_mean

doc/source/api/metrics.rst

+15-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,21 @@ R2
2626
^^
2727
.. autofunction:: r2
2828

29-
Variadic Accuracy
30-
^^^^^^^^^^^^^^^^^
31-
.. autofunction:: variadic_accuracy
29+
Accuracy
30+
^^^^^^^^
31+
.. autofunction:: accuracy
32+
33+
Matthews Correlation Coefficient
34+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
35+
.. autofuction:: matthews_corrcoef
36+
37+
Pearson Correlation Coefficient
38+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
39+
.. autofunction:: pearsonr
40+
41+
Spearman Correlation Coefficient
42+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
43+
.. autofunction:: spearmanr
3244

3345

3446
Chemical Metrics

doc/source/notes/reference.rst

+2-31
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ the result is not desired. The edges are masked out correctly, but the values of
3030
inverse indexes are wrong.
3131

3232
.. code:: python
33+
3334
with graph.edge():
3435
graph.inv_edge_index = torch.tensor(inv_edge_index)
3536
g1 = graph.edge_mask([0, 2, 3])
@@ -55,34 +56,4 @@ since the corresponding inverse edge has been masked out.
5556
:width: 33%
5657

5758
We can use ``graph.node_reference()`` and ``graph.graph_reference()`` for references
58-
to nodes and graphs respectively.
59-
60-
Use Cases in Proteins
61-
---------------------
62-
63-
In :class:`data.Protein`, the mapping ``atom2residue`` is implemented as
64-
references. The intuition is that references enable flexible indexing on either atoms
65-
or residues, while maintaining the correspondence between two views.
66-
67-
The following example shows how to track a specific residue with ``atom2residue`` in
68-
the atom view. For a protein, we first create a mask for atoms in a glutamine (GLN).
69-
70-
.. code:: python
71-
72-
protein = data.Protein.from_sequence("KALKQMLDMG")
73-
is_glutamine = protein.residue_type[protein.atom2residue] == protein.residue2id["GLN"]
74-
with protein.node():
75-
protein.is_glutamine = is_glutamine
76-
77-
We then apply a mask to the protein residue sequence. In the output protein,
78-
``atom2residue`` is able to map the masked atoms back to the glutamine residue.
79-
80-
.. code:: python
81-
82-
p1 = protein[3:6]
83-
residue_type = p1.residue_type[p1.atom2residue[p1.is_glutamine]]
84-
print([p1.id2residue[r] for r in residue_type.tolist()])
85-
86-
.. code:: bash
87-
88-
['GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN']
59+
to nodes and graphs respectively.

doc/source/notes/variadic.rst

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ Naturally, the prediction over nodes also forms a variadic tensor with ``num_nod
113113
:func:`variadic_topk <torchdrug.layers.functional.variadic_topk>`,
114114
:func:`variadic_randperm <torchdrug.layers.functional.variadic_randperm>`,
115115
:func:`variadic_sample <torchdrug.layers.functional.variadic_sample>`,
116+
:func:`variadic_meshgrid <torchdrug.layers.functional.variadic_meshgrid`,
116117
:func:`variadic_softmax <torchdrug.layers.functional.variadic_softmax>`,
117118
:func:`variadic_log_softmax <torchdrug.layers.functional.variadic_log_softmax>`,
118119
:func:`variadic_cross_entropy <torchdrug.layers.functional.variadic_cross_entropy>`,

torchdrug/data/graph.py

+22
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,17 @@ def edge_mask(self, index):
699699
num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
700700

701701
def line_graph(self):
702+
"""
703+
Construct a line graph of this graph.
704+
The node feature of the line graph is inherited from the edge feature of the original graph.
705+
706+
In the line graph, each node corresponds to an edge in the original graph.
707+
For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph,
708+
there is a directed edge (a, b) -> (b, c) in the line graph.
709+
710+
Returns:
711+
Graph
712+
"""
702713
node_in, node_out = self.edge_list.t()[:2]
703714
edge_index = torch.arange(self.num_edge, device=self.device)
704715
edge_in = edge_index[node_out.argsort()]
@@ -1627,6 +1638,17 @@ def subbatch(self, index):
16271638
return self.graph_mask(index, compact=True)
16281639

16291640
def line_graph(self):
1641+
"""
1642+
Construct a packed line graph of this packed graph.
1643+
The node features of the line graphs are inherited from the edge features of the original graphs.
1644+
1645+
In the line graph, each node corresponds to an edge in the original graph.
1646+
For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph,
1647+
there is a directed edge (a, b) -> (b, c) in the line graph.
1648+
1649+
Returns:
1650+
PackedGraph
1651+
"""
16301652
node_in, node_out = self.edge_list.t()[:2]
16311653
edge_index = torch.arange(self.num_edge, device=self.device)
16321654
edge_in = edge_index[node_out.argsort()]

torchdrug/layers/functional/functional.py

+38
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ def variadic_sort(input, size, descending=False):
375375
input (Tensor): input of shape :math:`(B, ...)`
376376
size (LongTensor): size of sets of shape :math:`(N,)`
377377
descending (bool, optional): return ascending or descending order
378+
379+
Returns
380+
(Tensor, LongTensor): sorted values and indexes
378381
"""
379382
index2sample = _size_to_index(size)
380383
index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
@@ -445,6 +448,21 @@ def variadic_sample(input, size, num_sample):
445448

446449

447450
def variadic_meshgrid(input1, size1, input2, size2):
451+
"""
452+
Compute the Cartesian product for two batches of sets with variadic sizes.
453+
454+
Suppose there are :math:`N` sets in each input,
455+
and the sizes of all sets are summed to :math:`B_1` and :math:`B_2` respectively.
456+
457+
Parameters:
458+
input1 (Tensor): input of shape :math:`(B_1, ...)`
459+
size1 (LongTensor): size of :attr:`input1` of shape :math:`(N,)`
460+
input2 (Tensor): input of shape :math:`(B_2, ...)`
461+
size2 (LongTensor): size of :attr:`input2` of shape :math:`(N,)`
462+
463+
Returns
464+
(Tensor, Tensor): the first and the second elements in the Cartesian product
465+
"""
448466
grid_size = size1 * size2
449467
local_index = variadic_arange(grid_size)
450468
local_inner_size = size2.repeat_interleave(grid_size)
@@ -456,6 +474,19 @@ def variadic_meshgrid(input1, size1, input2, size2):
456474

457475

458476
def variadic_to_padded(input, size, value=0):
477+
"""
478+
Convert a variadic tensor to a padded tensor.
479+
480+
Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
481+
482+
Parameters:
483+
input (Tensor): input of shape :math:`(B, ...)`
484+
size (LongTensor): size of sets of shape :math:`(N,)`
485+
value (scalar): fill value for padding
486+
487+
Returns:
488+
(Tensor, BoolTensor): padded tensor and mask
489+
"""
459490
num_sample = len(size)
460491
max_size = size.max()
461492
starts = torch.arange(num_sample, device=size.device) * max_size
@@ -469,6 +500,13 @@ def variadic_to_padded(input, size, value=0):
469500

470501

471502
def padded_to_variadic(padded, size):
503+
"""
504+
Convert a padded tensor to a variadic tensor.
505+
506+
Parameters:
507+
padded (Tensor): padded tensor of shape :math:`(N, ...)`
508+
size (LongTensor): size of sets of shape :math:`(N,)`
509+
"""
472510
num_sample, max_size = padded.shape[:2]
473511
starts = torch.arange(num_sample, device=size.device) * max_size
474512
ends = starts + size

torchdrug/transforms/transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
class TargetNormalize(object):
1212
"""
1313
Normalize the target values in a sample.
14+
1415
Parameters:
1516
mean (dict of float): mean of targets
1617
std (dict of float): standard deviation of targets

0 commit comments

Comments
 (0)