Skip to content

Commit

Permalink
black reformarring of scripts in examples/open_catalyst_2020
Browse files Browse the repository at this point in the history
  • Loading branch information
allaffa committed Mar 14, 2024
1 parent 7e47fcf commit fa100bd
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 40 deletions.
3 changes: 2 additions & 1 deletion examples/open_catalyst_2020/download_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_data(datadir, task, split, del_intmd_files):

# verify_count(output_path, task, split)
if task == "s2ef" and split == "test":
if not(os.path.exists(f"{datadir}/s2ef/all")):
if not (os.path.exists(f"{datadir}/s2ef/all")):
os.makedirs(f"{datadir}/s2ef/all")
os.system(f"mv {dirname}/test_data/s2ef/all/test_* {datadir}/s2ef/all")
elif task == "is2re":
Expand All @@ -98,6 +98,7 @@ def uncompress_data(compressed_dir):

def preprocess_data(uncompressed_dir, output_path):
import preprocess as preprocess_ad

args.data_path = uncompressed_dir
args.out_path = output_path
preprocess_ad.main(args)
Expand Down
54 changes: 30 additions & 24 deletions examples/open_catalyst_2020/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import subprocess
from hydragnn.utils import nsplit


def info(*args, logtype="info", sep=" "):
getattr(logging, logtype)(sep.join(map(str, args)))

Expand All @@ -48,7 +49,6 @@ def info(*args, logtype="info", sep=" "):


class OpenCatalystDataset(AbstractBaseDataset):

def __init__(self, dirpath, var_config, data_type, dist=False):
super().__init__()

Expand Down Expand Up @@ -77,20 +77,18 @@ def __init__(self, dirpath, var_config, data_type, dist=False):
rx = list(nsplit(range(mx), self.world_size))[self.rank]
chunked_txt_files = list()
for n in rx:
fname = os.path.join(self.data_path, "%d.txt"%n)
fname = os.path.join(self.data_path, "%d.txt" % n)
chunked_txt_files.append(fname)

if len(chunked_txt_files) == 0:
print(self.rank, "WARN: No files to process. Continue ...")

# Initialize feature extractor.
a2g = AtomsToGraphs(
max_neigh=50,
radius=6,
r_pbc=False
)
a2g = AtomsToGraphs(max_neigh=50, radius=6, r_pbc=False)

self.dataset.extend(write_images_to_adios(a2g, chunked_txt_files, self.data_path))
self.dataset.extend(
write_images_to_adios(a2g, chunked_txt_files, self.data_path)
)

def len(self):
return len(self.dataset)
Expand All @@ -113,10 +111,16 @@ def get(self, idx):
"--inputfile", help="input file", type=str, default="open_catalyst_energy.json"
)
parser.add_argument(
"--train_path", help="path to training data", type=str, default="s2ef_train_200K_uncompressed"
"--train_path",
help="path to training data",
type=str,
default="s2ef_train_200K_uncompressed",
)
parser.add_argument(
"--test_path", help="path to testing data", type=str, default="s2ef_val_id_uncompressed"
"--test_path",
help="path to testing data",
type=str,
default="s2ef_val_id_uncompressed",
)
parser.add_argument("--ddstore", action="store_true", help="ddstore dataset")
parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None)
Expand Down Expand Up @@ -190,10 +194,7 @@ def get(self, idx):
if args.preonly:
## local data
trainset = OpenCatalystDataset(
os.path.join(datadir),
var_config,
data_type=args.train_path,
dist=True
os.path.join(datadir), var_config, data_type=args.train_path, dist=True
)
## This is a local split
trainset, valset1, valset2 = split_dataset(
Expand All @@ -203,10 +204,7 @@ def get(self, idx):
)
valset = [*valset1, *valset2]
testset = OpenCatalystDataset(
os.path.join(datadir),
var_config,
data_type=args.test_path,
dist=True
os.path.join(datadir), var_config, data_type=args.test_path, dist=True
)
## Need as a list
testset = testset[:]
Expand All @@ -218,8 +216,10 @@ def get(self, idx):
setnames = ["trainset", "valset", "testset"]

