Skip to content

Commit

Permalink
Added more testing, added reading capability, fixed bug where it wasn…
Browse files Browse the repository at this point in the history
…'t creating directories and completely revamped class structure so now it uses virtual classes
  • Loading branch information
h-0-0 committed Oct 4, 2023
1 parent 22404df commit eac393d
Show file tree
Hide file tree
Showing 10 changed files with 489 additions and 66 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ So early I haven't even written the docs yet! Will be adding a quick example her

## Coming soon
Currently very much in early stages, first things still to do:
- Auto save when finished a tuning run.
- Auto sbatch job naming and job output naming.
- Add ability to read results, currently can only submit jobs and log metrics during tuning.
- Refine class structure, ie. subclassing, making sure classes have essential methods, what are the essential methods and attributes? etc.
- Get Searcher to check which tunings have been done and which haven't and only submit the ones that haven't been done yet. Depending on a flag, ie. if you want to re-run a tuning you can set a flag to re-run all tunings.
- Refine package structure and sort out github actions like test coverage, running tests etc.
- Add interfacing with SLURM to check for and re-submit failed jobs etc.
- Add more tests and documentation.
- Auto sbatch job naming and job output naming.
- Auto save when finished a tuning run.
- Add interfacing with SLURM to check for and re-submit failed jobs etc.
- Cancelling submitted jobs quickly and easily.
- Add some more subclasses for saving job results in different ways and for different tuning methods.
Although the idea for this package is to keep it ultra bare-bones and make it easy for the user to mod and add things themselves to their liking.

Expand Down
68 changes: 47 additions & 21 deletions slune/base.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,77 @@
import abc
from typing import Type

class Searcher():
class BaseSearcher(metaclass=abc.ABCMeta):
"""
Class that creates search space and returns arguments to pass to sbatch script
Base class for all searchers. Which should be subclassed to implement different search algorithms.
Must implement __len__ and next_tune methods.
"""
def __init__(self):
@abc.abstractmethod
def __init__(self, *args, **kwargs):
pass

def __len__(self):

@abc.abstractmethod
def __len__(self, *args, **kwargs):
"""
Returns the number of hyperparameter configurations to try.
"""
return len(self.searcher)
pass

@abc.abstractmethod
def next_tune(self, *args, **kwargs):
"""
Returns the next hyperparameter configuration to try.
"""
return self.searcher.next_tune(*args, **kwargs)
pass

class Slog():
class BaseLogger(metaclass=abc.ABCMeta):
"""
Class used to log metrics during tuning run and to save the results.
Args:
- Logger (object): Class that handles logging of metrics, including the formatting that will be used to save and read the results.
- Saver (object): Class that handles saving logs from Logger to storage and fetching correct logs from storage to give to Logger to read.
Base class for all loggers. Which should be subclassed to implement different logging algorithms.
Must implement log and read_log methods.
"""
def __init__(self, Logger, Saver):
self.logger = Logger
self.saver = Saver # TODO: Have to instantiate this with params, is there way you can just define the class?
@abc.abstractmethod
def __init__(self, *args, **kwargs):
pass

@abc.abstractmethod
def log(self, *args, **kwargs):
"""
Logs the metric/s for the current hyperparameter configuration,
stores them in a data frame that we can later save in storage.
You can use this method directly form your saver object (inheriting from BaseSaver) that you instantiate with a BaseLogger subclass object.
"""
pass

@abc.abstractmethod
def read_log(self, *args, **kwargs):
"""
Reads the minimum or maximum value of a metric from a data frame.
You can use this method directly form your saver object (inheriting from BaseSaver) that you instantiate with a BaseLogger subclass object.
"""
self.logger.log(*args, **kwargs)

pass

class BaseSaver(metaclass=abc.ABCMeta):
"""
Base class for all savers. Which should be subclassed to implement different saving algorithms.
Must implement save_collated and read methods. Inherits from BaseLogger.
"""
@abc.abstractmethod
def __init__(self, logger_instance: BaseLogger, *args, **kwargs):
# Given a class that inherits from BaseLogger we make it accessible through self.logger and make its methods accessible through self.log and self.read_log
self.logger = logger_instance
self.log = self.logger.log
self.read_log = self.logger.read_log

@abc.abstractmethod
def save_collated(self, *args, **kwargs):
"""
Saves the current results in logger to storage.
"""
self.saver.save_collated(self.logger.results, *args, **kwargs)
pass

@abc.abstractmethod
def read(self, *args, **kwargs):
"""
Reads results from storage.
"""
return self.saver.read(*args, **kwargs)


