Skip to content

Commit

Permalink
Open catalyst (ORNL#211)
Browse files Browse the repository at this point in the history
* open catalysts pre-load scripts added

* train.py added

* data loading, data processing, and training added

* test set changed to validation because we need labels

* exception handler in data loader removed

* skip command added if len(traj_logs) != len(traj_frames)

* fix adios

* removing second np.split

* uncompress file added

* energy float converted into tensor

* add update_predicted_values in adios reader

* add var_config for adios

* file reader uses glob.iglob

* minor update

* choice of adios or pickle in pre-loading

* minor fixes

* formatting fixed

* black reformatting of utils/adiosdataset.py

* black reformarring of scripts in examples/open_catalyst_2020

* Dataset section in the JSON file is removed

---------

Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
Co-authored-by: Jong Choi <[email protected]>
Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
Co-authored-by: Jong Choi <[email protected]>
Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
Co-authored-by: Jong Choi <[email protected]>
Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
  • Loading branch information
10 people authored Mar 18, 2024
1 parent af7ed45 commit 8b7f542
Show file tree
Hide file tree
Showing 9 changed files with 982 additions and 4 deletions.
153 changes: 153 additions & 0 deletions examples/open_catalyst_2020/download_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import argparse
import glob
import logging
import os

"""
This script provides users with an automated way to download, preprocess (where
applicable), and organize data to readily be used by the existing config files.
"""

DOWNLOAD_LINKS = {
"s2ef": {
"200k": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_200K.tar",
"2M": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_2M.tar",
"20M": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_20M.tar",
"all": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_all.tar",
"val_id": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_id.tar",
"val_ood_ads": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_ads.tar",
"val_ood_cat": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_cat.tar",
"val_ood_both": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_both.tar",
"test": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_test_lmdbs.tar.gz",
"rattled": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_rattled.tar",
"md": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_md.tar",
},
"is2re": "https://dl.fbaipublicfiles.com/opencatalystproject/data/is2res_train_val_test_lmdbs.tar.gz",
}

S2EF_COUNTS = {
"s2ef": {
"200k": 200000,
"2M": 2000000,
"20M": 20000000,
"all": 133934018,
"val_id": 999866,
"val_ood_ads": 999838,
"val_ood_cat": 999809,
"val_ood_both": 999944,
"rattled": 16677031,
"md": 38315405,
},
}


def get_data(datadir, task, split, del_intmd_files):
os.makedirs(datadir, exist_ok=True)

if task == "s2ef" and split is None:
raise NotImplementedError("S2EF requires a split to be defined.")

if task == "s2ef":
assert (
split in DOWNLOAD_LINKS[task]
), f'S2EF "{split}" split not defined, please specify one of the following: {list(DOWNLOAD_LINKS["s2ef"].keys())}'
download_link = DOWNLOAD_LINKS[task][split]

elif task == "is2re":
download_link = DOWNLOAD_LINKS[task]

os.system(f"wget {download_link} -P {datadir}")
filename = os.path.join(datadir, os.path.basename(download_link))
logging.info("Extracting contents...")
os.system(f"tar -xvf {filename} -C {datadir}")
dirname = os.path.join(
datadir,
os.path.basename(filename).split(".")[0],
)
if task == "s2ef" and split != "test":
compressed_dir = os.path.join(dirname, os.path.basename(dirname))
if split in ["200k", "2M", "20M", "all", "rattled", "md"]:
output_path = os.path.join(datadir, task, split, "train")
else:
output_path = os.path.join(datadir, task, "all", split)
uncompressed_dir = uncompress_data(compressed_dir)
# preprocess_data(uncompressed_dir, output_path)

# verify_count(output_path, task, split)
if task == "s2ef" and split == "test":
if not (os.path.exists(f"{datadir}/s2ef/all")):
os.makedirs(f"{datadir}/s2ef/all")
os.system(f"mv {dirname}/test_data/s2ef/all/test_* {datadir}/s2ef/all")
elif task == "is2re":
os.system(f"mv {dirname}/data/is2re {datadir}")

# if del_intmd_files:
# cleanup(filename, dirname)


def uncompress_data(compressed_dir):
import uncompress

parser = uncompress.get_parser()
args, _ = parser.parse_known_args()
args.ipdir = compressed_dir
args.opdir = os.path.dirname(compressed_dir) + "_uncompressed"
uncompress.main(args)
return args.opdir


def preprocess_data(uncompressed_dir, output_path):
import preprocess as preprocess_ad

args.data_path = uncompressed_dir
args.out_path = output_path
preprocess_ad.main(args)


def verify_count(output_path, task, split):
paths = glob.glob(os.path.join(output_path, "*.txt"))
count = 0
for path in paths:
lines = open(path, "r").read().splitlines()
count += len(lines)
assert (
count == S2EF_COUNTS[task][split]
), f"S2EF {split} count incorrect, verify preprocessing has completed successfully."


def cleanup(filename, dirname):
import shutil

if os.path.exists(filename):
os.remove(filename)
if os.path.exists(dirname):
shutil.rmtree(dirname)
if os.path.exists(dirname + "_uncompressed"):
shutil.rmtree(dirname + "_uncompressed")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, help="Task to download")
parser.add_argument(
"--split", type=str, help="Corresponding data split to download"
)
parser.add_argument(
"--keep",
action="store_true",
help="Keep intermediate directories and files upon data retrieval/processing",
)
parser.add_argument(
"--data-path",
type=str,
default="./dataset",
help="Specify path to save dataset. Defaults to './dataset'",
)

args, _ = parser.parse_known_args()
get_data(
datadir=args.data_path,
task=args.task,
split=args.split,
del_intmd_files=not args.keep,
)
58 changes: 58 additions & 0 deletions examples/open_catalyst_2020/open_catalyst_energy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"Verbosity": {
"level": 2
},
"NeuralNetwork": {
"Architecture": {
"model_type": "EGNN",
"equivariance": true,
"radius": 5.0,
"max_neighbours": 100000,
"num_gaussians": 50,
"envelope_exponent": 5,
"int_emb_size": 64,
"basis_emb_size": 8,
"out_emb_size": 128,
"num_after_skip": 2,
"num_before_skip": 1,
"num_radial": 6,
"num_spherical": 7,
"num_filters": 126,
"edge_features": ["coord_x", "coord_y", "coord_z"],
"hidden_dim": 50,
"num_conv_layers": 3,
"output_heads": {
"graph":{
"num_sharedlayers": 2,
"dim_sharedlayers": 50,
"num_headlayers": 2,
"dim_headlayers": [50,25]
}
},
"task_weights": [1.0]
},
"Variables_of_interest": {
"input_node_features": [0, 1, 2, 3],
"output_names": ["energy"],
"output_index": [0],
"output_dim": [1],
"type": ["graph"]
},
"Training": {
"num_epoch": 50,
"perc_train": 0.8,
"loss_function_type": "mae",
"batch_size": 32,
"continue": 0,
"Optimizer": {
"type": "AdamW",
"learning_rate": 1e-3
}
}
},
"Visualization": {
"plot_init_solution": true,
"plot_hist_solution": false,
"create_plots": true
}
}
58 changes: 58 additions & 0 deletions examples/open_catalyst_2020/open_catalyst_forces.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"Verbosity": {
"level": 2
},
"NeuralNetwork": {
"Architecture": {
"model_type": "EGNN",
"equivariance": true,
"radius": 5.0,
"max_neighbours": 100000,
"num_gaussians": 50,
"envelope_exponent": 5,
"int_emb_size": 64,
"basis_emb_size": 8,
"out_emb_size": 128,
"num_after_skip": 2,
"num_before_skip": 1,
"num_radial": 6,
"num_spherical": 7,
"num_filters": 126,
"edge_features": ["coord_x", "coord_y", "coord_z"],
"hidden_dim": 50,
"num_conv_layers": 3,
"output_heads": {
"node": {
"num_headlayers": 2,
"dim_headlayers": [200,200],
"type": "mlp"
}
},
"task_weights": [1.0]
},
"Variables_of_interest": {
"input_node_features": [0, 1, 2, 3],
"output_names": ["forces"],
"output_index": [2],
"output_dim": [3],
"type": ["node"]
},
"Training": {
"num_epoch": 50,
"EarlyStopping": true,
"perc_train": 0.9,
"loss_function_type": "mae",
"batch_size": 32,
"continue": 0,
"Optimizer": {
"type": "AdamW",
"learning_rate": 1e-3
}
}
},
"Visualization": {
"plot_init_solution": true,
"plot_hist_solution": false,
"create_plots": true
}
}
Loading

0 comments on commit 8b7f542

Please sign in to comment.