-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added more testing, added reading capability, fixed bug where it wasn…
…'t creating directories and completely revamped class structure so now it uses virtual classes
- Loading branch information
Showing
10 changed files
with
489 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,77 @@ | ||
import abc | ||
from typing import Type | ||
|
||
class Searcher(): | ||
class BaseSearcher(metaclass=abc.ABCMeta): | ||
""" | ||
Class that creates search space and returns arguments to pass to sbatch script | ||
Base class for all searchers. Which should be subclassed to implement different search algorithms. | ||
Must implement __len__ and next_tune methods. | ||
""" | ||
def __init__(self): | ||
@abc.abstractmethod | ||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
def __len__(self): | ||
|
||
@abc.abstractmethod | ||
def __len__(self, *args, **kwargs): | ||
""" | ||
Returns the number of hyperparameter configurations to try. | ||
""" | ||
return len(self.searcher) | ||
pass | ||
|
||
@abc.abstractmethod | ||
def next_tune(self, *args, **kwargs): | ||
""" | ||
Returns the next hyperparameter configuration to try. | ||
""" | ||
return self.searcher.next_tune(*args, **kwargs) | ||
pass | ||
|
||
class Slog(): | ||
class BaseLogger(metaclass=abc.ABCMeta): | ||
""" | ||
Class used to log metrics during tuning run and to save the results. | ||
Args: | ||
- Logger (object): Class that handles logging of metrics, including the formatting that will be used to save and read the results. | ||
- Saver (object): Class that handles saving logs from Logger to storage and fetching correct logs from storage to give to Logger to read. | ||
Base class for all loggers. Which should be subclassed to implement different logging algorithms. | ||
Must implement log and read_log methods. | ||
""" | ||
def __init__(self, Logger, Saver): | ||
self.logger = Logger | ||
self.saver = Saver # TODO: Have to instantiate this with params, is there way you can just define the class? | ||
@abc.abstractmethod | ||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def log(self, *args, **kwargs): | ||
""" | ||
Logs the metric/s for the current hyperparameter configuration, | ||
stores them in a data frame that we can later save in storage. | ||
You can use this method directly form your saver object (inheriting from BaseSaver) that you instantiate with a BaseLogger subclass object. | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def read_log(self, *args, **kwargs): | ||
""" | ||
Reads the minimum or maximum value of a metric from a data frame. | ||
You can use this method directly form your saver object (inheriting from BaseSaver) that you instantiate with a BaseLogger subclass object. | ||
""" | ||
self.logger.log(*args, **kwargs) | ||
|
||
pass | ||
|
||
class BaseSaver(metaclass=abc.ABCMeta): | ||
""" | ||
Base class for all savers. Which should be subclassed to implement different saving algorithms. | ||
Must implement save_collated and read methods. Inherits from BaseLogger. | ||
""" | ||
@abc.abstractmethod | ||
def __init__(self, logger_instance: BaseLogger, *args, **kwargs): | ||
# Given a class that inherits from BaseLogger we make it accessible through self.logger and make its methods accessible through self.log and self.read_log | ||
self.logger = logger_instance | ||
self.log = self.logger.log | ||
self.read_log = self.logger.read_log | ||
|
||
@abc.abstractmethod | ||
def save_collated(self, *args, **kwargs): | ||
""" | ||
Saves the current results in logger to storage. | ||
""" | ||
self.saver.save_collated(self.logger.results, *args, **kwargs) | ||
pass | ||
|
||
@abc.abstractmethod | ||
def read(self, *args, **kwargs): | ||
""" | ||
Reads results from storage. | ||
""" | ||
return self.saver.read(*args, **kwargs) | ||
|
||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.