## adios
if args.format=="adios":
fname = os.path.join(os.path.dirname(__file__), "./dataset/%s.bp" % modelname)
if args.format == "adios":
fname = os.path.join(
os.path.dirname(__file__), "./dataset/%s.bp" % modelname
)
adwriter = AdiosWriter(fname, comm)
adwriter.add("trainset", trainset)
adwriter.add("valset", valset)
Expand All @@ -230,7 +230,7 @@ def get(self, idx):
adwriter.save()

## pickle
elif args.format=="pickle":
elif args.format == "pickle":
basedir = os.path.join(
os.path.dirname(__file__), "dataset", "%s.pickle" % modelname
)
Expand Down Expand Up @@ -286,9 +286,15 @@ def get(self, idx):
basedir = os.path.join(
os.path.dirname(__file__), "dataset", "%s.pickle" % modelname
)
trainset = SimplePickleDataset(basedir=basedir, label="trainset", var_config=var_config)
valset = SimplePickleDataset(basedir=basedir, label="valset", var_config=var_config)
testset = SimplePickleDataset(basedir=basedir, label="testset", var_config=var_config)
trainset = SimplePickleDataset(
basedir=basedir, label="trainset", var_config=var_config
)
valset = SimplePickleDataset(
basedir=basedir, label="valset", var_config=var_config
)
testset = SimplePickleDataset(
basedir=basedir, label="testset", var_config=var_config
)
# minmax_node_feature = trainset.minmax_node_feature
# minmax_graph_feature = trainset.minmax_graph_feature
pna_deg = trainset.pna_deg
Expand Down
4 changes: 1 addition & 3 deletions examples/open_catalyst_2020/uncompress.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def main(args: argparse.Namespace) -> None:
ip_op_pairs: List[Tuple[str, str]] = []
for filename in filelist:
fname_base = os.path.basename(filename)
ip_op_pairs.append(
(filename, os.path.join(args.opdir, fname_base[:-3]))
)
ip_op_pairs.append((filename, os.path.join(args.opdir, fname_base[:-3])))

pool = mp.Pool(args.num_workers)
list(
Expand Down
13 changes: 5 additions & 8 deletions examples/open_catalyst_2020/utils/atoms_to_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

from hydragnn.preprocess.utils import RadiusGraph, RadiusGraphPBC

#transform_coordinates = Spherical(norm=False, cat=False)
# transform_coordinates = Spherical(norm=False, cat=False)
transform_coordinates = LocalCartesian(norm=False, cat=False)


class AtomsToGraphs:
"""A class to help convert periodic atomic structures to graphs.
Expand Down Expand Up @@ -55,15 +56,11 @@ def __init__(

if self.r_pbc:
self.radius_graph = RadiusGraphPBC(
self.radius,
loop=False,
max_num_neighbors=self.max_neigh
self.radius, loop=False, max_num_neighbors=self.max_neigh
)
else:
self.radius_graph = RadiusGraph(
self.radius,
loop=False,
max_num_neighbors=self.max_neigh
self.radius, loop=False, max_num_neighbors=self.max_neigh
)

def convert(
Expand Down Expand Up @@ -165,4 +162,4 @@ def convert_all(
data, slices = collate(data_list)
torch.save((data, slices), processed_file_path)

return data_list
return data_list
10 changes: 6 additions & 4 deletions examples/open_catalyst_2020/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ def write_images_to_adios(a2g, samples, data_path, subtract_reference_energy=Fal

dataset = []
idx = 0

rank = torch.distributed.get_rank()
for sample in tqdm(samples, desc=os.path.basename(data_path), disable=False if rank == 0 else True):
for sample in tqdm(
samples, desc=os.path.basename(data_path), disable=False if rank == 0 else True
):
try:
traj_logs = open(sample, "r").read().splitlines()
xyz_idx = os.path.splitext(os.path.basename(sample))[0]
Expand All @@ -33,8 +35,8 @@ def write_images_to_adios(a2g, samples, data_path, subtract_reference_energy=Fal
continue

if len(traj_logs) != len(traj_frames):
## let's skip
continue
## let's skip
continue
except Exception as e:
print(f"WARN:", type(error).__name__)
continue
Expand Down

0 comments on commit fa100bd

Please sign in to comment.