You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi DGL team, I’m kindly following up with my Slack messages.
I’m attempting to use dgl.nn.pytorch.explain.GNNExplainer to provide edge-level explanations for a heterogeneous graph transformer with dgl.nn.pytorch.conv.HGTConv layers. It seems, from the documentation, that “the required arguments of its forward function are graph, feat, and eweight (taken optionally). The feat argument is for input node features.
First, I've modified the HGTConv forward function to take the eweight argument as follows. May you please advise if this is correct?
Updated HGTConv code
importmathimporttypesfromdglimportfunctionasfnfromdgl.nn.pytorchimportTypedLinearfromdgl.nn.pytorchimportedge_softmaxdefforward_exp(self, g, x, ntype, etype, *, presorted=False, eweight=None):
"""Forward computation. Parameters ---------- g : DGLGraph The input graph. x : torch.Tensor A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`. ntype : torch.Tensor An 1D integer tensor of node types. Shape: :math:`(|V|,)`. etype : torch.Tensor An 1D integer tensor of edge types. Shape: :math:`(|E|,)`. presorted : bool, optional Whether *both* the nodes and the edges of the input graph have been sorted by their types. Forward on pre-sorted graph may be faster. Graphs created by :func:`~dgl.to_homogeneous` automatically satisfy the condition. Also see :func:`~dgl.reorder_graph` for manually reordering the nodes and edges. Returns ------- torch.Tensor New node features. Shape: :math:`(|V|, D_{head} * N_{head})`. """self.presorted=presortedifg.is_block:
x_src=xx_dst=x[: g.num_dst_nodes()]
srcntype=ntypedstntype=ntype[: g.num_dst_nodes()]
else:
x_src=xx_dst=xsrcntype=ntypedstntype=ntypewithg.local_scope():
k=self.linear_k(x_src, srcntype, presorted).view(
-1, self.num_heads, self.head_size
)
q=self.linear_q(x_dst, dstntype, presorted).view(
-1, self.num_heads, self.head_size
)
v=self.linear_v(x_src, srcntype, presorted).view(
-1, self.num_heads, self.head_size
)
g.srcdata["k"] =kg.dstdata["q"] =qg.srcdata["v"] =vg.edata["etype"] =etypeg.apply_edges(self.message)
g.edata["m"] =g.edata["m"] *edge_softmax(
g, g.edata["a"]
).unsqueeze(-1)
# Update for GNNExplainerifeweightisnotNone:
# Multiply messages by edge weightseweight=eweight.view(g.edata['m'].shape[0], 1, 1)
g.edata['m'] =g.edata['m'] *eweightg.update_all(fn.copy_e("m", "m"), fn.sum('m', 'h'))
h=g.dstdata["h"].view(-1, self.num_heads*self.head_size)
# target-specific aggregationh=self.drop(self.linear_a(h, dstntype, presorted))
alpha=torch.sigmoid(self.skip[dstntype]).unsqueeze(-1)
ifx_dst.shape!=h.shape:
h=h*alpha+ (x_dst @ self.residual_w) * (1-alpha)
else:
h=h*alpha+x_dst* (1-alpha)
ifself.use_norm:
h=self.norm(h)
returnh
I then update the layers in my model with, for example:
# Replace the forward methodmodel.conv1.forward=types.MethodType(forward_exp, model.conv1)
Critically, it seems that the current implementation of GNNExplainer is limited to node and graph explanations via the explain_node() and explain_graph() functions, respectively, but this is not a limitation in the original paper. What I would need is a function like:
explain_edge(edge_id, graph, feat, **kwargs)
which also takes an edge_id argument.
May you please advise if it would be possible to use the current implementation of GNNExplainer in DGL to provide edge explanations. If so, I would appreciate your guidance on how to implement this method (is this in the roadmap already? should I start with the source code for explain_node()?); if not, please let me know if there are other explainability methods implemented in DGL that you could recommend instead for this task.
Hi DGL team, I’m kindly following up with my Slack messages.
I’m attempting to use
dgl.nn.pytorch.explain.GNNExplainer
to provide edge-level explanations for a heterogeneous graph transformer withdgl.nn.pytorch.conv.HGTConv
layers. It seems, from the documentation, that “the required arguments of its forward function aregraph
,feat
, andeweight
(taken optionally). Thefeat
argument is for input node features.First, I've modified the
HGTConv
forward function to take theeweight
argument as follows. May you please advise if this is correct?Updated
HGTConv
codeI then update the layers in my model with, for example:
Critically, it seems that the current implementation of
GNNExplainer
is limited to node and graph explanations via theexplain_node()
andexplain_graph()
functions, respectively, but this is not a limitation in the original paper. What I would need is a function like:which also takes an
edge_id
argument.May you please advise if it would be possible to use the current implementation of
GNNExplainer
in DGL to provide edge explanations. If so, I would appreciate your guidance on how to implement this method (is this in the roadmap already? should I start with the source code forexplain_node()
?); if not, please let me know if there are other explainability methods implemented in DGL that you could recommend instead for this task.Thank you!
cc: @marinkaz; from the Slack conversation: @frozenbugs @jermainewang and team
The text was updated successfully, but these errors were encountered: