Skip to content

Commit

Permalink
hyperparameter adjust for conv_head tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pzhanggit committed Mar 4, 2024
1 parent a60dfeb commit 6126b6b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
12 changes: 6 additions & 6 deletions tests/inputs/ci_conv_head.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
"num_spherical": 7,
"num_filters": 126,
"periodic_boundary_conditions": false,
"hidden_dim": 50,
"num_conv_layers": 4,
"hidden_dim": 20,
"num_conv_layers": 2,
"output_heads": {
"node": {
"num_headlayers": 2,
"dim_headlayers": [50,50],
"dim_headlayers": [20,10],
"type": "conv"
}
},
Expand All @@ -60,10 +60,10 @@
"Training": {
"num_epoch": 100,
"perc_train": 0.7,
"EarlyStopping": true,
"patience": 10,
"EarlyStopping": false,
"patience": 10,
"loss_function_type": "mse",
"batch_size": 1,
"batch_size": 32,
"Optimizer": {
"type": "AdamW",
"use_zero_redundancy": false,
Expand Down
8 changes: 2 additions & 6 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False
thresholds["PNA"] = [0.10, 0.10]
if use_lengths and "vector" in ci_input:
thresholds["PNA"] = [0.2, 0.15]
if ci_input == "ci_conv_head.json":
thresholds["PNA"] = [0.3, 0.3]
thresholds["EGNN"] = [0.5, 0.5]
thresholds["SchNet"] = [0.6, 0.6]
verbosity = 2

for ihead in range(len(true_values)):
Expand Down Expand Up @@ -207,7 +203,7 @@ def pytest_train_model_vectoroutput(model_type, overwrite_data=False):

@pytest.mark.parametrize(
"model_type",
["PNA", "EGNN", "SchNet"],
["SAGE", "GIN", "GAT", "MFC", "PNA", "SchNet", "DimeNet", "EGNN"]
)
def pytest_train_model_conv_head(model_type, overwrite_data=False):
unittest_train_model(model_type, "ci_conv_head.json", True, overwrite_data)
unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data)

0 comments on commit 6126b6b

Please sign in to comment.