Skip to content

Commit

Permalink
bug fixed in x_node updates for convolutional heads
Browse files Browse the repository at this point in the history
  • Loading branch information
allaffa committed Jan 22, 2024
1 parent 3715acf commit fc19a97
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
7 changes: 3 additions & 4 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,10 @@ def forward(self, data):
else:
if self.node_NN_type == "conv":
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
# x_node = self.activation_function(
# batch_norm(conv(x=x, pos=pos, **conv_args))
# )
c, pos = conv(x=x, pos=pos, **conv_args)
x_node = self.activation_function(c)
c = batch_norm(c)
x = self.activation_function(c)
x_node = x
else:
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False
if use_lengths and "vector" in ci_input:
thresholds["PNA"] = [0.2, 0.15]
if ci_input == "ci_conv_head.json":
thresholds["PNA"] = [0.5, 0.5]
thresholds["EGNN"] = [0.5, 0.5]
thresholds["SchNet"] = [0.5, 0.5]
thresholds["PNA"] = [0.2, 0.2]
thresholds["EGNN"] = [0.2, 0.2]
thresholds["SchNet"] = [0.2, 0.2]
verbosity = 2

for ihead in range(len(true_values)):
Expand Down

0 comments on commit fc19a97

Please sign in to comment.