Skip to content

Commit

Permalink
Removed '/' to make path managing platform independent
Browse files Browse the repository at this point in the history
  • Loading branch information
h-0-0 committed Nov 22, 2023
1 parent 4bba472 commit 9f1c3a1
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 70 deletions.
27 changes: 14 additions & 13 deletions src/slune/savers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SaverCsv(BaseSaver):
Saves the results of each run in a CSV file in a hierarchical directory structure based on argument names.
Handles parallel runs by waiting a random time
"""
def __init__(self, logger_instance: BaseLogger, params: List[str] = None, root_dir: Optional[str] ='./tuning_results'):
def __init__(self, logger_instance: BaseLogger, params: List[str] = None, root_dir: Optional[str] = os.path.join('.', 'tuning_results')):
super(SaverCsv, self).__init__(logger_instance)
self.root_dir = root_dir
if params != None:
Expand All @@ -38,20 +38,21 @@ def get_match(self, params: List[str]):
match = find_directory_path(stripped_params, root_directory=self.root_dir)
# Add on missing parameters
if match == self.root_dir:
match = "/".join(stripped_params)
match = os.path.join(*stripped_params)
else:
missing_params = [p for p in stripped_params if p not in match]
if missing_params != []:
match = match + '/' + '/'.join(missing_params)
match = [match] + missing_params
match = os.path.join(*match)
# Take the root directory out of the match
match = match.replace(self.root_dir, '')
if match.startswith('/'):
if match.startswith(os.path.sep):
match = match[1:]
# Now we add back in the values we stripped out
match = match.split('/')
match = match.split(os.path.sep)
match = [[p for p in params if m in p][0] for m in match]
# Check if there is an existing path with the same numerical values, if so use that instead
match = get_numeric_equiv("/".join(match), root_directory=self.root_dir)
match = get_numeric_equiv(os.path.join(*match), root_directory=self.root_dir)
return match

def get_path(self, params: List[str]):
Expand Down Expand Up @@ -94,8 +95,8 @@ def save_collated_from_results(self, results: pd.DataFrame):
"""
# If path does not exist, create it
# Remove the csv file name from the path
dir_path = self.current_path.split('/')[:-1]
dir_path = '/'.join(dir_path)
dir_path = self.current_path.split(os.path.sep)[:-1]
dir_path = os.path.join(*dir_path)
if not os.path.exists(dir_path):
time.sleep(random.random()) # Wait a random amount of time under 1 second to avoid multiple processes creating the same directory
os.makedirs(dir_path)
Expand Down Expand Up @@ -130,9 +131,9 @@ def read(self, params: List[str], metric_name: str, select_by: str ='max', avg:
values = {}
# Do averaging for different runs of same params if avg is True, otherwise just read the metric from each path
if avg:
paths_same_params = set(['/'.join(p.split('/')[:-1]) for p in paths])
paths_same_params = set([os.path.join(*p.split(os.path.sep)[:-1]) for p in paths])
for path in paths_same_params:
runs = get_all_paths(path.split('/'), root_directory=self.root_dir)
runs = get_all_paths(path.split(os.path.sep), root_directory=self.root_dir)
cumsum = 0
for r in runs:
df = pd.read_csv(r)
Expand All @@ -142,7 +143,7 @@ def read(self, params: List[str], metric_name: str, select_by: str ='max', avg:
else:
for path in paths:
df = pd.read_csv(path)
values['/'.join(path.split('/')[:-1])] = self.read_log(df, metric_name, select_by)
values[os.path.join(*path.split(os.path.sep)[:-1])] = self.read_log(df, metric_name, select_by)
# Get the key of the min/max value
if select_by == 'min':
best_params = min(values, key=values.get)
Expand All @@ -154,9 +155,9 @@ def read(self, params: List[str], metric_name: str, select_by: str ='max', avg:
best_value = values[best_params]
# Format the path into a list of arguments
best_params = best_params.replace(self.root_dir, '')
if best_params.startswith('/'):
if best_params.startswith(os.path.sep):
best_params = best_params[1:]
best_params = best_params.split('/')
best_params = best_params.split(os.path.sep)
return best_params, best_value

def exists(self, params: List[str]):
Expand Down
10 changes: 5 additions & 5 deletions src/slune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _find_directory_path(curr_strings, curr_root, depth, max_depth, max_path):
if string in stripped_dir_list:
dir_list = [d for d in dir_list if d.startswith(string)]
for d in dir_list:
new_depth, new_path = _find_directory_path([s for s in curr_strings if s != string], curr_root + '/' + d, depth + 1, max_depth, max_path)
new_depth, new_path = _find_directory_path([s for s in curr_strings if s != string], os.path.join(curr_root, d), depth + 1, max_depth, max_path)
if new_depth > max_depth:
max_depth, max_path = new_depth, new_path
if depth > max_depth:
Expand All @@ -28,9 +28,9 @@ def _find_directory_path(curr_strings, curr_root, depth, max_depth, max_path):
max_depth, max_path = _find_directory_path(strings, root_directory, 0, -1, '')
if max_depth > 0:
max_path = max_path[len(root_directory):]
dirs = max_path[1:].split('/')
dirs = max_path[1:].split(os.path.sep)
dirs = [d.split('=')[0].strip() +"=" for d in dirs]
max_path = '/'.join(dirs)
max_path = os.path.join(*dirs)
max_path = os.path.join(root_directory, max_path)
return max_path

Expand All @@ -50,7 +50,7 @@ def is_numeric(s):
except ValueError:
return False

dirs = path.split('/')
dirs = path.split(os.path.sep)
equiv = root_directory
for d in dirs:
next_dir = os.path.join(equiv, d)
Expand Down Expand Up @@ -116,7 +116,7 @@ def get_all_paths(params, root_directory='.'):
all_csv = find_csv_files(root_directory)
matches = []
for csv in all_csv:
path = csv.split('/')
path = csv.split(os.path.sep)
if all([p in path for p in params]):
matches.append(csv)
return matches
42 changes: 21 additions & 21 deletions tests/test_savers_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,47 +45,47 @@ def test_get_match_full_match(self):

# Test if get_match finds correct match and builds correct directory path using the parameters
matching_dir = saver.get_match(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"])
self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3"))
self.assertEqual(matching_dir, os.path.join(*[self.test_dir, '--folder1=0.1','--folder2=0.2','--folder3=0.3']))

def test_get_match_partial_match(self):
# Create a SaverCsv instance
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Test if get_match finds correct match and builds correct directory path using the parameters
matching_dir = saver.get_match(["--folder2=0.2", "--folder1=0.1"])
self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2"))
self.assertEqual(matching_dir, os.path.join(*[self.test_dir, '--folder1=0.1','--folder2=0.2']))

def test_get_match_partial_match_more_params(self):
# Create a SaverCsv instance
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Test if get_match finds correct match and builds correct directory path using the parameters
matching_dir = saver.get_match(["--folder1=0.1", "--folder6=0.6", "--folder5=0.5", "--folder7=0.7"])
self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder1=0.1/--folder5=0.5/--folder6=0.6/--folder7=0.7"))
self.assertEqual(matching_dir, os.path.join(*[self.test_dir, '--folder1=0.1','--folder5=0.5','--folder6=0.6','--folder7=0.7']))

def test_get_match_different_values(self):
# Create a SaverCsv instance
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Test if get_match finds correct match and builds correct directory path using the parameters
matching_dir = saver.get_match(["--folder2=2.2", "--folder1=1.1"])
self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder1=1.1/--folder2=2.2"))
self.assertEqual(matching_dir, os.path.join(*[self.test_dir, '--folder1=1.1','--folder2=2.2']))

def test_get_match_too_deep(self):
# Create a SaverCsv instance
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Test if get_match finds correct match and builds correct directory path using the parameters
matching_dir = saver.get_match(["--folder2=0.2", "--folder3=0.3"])
self.assertEqual(matching_dir, os.path.join(self.test_dir, '--folder2=0.2', '--folder3=0.3'))
self.assertEqual(matching_dir, os.path.join(*[self.test_dir, '--folder2=0.2', '--folder3=0.3']))

def test_get_match_no_match(self):
# Create a SaverCsv instance
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Test if get_match finds correct match and builds correct directory path using the parameters
matching_dir = saver.get_match(["--folder_not_there=0", "--folder_also_not_there=0.1"])
self.assertEqual(matching_dir, os.path.join(self.test_dir, "--folder_not_there=0/--folder_also_not_there=0.1"))
self.assertEqual(matching_dir, os.path.join(*[self.test_dir, '--folder_not_there=0','--folder_also_not_there=0.1']))

def test_get_match_duplicate_params(self):
# Create a SaverCsv instance
Expand Down Expand Up @@ -121,15 +121,15 @@ def test_get_path_no_results(self):

# Test if get_path gets the correct path
path = saver.get_path(["--folder5=0.5","--folder1=0.1", "--folder6=0.6"])
self.assertEqual(path, os.path.join(self.test_dir, "--folder1=0.1/--folder5=0.5/--folder6=0.6/results_0.csv"))
self.assertEqual(path, os.path.join(*[self.test_dir, '--folder1=0.1','--folder5=0.5','--folder6=0.6','results_0.csv']))

def test_get_path_already_results(self):
# Create a SaverCsv instance
saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir)

# Test if get_path gets the correct path
path = saver.get_path(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"])
self.assertEqual(path, os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3/results_1.csv"))
self.assertEqual(path, os.path.join(*[self.test_dir, '--folder1=0.1','--folder2=0.2','--folder3=0.3','results_1.csv']))

def test_save_collated(self):
# Create a SaverCsv instance
Expand All @@ -139,7 +139,7 @@ def test_save_collated(self):
# Save the results
saver.save_collated_from_results(results)
# Check if the results were saved correctly
read_results = pd.read_csv(os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3/results_1.csv"))
read_results = pd.read_csv(os.path.join(*[self.test_dir, '--folder1=0.1','--folder2=0.2','--folder3=0.3','results_1.csv']))
self.assertEqual(read_results.shape, (3,2))
self.assertEqual(results.columns.tolist(), read_results.columns.tolist())
read_values = [x for x in read_results.values.tolist() if str(x) != 'nan']
Expand All @@ -150,7 +150,7 @@ def test_save_collated(self):
# Save the results
saver.save_collated_from_results(results)
# Check if the results were saved correctly
read_results = pd.read_csv(os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3/results_1.csv"))
read_results = pd.read_csv(os.path.join(*[self.test_dir, '--folder1=0.1','--folder2=0.2','--folder3=0.3','results_1.csv']))
results = pd.concat([pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]}), results], ignore_index=True)
self.assertEqual(read_results.shape, (6,3))
self.assertEqual(results.columns.tolist(), read_results.columns.tolist())
Expand All @@ -168,7 +168,7 @@ def test_creates_path_if_no_full_match(self):
# Save the results
saver.save_collated_from_results(results)
# Check if the results were saved correctly
read_results = pd.read_csv(os.path.join(self.test_dir, "--folder1=0.01/--folder2=0.02/--folder3=0.03/results_0.csv"))
read_results = pd.read_csv(os.path.join(*[self.test_dir, '--folder1=0.01','--folder2=0.02','--folder3=0.03','results_0.csv']))
self.assertEqual(read_results.shape, (3,2))
self.assertEqual(results.columns.tolist(), read_results.columns.tolist())
read_values = [x for x in read_results.values.tolist() if str(x) != 'nan']
Expand All @@ -182,13 +182,13 @@ def test_creates_path_if_no_full_match(self):

def test_root_dir_forwardslash(self):
# Create a SaverCsv instance
saver = SaverCsv(LoggerDefault(), ["--folder3=0.3", "--folder2=0.2"], root_dir=self.test_dir+'/--folder1=0.1')
saver = SaverCsv(LoggerDefault(), ["--folder3=0.3", "--folder2=0.2"], root_dir=os.path.join(self.test_dir,'--folder1=0.1'))
# Create a data frame with some results
results = pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]})
# Save the results
saver.save_collated_from_results(results)
# Check if the results were saved correctly
read_results = pd.read_csv(os.path.join(self.test_dir, "--folder1=0.1/--folder2=0.2/--folder3=0.3/results_1.csv"))
read_results = pd.read_csv(os.path.join(*[self.test_dir, '--folder1=0.1','--folder2=0.2','--folder3=0.3','results_1.csv']))
self.assertEqual(read_results.shape, (3,2))
self.assertEqual(results.columns.tolist(), read_results.columns.tolist())
read_values = [x for x in read_results.values.tolist() if str(x) != 'nan']
Expand All @@ -206,18 +206,18 @@ def setUp(self):

# Creating some CSV files with specific subdirectory paths
self.csv_files = [
'dir1/file1.csv',
'dir2/file2.csv',
'dir1/subdir1/file3.csv',
'dir2/subdir2/file4.csv',
'dir2/subdir2/subdir3/file5.csv'
os.path.join('dir1','file1.csv'),
os.path.join('dir2','file2.csv'),
os.path.join('dir1','subdir1','file3.csv'),
os.path.join('dir2','subdir2','file4.csv'),
os.path.join('dir2','subdir2','subdir3','file5.csv')
]

for i, file in enumerate(self.csv_files):
file_path = os.path.join(self.test_dir, file)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# Create a data frame with different values for each CSV file
if file_path == 'dir2/subdir2/subdir3/file5.csv':
if file_path == os.path.join('dir2','subdir2','subdir3','file5.csv'):
results = pd.DataFrame({'a': [i+1,i+2,i+3], 'b': [i+4,i+5,i+6], 'c': [i+7,i+8,i+9]})
else:
results = pd.DataFrame({'a': [i+1,i+2,i+3], 'b': [i+4,i+5,i+6]})
Expand Down Expand Up @@ -365,7 +365,7 @@ def test_exists_no_file(self):
def test_read_multi_files_avg(self):
# Create another results file with different values
results = pd.DataFrame({'a': [7,8,9], 'd': [10,11,12]})
results.to_csv(os.path.join(self.test_dir, 'dir2/subdir2/subdir3/more_results.csv'), mode='w', index=False)
results.to_csv(os.path.join(self.test_dir, 'dir2','subdir2','subdir3','more_results.csv'), mode='w', index=False)
# Create some params to use for testing
params = ['dir2', 'subdir2', 'subdir3']
# Create an instance of SaverCsv
Expand All @@ -379,7 +379,7 @@ def test_read_multi_files_avg(self):
self.assertEqual(value, 8)

# Remove the results file
os.remove(os.path.join(self.test_dir, 'dir2/subdir2/subdir3/more_results.csv'))
os.remove(os.path.join(self.test_dir, 'dir2','subdir2','subdir3','more_results.csv'))



Expand Down
7 changes: 4 additions & 3 deletions tests/test_slune.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import unittest
from unittest.mock import patch, call, MagicMock
from slune import garg, submit_job, sbatchit
import os

class TestSubmitJob(unittest.TestCase):
@patch('subprocess.run')
def test_regular(self, mock_run):
# Arrange
sh_path = "/path/to/bash/script"
sh_path = os.path.join('path','to','bash','script')
args = ["arg1", "arg2"]

# Act
Expand All @@ -19,8 +20,8 @@ class TestSbatchit(unittest.TestCase):
@patch('subprocess.run')
def test_sbatchit(self, mock_run):
# Arrange
script_path = "/path/to/script"
template_path = "/path/to/template"
script_path = os.path.join('path','to','script')
template_path = os.path.join('path','to','template')
searcher = MagicMock()
searcher.__iter__.return_value = [['arg1', 'arg2'], ['arg3', 'arg4']]
cargs = ["carg1", "carg2"]
Expand Down
Loading

0 comments on commit 9f1c3a1

Please sign in to comment.