From eac393d7444d298cbfe8bcaedc92c750752f6d0f Mon Sep 17 00:00:00 2001 From: h-aze Date: Wed, 4 Oct 2023 16:38:05 +0100 Subject: [PATCH] Added more testing, added reading capability, fixed bug where it wasn't creating directories and completely revamped class structure so now it uses virtual classes --- README.md | 10 +-- slune/base.py | 68 +++++++++++------ slune/loggers.py | 33 ++++++-- slune/savers.py | 71 ++++++++++++++---- slune/searchers.py | 8 +- slune/slune.py | 8 +- slune/utils.py | 32 +++++++- tests/test_loggers.py | 48 +++++++++++- tests/test_savers.py | 170 ++++++++++++++++++++++++++++++++++++++++-- tests/test_utils.py | 107 +++++++++++++++++++++++++- 10 files changed, 489 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 7a09030..4149259 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ So early I haven't even written the docs yet! Will be adding a quick example her ## Coming soon Currently very much in early stages, first things still to do: -- Auto save when finished a tuning run. -- Auto sbatch job naming and job output naming. -- Add ability to read results, currently can only submit jobs and log metrics during tuning. -- Refine class structure, ie. subclassing, making sure classes have essential methods, what are the essential methods and attributes? etc. +- Get Searcher to check which tunings have been done and which haven't and only submit the ones that haven't been done yet. Depending on a flag, ie. if you want to re-run a tuning you can set a flag to re-run all tunings. - Refine package structure and sort out github actions like test coverage, running tests etc. -- Add interfacing with SLURM to check for and re-submit failed jobs etc. - Add more tests and documentation. +- Auto sbatch job naming and job output naming. +- Auto save when finished a tuning run. +- Add interfacing with SLURM to check for and re-submit failed jobs etc. +- Cancelling submitted jobs quickly and easily. - Add some more subclasses for saving job results in different ways and for different tuning methods. Although the idea for this package is to keep it ultra bare-bones and make it easy for the user to mod and add things themselves to their liking. diff --git a/slune/base.py b/slune/base.py index 565c28c..93d2020 100644 --- a/slune/base.py +++ b/slune/base.py @@ -1,51 +1,77 @@ +import abc +from typing import Type -class Searcher(): +class BaseSearcher(metaclass=abc.ABCMeta): """ - Class that creates search space and returns arguments to pass to sbatch script + Base class for all searchers. Which should be subclassed to implement different search algorithms. + Must implement __len__ and next_tune methods. """ - def __init__(self): + @abc.abstractmethod + def __init__(self, *args, **kwargs): pass - - def __len__(self): + + @abc.abstractmethod + def __len__(self, *args, **kwargs): """ Returns the number of hyperparameter configurations to try. """ - return len(self.searcher) + pass + @abc.abstractmethod def next_tune(self, *args, **kwargs): """ Returns the next hyperparameter configuration to try. """ - return self.searcher.next_tune(*args, **kwargs) + pass -class Slog(): +class BaseLogger(metaclass=abc.ABCMeta): """ - Class used to log metrics during tuning run and to save the results. - Args: - - Logger (object): Class that handles logging of metrics, including the formatting that will be used to save and read the results. - - Saver (object): Class that handles saving logs from Logger to storage and fetching correct logs from storage to give to Logger to read. + Base class for all loggers. Which should be subclassed to implement different logging algorithms. + Must implement log and read_log methods. """ - def __init__(self, Logger, Saver): - self.logger = Logger - self.saver = Saver # TODO: Have to instantiate this with params, is there way you can just define the class? + @abc.abstractmethod + def __init__(self, *args, **kwargs): + pass + @abc.abstractmethod def log(self, *args, **kwargs): """ Logs the metric/s for the current hyperparameter configuration, stores them in a data frame that we can later save in storage. + You can use this method directly form your saver object (inheriting from BaseSaver) that you instantiate with a BaseLogger subclass object. + """ + pass + + @abc.abstractmethod + def read_log(self, *args, **kwargs): + """ + Reads the minimum or maximum value of a metric from a data frame. + You can use this method directly form your saver object (inheriting from BaseSaver) that you instantiate with a BaseLogger subclass object. """ - self.logger.log(*args, **kwargs) - + pass + +class BaseSaver(metaclass=abc.ABCMeta): + """ + Base class for all savers. Which should be subclassed to implement different saving algorithms. + Must implement save_collated and read methods. Inherits from BaseLogger. + """ + @abc.abstractmethod + def __init__(self, logger_instance: BaseLogger, *args, **kwargs): + # Given a class that inherits from BaseLogger we make it accessible through self.logger and make its methods accessible through self.log and self.read_log + self.logger = logger_instance + self.log = self.logger.log + self.read_log = self.logger.read_log + + @abc.abstractmethod def save_collated(self, *args, **kwargs): """ Saves the current results in logger to storage. """ - self.saver.save_collated(self.logger.results, *args, **kwargs) + pass + @abc.abstractmethod def read(self, *args, **kwargs): """ Reads results from storage. """ - return self.saver.read(*args, **kwargs) - - + pass diff --git a/slune/loggers.py b/slune/loggers.py index 0357c19..97daf11 100644 --- a/slune/loggers.py +++ b/slune/loggers.py @@ -1,14 +1,16 @@ import pandas as pd +from slune.base import BaseLogger -class LoggerDefault(): +class LoggerDefault(BaseLogger): """ Logs the metric/s for the current hyperparameter configuration, stores them in a data frame that we can later save in storage. """ - def __init__(self): + def __init__(self, *args, **kwargs): + super(LoggerDefault, self).__init__(*args, **kwargs) self.results = pd.DataFrame() - def log(self, metrics): + def log(self, metrics: dict): """ Logs the metric/s for the current hyperparameter configuration, stores them in a data frame that we can later save in storage. @@ -27,6 +29,25 @@ def log(self, metrics): # Append metrics dataframe to results dataframe self.results = pd.concat([self.results, metrics_df], ignore_index=True) - def read_log(self, *args, **kwargs): - # TODO: implement this function - raise NotImplementedError \ No newline at end of file + def read_log(self, data_frame: pd.DataFrame, metric_name: str, min_max: str ='max'): + """ + Reads the minimum or maximum value of a metric from a data frame. + Args: + - data_frame (pd.DataFrame): Data frame containing the metric to be read. + - metric_name (string): Name of the metric to be read. + - min_max (string): Whether to read the minimum or maximum value of the metric, default is 'max'. + Returns: + - value (float): Minimum or maximum value of the metric. + """ + # Get the metric column + metric_col = data_frame[metric_name] + # Get the index of the minimum or maximum value + if min_max == 'max': + index = metric_col.idxmax() + elif min_max == 'min': + index = metric_col.idxmin() + else: + raise ValueError(f"min_max must be 'min' or 'max', got {min_max}") + # Get the value of the metric + value = metric_col.iloc[index] + return value \ No newline at end of file diff --git a/slune/savers.py b/slune/savers.py index d643361..6ff42ac 100644 --- a/slune/savers.py +++ b/slune/savers.py @@ -1,16 +1,20 @@ import os import pandas as pd -from slune.utils import find_directory_path +from slune.utils import find_directory_path, get_all_paths +from slune.base import BaseSaver, BaseLogger +from typing import List, Optional, Type -class SaverCsv(): +class SaverCsv(BaseSaver): """ Saves the results of each run in a CSV file in a hierarchical directory structure based on argument names. """ - def __init__(self, params, root_dir='./tuning_results'): + def __init__(self, logger_instance: BaseLogger, params: List[str] = None, root_dir: Optional[str] ='./tuning_results'): + super(SaverCsv, self).__init__(logger_instance) self.root_dir = root_dir - self.current_path = self.get_path(params) + if params != None: + self.current_path = self.get_path(params) - def strip_params(self, params): + def strip_params(self, params: List[str]): """ Strips the argument names from the arguments given by args. eg. ["--argument_name=argument_value", ...] -> ["--argument_name=", ...] @@ -18,7 +22,7 @@ def strip_params(self, params): """ return [p.split('=')[0].strip() for p in params] - def get_match(self, params): + def get_match(self, params: List[str]): """ Searches the root directory for a directory tree that matches the parameters given. If only partial matches are found, returns the deepest matching directory with the missing parameters appended. @@ -42,13 +46,14 @@ def get_match(self, params): match = os.path.join(*match) return match - def get_path(self, params): + def get_path(self, params: List[str]): """ Creates a path using the parameters by checking existing directories in the root directory. Check get_match for how we create the path, we then check if results files for this path already exist, if they do we increment the number of the results file name that we will use. TODO: Add option to dictate order of parameters in directory structure. TODO: Return warnings if there exist multiple paths that match the parameters but in a different order, or paths that don't go as deep as others. + TODO: Should use same directory if number equal but string not, eg. 1 and 1.0 Args: - params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...]. """ @@ -76,16 +81,56 @@ def get_path(self, params): csv_file_path = os.path.join(dir_path, f'results_{csv_file_number}.csv') return csv_file_path - def save_collated(self, results): - # We add results onto the end of the current results in the csv file if it already exists - # if not then we create a new csv file and save the results there + def save_collated(self, results: pd.DataFrame): + """ + We add results onto the end of the current results in the csv file if it already exists, + if not then we create a new csv file and save the results there + """ + # If path does not exist, create it + # Remove the csv file name from the path + dir_path = self.current_path.split('/')[:-1] + dir_path = '/'.join(dir_path) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + # If csv file already exists, append results to the end if os.path.exists(self.current_path): results = pd.concat([pd.read_csv(self.current_path), results]) results.to_csv(self.current_path, mode='w', index=False) + # If csv file does not exist, create it else: results.to_csv(self.current_path, index=False) - def read(self, *args, **kwargs): - # TODO: implement this function - raise NotImplementedError + def read(self, params: List[str], metric_name: str, min_max: str ='max'): + """ + Finds the min/max value of a metric from all csv files in the root directory that match the parameters given. + Args: + - params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...]. + - metric_name (string): Name of the metric to be read. + - min_max (string): Whether to read the minimum or maximum value of the metric, default is 'max'. + Returns: + - min_max_params (list): List of strings containing the arguments used to get the min/max value of the metric. + - min_max_value (float): Minimum or maximum value of the metric. + """ + # Get all paths that match the parameters given + paths = get_all_paths(params, root_directory=self.root_dir) + # Read the metric from each path + values = {} + for path in paths: + df = pd.read_csv(path) + values[path] = self.read_log(df, metric_name, min_max) + # Get the key of the min/max value + if min_max == 'min': + min_max_params = min(values, key=values.get) + elif min_max == 'max': + min_max_params = max(values, key=values.get) + else: + raise ValueError(f"min_max must be 'min' or 'max', got {min_max}") + # Find the min/max value of the metric from the key + min_max_value = values[min_max_params] + # Format the path into a list of arguments + min_max_params = min_max_params.replace(self.root_dir, '') + if min_max_params.startswith('/'): + min_max_params = min_max_params[1:] + min_max_params = min_max_params.split('/') + return min_max_params, min_max_value diff --git a/slune/searchers.py b/slune/searchers.py index a8ded88..a45051b 100644 --- a/slune/searchers.py +++ b/slune/searchers.py @@ -1,6 +1,7 @@ +from slune.base import BaseSearcher from slune.utils import dict_to_strings -class SearcherGrid(): +class SearcherGrid(BaseSearcher): """ Given dictionary of hyperparameters and values to try, creates grid of all possible hyperparameter configurations, and returns them one by one for each call to next_tune. @@ -9,7 +10,8 @@ class SearcherGrid(): Structure of dictionary should be: { "--argument_name" : [Value_1, Value_2, ...], ... } TODO: Add extra functionality by using nested dictionaries to specify which hyperparameters to try together. """ - def __init__(self, hyperparameters): + def __init__(self, hyperparameters: dict): + super().__init__() self.hyperparameters = hyperparameters self.grid = self.get_grid(hyperparameters) self.grid_index = None @@ -20,7 +22,7 @@ def __len__(self): """ return len(self.grid) - def get_grid(self, param_dict): + def get_grid(self, param_dict: dict): """ Generate all possible combinations of values for each argument in the given dictionary using recursion. diff --git a/slune/slune.py b/slune/slune.py index 6366506..2683bb5 100644 --- a/slune/slune.py +++ b/slune/slune.py @@ -2,7 +2,6 @@ import subprocess import sys -from slune.base import Slog from slune.savers import SaverCsv from slune.loggers import LoggerDefault @@ -71,10 +70,7 @@ def single_garg(arg_name): else: return single_garg(arg_names) -def get_slogcsv(params): - """ - Creates a Slog object from the SaverCsv and LoggerDefault classes. - """ - return Slog(LoggerDefault(), SaverCsv(params)) +def get_csv_slog(params, root_directory='.'): + return SaverCsv(LoggerDefault(), params, root_directory=root_directory) # TODO: add functions for reading results \ No newline at end of file diff --git a/slune/utils.py b/slune/utils.py index 71c0d8d..3373e56 100644 --- a/slune/utils.py +++ b/slune/utils.py @@ -53,4 +53,34 @@ def dict_to_strings(d): out.append('{}={}'.format(key, value)) else: out.append('--{}={}'.format(key, value)) - return out \ No newline at end of file + return out + +def find_csv_files(root_directory='.'): + """ + Recursively finds all csv files in all subdirectories of the root directory and returns their paths. + Args: + - root_directory (string): Path to the root directory to be searched, default is current working directory. + Returns: + - csv_files (list): List of strings containing the paths to all csv files found. + """ + csv_files = [] + for root, dirs, files in os.walk(root_directory): + for file in files: + if file.endswith('.csv'): + csv_files.append(os.path.join(root, file)) + return csv_files + +def get_all_paths(params, root_directory='.'): + """ + Finds all paths of csv files in all subdirectories of the root directory that have a directory in their path matching one of each of all the parameters given. + Args: + - params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...]. + - root_directory (string): Path to the root directory to be searched, default is current working directory. + """ + all_csv = find_csv_files(root_directory) + matches = [] + for csv in all_csv: + path = csv.split('/') + if all([p in path for p in params]): + matches.append(csv) + return matches \ No newline at end of file diff --git a/tests/test_loggers.py b/tests/test_loggers.py index 9163006..4414bfd 100644 --- a/tests/test_loggers.py +++ b/tests/test_loggers.py @@ -5,7 +5,7 @@ from datetime import datetime import time -class TestLoggerDefault(unittest.TestCase): +class TestLoggerDefaultWrite(unittest.TestCase): def setUp(self): self.logger = LoggerDefault() @@ -50,5 +50,51 @@ def test_log_method_adds_correct_values(self): self.assertEqual(row['time_stamp'].round('s'), rounded_timestamp) +class TestLoggerDefaultRead(unittest.TestCase): + def setUp(self): + # Create an instance of LoggerDefault for testing + self.logger = LoggerDefault() + + def test_read_min_metric(self): + # Create a sample DataFrame + data = {'Metric1': [1, 2, 3, 4], + 'Metric2': [5, 6, 7, 8]} + df = pd.DataFrame(data) + + # Test reading the minimum value of Metric1 + result = self.logger.read_log(df, 'Metric1', min_max='min') + self.assertEqual(result, 1) + + def test_read_max_metric(self): + # Create a sample DataFrame + data = {'Metric1': [1, 2, 3, 4], + 'Metric2': [5, 6, 7, 8]} + df = pd.DataFrame(data) + + # Test reading the maximum value of Metric2 + result = self.logger.read_log(df, 'Metric2', min_max='max') + self.assertEqual(result, 8) + + def test_invalid_metric_name(self): + # Create a sample DataFrame + data = {'Metric1': [1, 2, 3, 4], + 'Metric2': [5, 6, 7, 8]} + df = pd.DataFrame(data) + + # Test providing an invalid metric name + with self.assertRaises(KeyError): + self.logger.read_log(df, 'InvalidMetric', min_max='min') + + def test_invalid_min_max_argument(self): + # Create a sample DataFrame + data = {'Metric1': [1, 2, 3, 4], + 'Metric2': [5, 6, 7, 8]} + df = pd.DataFrame(data) + + # Test providing an invalid value for min_max argument + with self.assertRaises(ValueError): + self.logger.read_log(df, 'Metric1', min_max='invalid_value') + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_savers.py b/tests/test_savers.py index 44e5335..40cd0e2 100644 --- a/tests/test_savers.py +++ b/tests/test_savers.py @@ -2,6 +2,7 @@ import os import pandas as pd from slune.savers import SaverCsv +from slune.loggers import LoggerDefault class TestSaverCsv(unittest.TestCase): def setUp(self): @@ -32,7 +33,7 @@ def tearDown(self): def test_get_match_full_match(self): # Create a SaverCsv instance - saver = SaverCsv(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) + saver = SaverCsv(LoggerDefault(), ["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) # Test if get_match finds correct match and builds correct directory path using the parameters matching_dir = saver.get_match(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"]) @@ -40,7 +41,7 @@ def test_get_match_full_match(self): def test_get_match_partial_match(self): # Create a SaverCsv instance - saver = SaverCsv(["--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) + saver = SaverCsv(LoggerDefault(), ["--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) # Test if get_match finds correct match and builds correct directory path using the parameters matching_dir = saver.get_match(["--folder2=0.2", "--folder1=0.1"]) @@ -48,7 +49,7 @@ def test_get_match_partial_match(self): def test_get_match_different_values(self): # Create a SaverCsv instance - saver = SaverCsv(["--folder2=2.2", "--folder1=1.1"], root_dir=self.test_dir) + saver = SaverCsv(LoggerDefault(), ["--folder2=2.2", "--folder1=1.1"], root_dir=self.test_dir) # Test if get_match finds correct match and builds correct directory path using the parameters matching_dir = saver.get_match(["--folder2=2.2", "--folder1=1.1"]) @@ -56,7 +57,7 @@ def test_get_match_different_values(self): def test_get_match_too_deep(self): # Create a SaverCsv instance - saver = SaverCsv(["--folder2=0.2", "--folder3=0.3"], root_dir=self.test_dir) + saver = SaverCsv(LoggerDefault(), ["--folder2=0.2", "--folder3=0.3"], root_dir=self.test_dir) # Test if get_match finds correct match and builds correct directory path using the parameters matching_dir = saver.get_match(["--folder2=0.2", "--folder3=0.3"]) @@ -64,7 +65,7 @@ def test_get_match_too_deep(self): def test_get_match_no_match(self): # Create a SaverCsv instance - saver = SaverCsv(["--folder_not_there=0", "--folder_also_not_there=0.1"], root_dir=self.test_dir) + saver = SaverCsv(LoggerDefault(), ["--folder_not_there=0", "--folder_also_not_there=0.1"], root_dir=self.test_dir) # Test if get_match finds correct match and builds correct directory path using the parameters matching_dir = saver.get_match(["--folder_not_there=0", "--folder_also_not_there=0.1"]) @@ -72,7 +73,7 @@ def test_get_match_no_match(self): def test_get_path_no_results(self): # Create a SaverCsv instance - saver = SaverCsv(["--folder5=0.5","--folder1=0.1", "--folder6=0.6"], root_dir=self.test_dir) + saver = SaverCsv(LoggerDefault(), ["--folder5=0.5","--folder1=0.1", "--folder6=0.6"], root_dir=self.test_dir) # Test if get_path gets the correct path path = saver.current_path @@ -80,7 +81,7 @@ def test_get_path_no_results(self): def test_get_path_already_results(self): # Create a SaverCsv instance - saver = SaverCsv(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) + saver = SaverCsv(LoggerDefault(), ["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) # Test if get_path gets the correct path path = saver.current_path @@ -88,7 +89,7 @@ def test_get_path_already_results(self): def test_save_collated(self): # Create a SaverCsv instance - saver = SaverCsv(["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) + saver = SaverCsv(LoggerDefault(), ["--folder3=0.3", "--folder2=0.2", "--folder1=0.1"], root_dir=self.test_dir) # Create a data frame with some results results = pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]}) # Save the results @@ -114,8 +115,161 @@ def test_save_collated(self): self.assertEqual(read_values, values) # Remove the results file os.remove(os.path.join(self.test_dir, '--folder1=0.1', '--folder2=0.2', '--folder3=0.3', 'results_1.csv')) + + def test_creates_path_if_no_full_match(self): + # Create a SaverCsv instance + saver = SaverCsv(LoggerDefault(), ["--folder3=0.03", "--folder2=0.02", "--folder1=0.01"], root_dir=self.test_dir) + # Create a data frame with some results + results = pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]}) + # Save the results + saver.save_collated(results) + # Check if the results were saved correctly + read_results = pd.read_csv(os.path.join(self.test_dir, "--folder1=0.01/--folder2=0.02/--folder3=0.03/results_0.csv")) + self.assertEqual(read_results.shape, (3,2)) + self.assertEqual(results.columns.tolist(), read_results.columns.tolist()) + read_values = [x for x in read_results.values.tolist() if str(x) != 'nan'] + values = [x for x in results.values.tolist() if str(x) != 'nan'] + self.assertEqual(values, read_values) + # Remove the results file + os.remove(os.path.join(self.test_dir, '--folder1=0.01', '--folder2=0.02', '--folder3=0.03', 'results_0.csv')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.01', '--folder2=0.02', '--folder3=0.03')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.01', '--folder2=0.02')) + os.rmdir(os.path.join(self.test_dir, '--folder1=0.01')) # TODO: add tests for root_dir that has '/'s in it +class TestSaverCsvReadMethod(unittest.TestCase): + + def setUp(self): + # Create a temporary directory with some CSV files for testing + self.test_dir = 'test_directory' + os.makedirs(self.test_dir, exist_ok=True) + + # Creating some CSV files with specific subdirectory paths + self.csv_files = [ + 'dir1/file1.csv', + 'dir2/file2.csv', + 'dir1/subdir1/file3.csv', + 'dir2/subdir2/file4.csv', + 'dir2/subdir2/subdir3/file5.csv' + ] + + for i, file in enumerate(self.csv_files): + file_path = os.path.join(self.test_dir, file) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + # Create a data frame with different values for each CSV file + if file_path == 'dir2/subdir2/subdir3/file5.csv': + results = pd.DataFrame({'a': [i+1,i+2,i+3], 'b': [i+4,i+5,i+6], 'c': [i+7,i+8,i+9]}) + else: + results = pd.DataFrame({'a': [i+1,i+2,i+3], 'b': [i+4,i+5,i+6]}) + # Save the results + results.to_csv(file_path, mode='w', index=False) + # The data frames we created should look like this: + # file1.csv: a b + # 1 4 + # 2 5 + # 3 6 + # file2.csv: a b + # 2 5 + # 3 6 + # 4 7 + # file3.csv: a b + # 3 6 + # 4 7 + # 5 8 + # file4.csv: a b + # 4 7 + # 5 8 + # 6 9 + # file5.csv: a b + # 5 8 + # 6 9 + # 7 10 + + def tearDown(self): + # Clean up the temporary directory and files after testing + for root, dirs, files in os.walk(self.test_dir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(self.test_dir) + + def test_read_max_metric(self): + # Create some params to use for testing + params = ['dir1', 'subdir1'] + # Create an instance of SaverCsv + saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir) + + # Call the read method to get min and max values + max_param_a, max_value_a = saver.read(params, 'a') + max_param_b, max_value_b = saver.read(params, 'b') + + # Perform assertions based on your expectations + self.assertEqual(max_param_a, ['dir1','subdir1','file3.csv']) + self.assertEqual(max_param_b, ['dir1','subdir1','file3.csv']) + self.assertEqual(max_value_a, 5) + self.assertEqual(max_value_b, 8) + + def test_read_min_metric(self): + # Create some params to use for testing + params = ['dir1', 'subdir1'] + # Create an instance of SaverCsv + saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir) + + # Call the read method to get min and max values + max_param_a, min_value_a = saver.read(params, 'a', min_max='min') + max_param_b, min_value_b = saver.read(params, 'b', min_max='min') + + # Perform assertions based on your expectations + self.assertEqual(max_param_a, ['dir1','subdir1','file3.csv']) + self.assertEqual(max_param_b, ['dir1','subdir1','file3.csv']) + self.assertEqual(min_value_a, 3) + self.assertEqual(min_value_b, 6) + + def test_read_nonexistent_metric(self): + # Create some params to use for testing + params = ['dir1', 'subdir1'] + # Create an instance of SaverCsv + saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir) + + # Call the read method to get min and max values + with self.assertRaises(KeyError): + saver.read(params, 'c') + + def test_multiple_matching_paths(self): + # Create some params to use for testing + params = ['dir2'] + # Create an instance of SaverCsv + saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir) + + # Call the read method to get max value and params + param, value = saver.read(params, 'a') + + # Check results are as expected + self.assertEqual(param, ['dir2','subdir2','subdir3','file5.csv']) + self.assertEqual(value, 7) + + def test_no_matching_paths(self): + # Create some params to use for testing + params = ['dir3'] + # Create an instance of SaverCsv + saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir) + + # Call the read method to get max value and params + with self.assertRaises(ValueError): + saver.read(params, 'a') + + def test_multiple_matching_paths_missing_metrics(self): + # Create some params to use for testing + params = ['dir2'] + # Create an instance of SaverCsv + saver = SaverCsv(LoggerDefault(), root_dir=self.test_dir) + + # Call the read method to get max value and params + with self.assertRaises(KeyError): + saver.read(params, 'c') + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index 6fa427a..82030d7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ import unittest import os -from slune.utils import find_directory_path, dict_to_strings +from slune.utils import find_directory_path, dict_to_strings, find_csv_files, get_all_paths class TestFindDirectoryPath(unittest.TestCase): @@ -72,4 +72,107 @@ def test_dict_to_strings(self): d = {'arg1': 1, 'arg2': 2} result = dict_to_strings(d) self.assertEqual(result, ['--arg1=1', '--arg2=2']) - \ No newline at end of file + + +class TestFindCSVFiles(unittest.TestCase): + + def setUp(self): + # Create a temporary directory with some CSV files for testing + self.test_dir = 'test_directory' + os.makedirs(self.test_dir, exist_ok=True) + + # Creating some CSV files + self.csv_files = [ + 'file1.csv', + 'file2.csv', + 'subdir1/file3.csv', + 'subdir2/file4.csv', + 'subdir2/subdir3/file5.csv' + ] + + for file in self.csv_files: + file_path = os.path.join(self.test_dir, file) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w') as f: + f.write("Sample CSV content") + + def tearDown(self): + # Clean up the temporary directory and files after testing + for root, dirs, files in os.walk(self.test_dir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(self.test_dir) + + def test_find_csv_files(self): + # Test the find_csv_files function + + # Call the function to get the result + result = find_csv_files(self.test_dir) + + # Define the expected result based on the files we created + expected_result = [ + os.path.join(self.test_dir, file) for file in self.csv_files + ] + + # Sort both lists for comparison, as the order might not be guaranteed + result.sort() + expected_result.sort() + + # Assert that the result matches the expected result + self.assertEqual(result, expected_result) + + +class TestGetAllPaths(unittest.TestCase): + + def setUp(self): + # Create a temporary directory with some CSV files for testing + self.test_dir = 'test_directory' + os.makedirs(self.test_dir, exist_ok=True) + + # Creating some CSV files with specific subdirectory paths + self.csv_files = [ + 'dir1/file1.csv', + 'dir2/file2.csv', + 'dir1/subdir1/file3.csv', + 'dir2/subdir2/file4.csv', + 'dir2/subdir2/subdir3/file5.csv' + ] + + for file in self.csv_files: + file_path = os.path.join(self.test_dir, file) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w') as f: + f.write("Sample CSV content") + + def tearDown(self): + # Clean up the temporary directory and files after testing + for root, dirs, files in os.walk(self.test_dir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(self.test_dir) + + def test_get_all_paths(self): + # Test the get_all_paths function + + # Call the function to get the result + result = get_all_paths(['dir1', 'subdir1'], self.test_dir) + + # Define the expected result based on the files we created + expected_result = [ + os.path.join(self.test_dir, 'dir1/subdir1/file3.csv') + ] + + # Sort both lists for comparison, as the order might not be guaranteed + result.sort() + expected_result.sort() + + # Assert that the result matches the expected result + self.assertEqual(result, expected_result) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file