pass
33 changes: 27 additions & 6 deletions slune/loggers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import pandas as pd
from slune.base import BaseLogger

class LoggerDefault():
class LoggerDefault(BaseLogger):
"""
Logs the metric/s for the current hyperparameter configuration,
stores them in a data frame that we can later save in storage.
"""
def __init__(self):
def __init__(self, *args, **kwargs):
super(LoggerDefault, self).__init__(*args, **kwargs)
self.results = pd.DataFrame()

def log(self, metrics):
def log(self, metrics: dict):
"""
Logs the metric/s for the current hyperparameter configuration,
stores them in a data frame that we can later save in storage.
Expand All @@ -27,6 +29,25 @@ def log(self, metrics):
# Append metrics dataframe to results dataframe
self.results = pd.concat([self.results, metrics_df], ignore_index=True)

def read_log(self, *args, **kwargs):
# TODO: implement this function
raise NotImplementedError
def read_log(self, data_frame: pd.DataFrame, metric_name: str, min_max: str ='max'):
"""
Reads the minimum or maximum value of a metric from a data frame.
Args:
- data_frame (pd.DataFrame): Data frame containing the metric to be read.
- metric_name (string): Name of the metric to be read.
- min_max (string): Whether to read the minimum or maximum value of the metric, default is 'max'.
Returns:
- value (float): Minimum or maximum value of the metric.
"""
# Get the metric column
metric_col = data_frame[metric_name]
# Get the index of the minimum or maximum value
if min_max == 'max':
index = metric_col.idxmax()
elif min_max == 'min':
index = metric_col.idxmin()
else:
raise ValueError(f"min_max must be 'min' or 'max', got {min_max}")
# Get the value of the metric
value = metric_col.iloc[index]
return value
71 changes: 58 additions & 13 deletions slune/savers.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
import os
import pandas as pd
from slune.utils import find_directory_path
from slune.utils import find_directory_path, get_all_paths
from slune.base import BaseSaver, BaseLogger
from typing import List, Optional, Type

class SaverCsv():
class SaverCsv(BaseSaver):
"""
Saves the results of each run in a CSV file in a hierarchical directory structure based on argument names.
"""
def __init__(self, params, root_dir='./tuning_results'):
def __init__(self, logger_instance: BaseLogger, params: List[str] = None, root_dir: Optional[str] ='./tuning_results'):
super(SaverCsv, self).__init__(logger_instance)
self.root_dir = root_dir
self.current_path = self.get_path(params)
if params != None:
self.current_path = self.get_path(params)

def strip_params(self, params):
def strip_params(self, params: List[str]):
"""
Strips the argument names from the arguments given by args.
eg. ["--argument_name=argument_value", ...] -> ["--argument_name=", ...]
Also gets rid of blank spaces
"""
return [p.split('=')[0].strip() for p in params]

def get_match(self, params):
def get_match(self, params: List[str]):
"""
Searches the root directory for a directory tree that matches the parameters given.
If only partial matches are found, returns the deepest matching directory with the missing parameters appended.
Expand All @@ -42,13 +46,14 @@ def get_match(self, params):
match = os.path.join(*match)
return match

def get_path(self, params):
def get_path(self, params: List[str]):
"""
Creates a path using the parameters by checking existing directories in the root directory.
Check get_match for how we create the path, we then check if results files for this path already exist,
if they do we increment the number of the results file name that we will use.
TODO: Add option to dictate order of parameters in directory structure.
TODO: Return warnings if there exist multiple paths that match the parameters but in a different order, or paths that don't go as deep as others.
TODO: Should use same directory if number equal but string not, eg. 1 and 1.0
Args:
- params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...].
"""
Expand Down Expand Up @@ -76,16 +81,56 @@ def get_path(self, params):
csv_file_path = os.path.join(dir_path, f'results_{csv_file_number}.csv')
return csv_file_path

def save_collated(self, results):
# We add results onto the end of the current results in the csv file if it already exists
# if not then we create a new csv file and save the results there
def save_collated(self, results: pd.DataFrame):
"""
We add results onto the end of the current results in the csv file if it already exists,
if not then we create a new csv file and save the results there
"""
# If path does not exist, create it
# Remove the csv file name from the path
dir_path = self.current_path.split('/')[:-1]
dir_path = '/'.join(dir_path)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# If csv file already exists, append results to the end
if os.path.exists(self.current_path):
results = pd.concat([pd.read_csv(self.current_path), results])
results.to_csv(self.current_path, mode='w', index=False)
# If csv file does not exist, create it
else:
results.to_csv(self.current_path, index=False)

