diff --git a/py_neuromodulation/stream/settings.py b/py_neuromodulation/stream/settings.py index c2b45e4c..d03afbe8 100644 --- a/py_neuromodulation/stream/settings.py +++ b/py_neuromodulation/stream/settings.py @@ -5,7 +5,7 @@ from pydantic import model_validator, ValidationError from pydantic.functional_validators import ModelWrapValidatorHandler -from py_neuromodulation import logger, user_features +from py_neuromodulation import logger, user_features, PYNM_DIR from py_neuromodulation.utils.types import ( BoolSelector, @@ -296,7 +296,7 @@ def from_file(PATH: _PathLike) -> "NMSettings": @staticmethod def get_default() -> "NMSettings": - return NMSettings() + return NMSettings.from_file(PYNM_DIR / "default_settings.yaml") @staticmethod def list_normalization_methods() -> list[NORM_METHOD]: diff --git a/py_neuromodulation/utils/types.py b/py_neuromodulation/utils/types.py index 6f7ab636..0f9632ee 100644 --- a/py_neuromodulation/utils/types.py +++ b/py_neuromodulation/utils/types.py @@ -106,11 +106,30 @@ def __iter__(self): # type: ignore @model_validator(mode="after") def validate_range(self): if not (isnan(self.frequency_high_hz) or isnan(self.frequency_low_hz)): - assert ( - self.frequency_high_hz > self.frequency_low_hz - ), "Frequency high must be greater than frequency low" + assert self.frequency_high_hz > self.frequency_low_hz, ( + "Frequency high must be greater than frequency low" + ) return self + @model_validator(mode="before") + @classmethod + def check_input(cls, input): + """Pydantic validator to convert the input to a dictionary when passed as a list + as we have it by default in the default_settings.yaml file + For example, [1,2] will be converted to {"frequency_low_hz": 1, "frequency_high_hz": 2} + """ + match input: + case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input: + return input + case Sequence() if len(input) == 2: + return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]} + case _: + raise ValueError( + "Value for FrequencyRange must be a dictionary, " + "or a sequence of 2 numeric values, " + f"but got {input} instead." + ) + class BoolSelector(NMBaseModel): def get_enabled(self):