Skip to content

Commit 4deced0

Browse files
committed
Move getting all cloud nodes into utils
1 parent be6dd0e commit 4deced0

2 files changed

Lines changed: 25 additions & 18 deletions

File tree

citc/utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

33
import yaml
44

5+
from . import aws, oracle, cloud
6+
57

68
def load_yaml(filename: str) -> dict:
79
with open(filename, "r") as f:
@@ -14,3 +16,23 @@ def get_nodespace(file="/etc/citc/startnode.yaml") -> Dict[str, str]:
1416
This will be static for all nodes in this cluster
1517
"""
1618
return load_yaml(file)
19+
20+
21+
def get_cloud_nodes() -> List[cloud.CloudNode]:
22+
nodespace = get_nodespace()
23+
24+
csp = nodespace["csp"]
25+
if csp == "aws":
26+
ec2 = aws.ec2_client(nodespace)
27+
cloud_nodes = aws.AwsNode.all(ec2, nodespace)
28+
elif csp == "google":
29+
cloud_nodes = []
30+
elif csp == "oracle":
31+
client_config = oracle.client_config(nodespace)
32+
cloud_nodes = oracle.OracleNode.all(client_config, nodespace)
33+
elif csp == "azure":
34+
cloud_nodes = []
35+
else:
36+
raise Exception(f"Cloud provider {csp} not found")
37+
38+
return cloud_nodes

citc/watchdog.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import List, Callable, Iterator
55
from pathlib import Path
66

7-
from . import aws, slurm, cloud, oracle, utils
7+
from . import slurm, cloud, utils
88

99

1010
class SignalHandler:
@@ -86,22 +86,7 @@ def main():
8686
SLURM_CONF = Path("/mnt/shared/etc/slurm/slurm.conf")
8787

8888
while handler.alive:
89-
nodespace = utils.get_nodespace()
90-
91-
csp = nodespace["csp"]
92-
if csp == "aws":
93-
ec2 = aws.ec2_client(nodespace)
94-
cloud_nodes = aws.AwsNode.all(ec2, nodespace)
95-
elif csp == "google":
96-
cloud_nodes = []
97-
elif csp == "oracle":
98-
client_config = oracle.client_config(nodespace)
99-
cloud_nodes = oracle.OracleNode.all(client_config, nodespace)
100-
elif csp == "azure":
101-
cloud_nodes = []
102-
else:
103-
raise Exception(f"Cloud provider {csp} not found")
104-
89+
cloud_nodes = utils.get_cloud_nodes()
10590
slurm_nodes = slurm.all_nodes(SLURM_CONF)
10691

10792
for task in crosscheck(slurm_nodes, cloud_nodes):

0 commit comments

Comments
 (0)