def read(self, *args, **kwargs):
# TODO: implement this function
raise NotImplementedError
def read(self, params: List[str], metric_name: str, min_max: str ='max'):
"""
Finds the min/max value of a metric from all csv files in the root directory that match the parameters given.
Args:
- params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...].
- metric_name (string): Name of the metric to be read.
- min_max (string): Whether to read the minimum or maximum value of the metric, default is 'max'.
Returns:
- min_max_params (list): List of strings containing the arguments used to get the min/max value of the metric.
- min_max_value (float): Minimum or maximum value of the metric.
"""
# Get all paths that match the parameters given
paths = get_all_paths(params, root_directory=self.root_dir)
# Read the metric from each path
values = {}
for path in paths:
df = pd.read_csv(path)
values[path] = self.read_log(df, metric_name, min_max)
# Get the key of the min/max value
if min_max == 'min':
min_max_params = min(values, key=values.get)
elif min_max == 'max':
min_max_params = max(values, key=values.get)
else:
raise ValueError(f"min_max must be 'min' or 'max', got {min_max}")
# Find the min/max value of the metric from the key
min_max_value = values[min_max_params]
# Format the path into a list of arguments
min_max_params = min_max_params.replace(self.root_dir, '')
if min_max_params.startswith('/'):
min_max_params = min_max_params[1:]
min_max_params = min_max_params.split('/')
return min_max_params, min_max_value

8 changes: 5 additions & 3 deletions slune/searchers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from slune.base import BaseSearcher
from slune.utils import dict_to_strings

class SearcherGrid():
class SearcherGrid(BaseSearcher):
"""
Given dictionary of hyperparameters and values to try, creates grid of all possible hyperparameter configurations,
and returns them one by one for each call to next_tune.
Expand All @@ -9,7 +10,8 @@ class SearcherGrid():
Structure of dictionary should be: { "--argument_name" : [Value_1, Value_2, ...], ... }
TODO: Add extra functionality by using nested dictionaries to specify which hyperparameters to try together.
"""
def __init__(self, hyperparameters):
def __init__(self, hyperparameters: dict):
super().__init__()
self.hyperparameters = hyperparameters
self.grid = self.get_grid(hyperparameters)
self.grid_index = None
Expand All @@ -20,7 +22,7 @@ def __len__(self):
"""
return len(self.grid)

def get_grid(self, param_dict):
def get_grid(self, param_dict: dict):
"""
Generate all possible combinations of values for each argument in the given dictionary using recursion.
Expand Down
8 changes: 2 additions & 6 deletions slune/slune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import subprocess
import sys

from slune.base import Slog
from slune.savers import SaverCsv
from slune.loggers import LoggerDefault

Expand Down Expand Up @@ -71,10 +70,7 @@ def single_garg(arg_name):
else:
return single_garg(arg_names)

def get_slogcsv(params):
"""
Creates a Slog object from the SaverCsv and LoggerDefault classes.
"""
return Slog(LoggerDefault(), SaverCsv(params))
def get_csv_slog(params, root_directory='.'):
return SaverCsv(LoggerDefault(), params, root_directory=root_directory)

# TODO: add functions for reading results
32 changes: 31 additions & 1 deletion slune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,34 @@ def dict_to_strings(d):
out.append('{}={}'.format(key, value))
else:
out.append('--{}={}'.format(key, value))
return out
return out

def find_csv_files(root_directory='.'):
"""
Recursively finds all csv files in all subdirectories of the root directory and returns their paths.
Args:
- root_directory (string): Path to the root directory to be searched, default is current working directory.
Returns:
- csv_files (list): List of strings containing the paths to all csv files found.
"""
csv_files = []
for root, dirs, files in os.walk(root_directory):
for file in files:
if file.endswith('.csv'):
csv_files.append(os.path.join(root, file))
return csv_files

def get_all_paths(params, root_directory='.'):
"""
Finds all paths of csv files in all subdirectories of the root directory that have a directory in their path matching one of each of all the parameters given.
Args:
- params (list): List of strings containing the arguments used, in form ["--argument_name=argument_value", ...].
- root_directory (string): Path to the root directory to be searched, default is current working directory.
"""
all_csv = find_csv_files(root_directory)
matches = []
for csv in all_csv:
path = csv.split('/')
if all([p in path for p in params]):
matches.append(csv)
return matches
Loading

0 comments on commit eac393d

Please sign in to comment.