forked from ORNL/HydraGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* open catalysts pre-load scripts added * train.py added * data loading, data processing, and training added * test set changed to validation because we need labels * exception handler in data loader removed * skip command added if len(traj_logs) != len(traj_frames) * fix adios * removing second np.split * uncompress file added * energy float converted into tensor * add update_predicted_values in adios reader * add var_config for adios * file reader uses glob.iglob * minor update * choice of adios or pickle in pre-loading * minor fixes * formatting fixed * black reformatting of utils/adiosdataset.py * black reformarring of scripts in examples/open_catalyst_2020 * Dataset section in the JSON file is removed --------- Co-authored-by: Massimiliano Lupo Pasini <[email protected]> Co-authored-by: Massimiliano Lupo Pasini <[email protected]> Co-authored-by: Jong Choi <[email protected]> Co-authored-by: Massimiliano Lupo Pasini <[email protected]> Co-authored-by: Massimiliano Lupo Pasini <[email protected]> Co-authored-by: Jong Choi <[email protected]> Co-authored-by: Massimiliano Lupo Pasini <[email protected]> Co-authored-by: Jong Choi <[email protected]> Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
- Loading branch information
1 parent
af7ed45
commit 8b7f542
Showing
9 changed files
with
982 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import argparse | ||
import glob | ||
import logging | ||
import os | ||
|
||
""" | ||
This script provides users with an automated way to download, preprocess (where | ||
applicable), and organize data to readily be used by the existing config files. | ||
""" | ||
|
||
DOWNLOAD_LINKS = { | ||
"s2ef": { | ||
"200k": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_200K.tar", | ||
"2M": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_2M.tar", | ||
"20M": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_20M.tar", | ||
"all": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_all.tar", | ||
"val_id": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_id.tar", | ||
"val_ood_ads": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_ads.tar", | ||
"val_ood_cat": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_cat.tar", | ||
"val_ood_both": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_both.tar", | ||
"test": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_test_lmdbs.tar.gz", | ||
"rattled": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_rattled.tar", | ||
"md": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_md.tar", | ||
}, | ||
"is2re": "https://dl.fbaipublicfiles.com/opencatalystproject/data/is2res_train_val_test_lmdbs.tar.gz", | ||
} | ||
|
||
S2EF_COUNTS = { | ||
"s2ef": { | ||
"200k": 200000, | ||
"2M": 2000000, | ||
"20M": 20000000, | ||
"all": 133934018, | ||
"val_id": 999866, | ||
"val_ood_ads": 999838, | ||
"val_ood_cat": 999809, | ||
"val_ood_both": 999944, | ||
"rattled": 16677031, | ||
"md": 38315405, | ||
}, | ||
} | ||
|
||
|
||
def get_data(datadir, task, split, del_intmd_files): | ||
os.makedirs(datadir, exist_ok=True) | ||
|
||
if task == "s2ef" and split is None: | ||
raise NotImplementedError("S2EF requires a split to be defined.") | ||
|
||
if task == "s2ef": | ||
assert ( | ||
split in DOWNLOAD_LINKS[task] | ||
), f'S2EF "{split}" split not defined, please specify one of the following: {list(DOWNLOAD_LINKS["s2ef"].keys())}' | ||
download_link = DOWNLOAD_LINKS[task][split] | ||
|
||
elif task == "is2re": | ||
download_link = DOWNLOAD_LINKS[task] | ||
|
||
os.system(f"wget {download_link} -P {datadir}") | ||
filename = os.path.join(datadir, os.path.basename(download_link)) | ||
logging.info("Extracting contents...") | ||
os.system(f"tar -xvf {filename} -C {datadir}") | ||
dirname = os.path.join( | ||
datadir, | ||
os.path.basename(filename).split(".")[0], | ||
) | ||
if task == "s2ef" and split != "test": | ||
compressed_dir = os.path.join(dirname, os.path.basename(dirname)) | ||
if split in ["200k", "2M", "20M", "all", "rattled", "md"]: | ||
output_path = os.path.join(datadir, task, split, "train") | ||
else: | ||
output_path = os.path.join(datadir, task, "all", split) | ||
uncompressed_dir = uncompress_data(compressed_dir) | ||
# preprocess_data(uncompressed_dir, output_path) | ||
|
||
# verify_count(output_path, task, split) | ||
if task == "s2ef" and split == "test": | ||
if not (os.path.exists(f"{datadir}/s2ef/all")): | ||
os.makedirs(f"{datadir}/s2ef/all") | ||
os.system(f"mv {dirname}/test_data/s2ef/all/test_* {datadir}/s2ef/all") | ||
elif task == "is2re": | ||
os.system(f"mv {dirname}/data/is2re {datadir}") | ||
|
||
# if del_intmd_files: | ||
# cleanup(filename, dirname) | ||
|
||
|
||
def uncompress_data(compressed_dir): | ||
import uncompress | ||
|
||
parser = uncompress.get_parser() | ||
args, _ = parser.parse_known_args() | ||
args.ipdir = compressed_dir | ||
args.opdir = os.path.dirname(compressed_dir) + "_uncompressed" | ||
uncompress.main(args) | ||
return args.opdir | ||
|
||
|
||
def preprocess_data(uncompressed_dir, output_path): | ||
import preprocess as preprocess_ad | ||
|
||
args.data_path = uncompressed_dir | ||
args.out_path = output_path | ||
preprocess_ad.main(args) | ||
|
||
|
||
def verify_count(output_path, task, split): | ||
paths = glob.glob(os.path.join(output_path, "*.txt")) | ||
count = 0 | ||
for path in paths: | ||
lines = open(path, "r").read().splitlines() | ||
count += len(lines) | ||
assert ( | ||
count == S2EF_COUNTS[task][split] | ||
), f"S2EF {split} count incorrect, verify preprocessing has completed successfully." | ||
|
||
|
||
def cleanup(filename, dirname): | ||
import shutil | ||
|
||
if os.path.exists(filename): | ||
os.remove(filename) | ||
if os.path.exists(dirname): | ||
shutil.rmtree(dirname) | ||
if os.path.exists(dirname + "_uncompressed"): | ||
shutil.rmtree(dirname + "_uncompressed") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--task", type=str, help="Task to download") | ||
parser.add_argument( | ||
"--split", type=str, help="Corresponding data split to download" | ||
) | ||
parser.add_argument( | ||
"--keep", | ||
action="store_true", | ||
help="Keep intermediate directories and files upon data retrieval/processing", | ||
) | ||
parser.add_argument( | ||
"--data-path", | ||
type=str, | ||
default="./dataset", | ||
help="Specify path to save dataset. Defaults to './dataset'", | ||
) | ||
|
||
args, _ = parser.parse_known_args() | ||
get_data( | ||
datadir=args.data_path, | ||
task=args.task, | ||
split=args.split, | ||
del_intmd_files=not args.keep, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
{ | ||
"Verbosity": { | ||
"level": 2 | ||
}, | ||
"NeuralNetwork": { | ||
"Architecture": { | ||
"model_type": "EGNN", | ||
"equivariance": true, | ||
"radius": 5.0, | ||
"max_neighbours": 100000, | ||
"num_gaussians": 50, | ||
"envelope_exponent": 5, | ||
"int_emb_size": 64, | ||
"basis_emb_size": 8, | ||
"out_emb_size": 128, | ||
"num_after_skip": 2, | ||
"num_before_skip": 1, | ||
"num_radial": 6, | ||
"num_spherical": 7, | ||
"num_filters": 126, | ||
"edge_features": ["coord_x", "coord_y", "coord_z"], | ||
"hidden_dim": 50, | ||
"num_conv_layers": 3, | ||
"output_heads": { | ||
"graph":{ | ||
"num_sharedlayers": 2, | ||
"dim_sharedlayers": 50, | ||
"num_headlayers": 2, | ||
"dim_headlayers": [50,25] | ||
} | ||
}, | ||
"task_weights": [1.0] | ||
}, | ||
"Variables_of_interest": { | ||
"input_node_features": [0, 1, 2, 3], | ||
"output_names": ["energy"], | ||
"output_index": [0], | ||
"output_dim": [1], | ||
"type": ["graph"] | ||
}, | ||
"Training": { | ||
"num_epoch": 50, | ||
"perc_train": 0.8, | ||
"loss_function_type": "mae", | ||
"batch_size": 32, | ||
"continue": 0, | ||
"Optimizer": { | ||
"type": "AdamW", | ||
"learning_rate": 1e-3 | ||
} | ||
} | ||
}, | ||
"Visualization": { | ||
"plot_init_solution": true, | ||
"plot_hist_solution": false, | ||
"create_plots": true | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
{ | ||
"Verbosity": { | ||
"level": 2 | ||
}, | ||
"NeuralNetwork": { | ||
"Architecture": { | ||
"model_type": "EGNN", | ||
"equivariance": true, | ||
"radius": 5.0, | ||
"max_neighbours": 100000, | ||
"num_gaussians": 50, | ||
"envelope_exponent": 5, | ||
"int_emb_size": 64, | ||
"basis_emb_size": 8, | ||
"out_emb_size": 128, | ||
"num_after_skip": 2, | ||
"num_before_skip": 1, | ||
"num_radial": 6, | ||
"num_spherical": 7, | ||
"num_filters": 126, | ||
"edge_features": ["coord_x", "coord_y", "coord_z"], | ||
"hidden_dim": 50, | ||
"num_conv_layers": 3, | ||
"output_heads": { | ||
"node": { | ||
"num_headlayers": 2, | ||
"dim_headlayers": [200,200], | ||
"type": "mlp" | ||
} | ||
}, | ||
"task_weights": [1.0] | ||
}, | ||
"Variables_of_interest": { | ||
"input_node_features": [0, 1, 2, 3], | ||
"output_names": ["forces"], | ||
"output_index": [2], | ||
"output_dim": [3], | ||
"type": ["node"] | ||
}, | ||
"Training": { | ||
"num_epoch": 50, | ||
"EarlyStopping": true, | ||
"perc_train": 0.9, | ||
"loss_function_type": "mae", | ||
"batch_size": 32, | ||
"continue": 0, | ||
"Optimizer": { | ||
"type": "AdamW", | ||
"learning_rate": 1e-3 | ||
} | ||
} | ||
}, | ||
"Visualization": { | ||
"plot_init_solution": true, | ||
"plot_hist_solution": false, | ||
"create_plots": true | ||
} | ||
} |
Oops, something went wrong.