Skip to content

Commit

Permalink
Read default settings from file, fix regression where FrequencyRanges…
Browse files Browse the repository at this point in the history
… were not being properly read from YAML files
  • Loading branch information
Toni M. Brotons authored and Toni M. Brotons committed Feb 6, 2025
1 parent d18640d commit d9a83fd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
4 changes: 2 additions & 2 deletions py_neuromodulation/stream/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import model_validator, ValidationError
from pydantic.functional_validators import ModelWrapValidatorHandler

from py_neuromodulation import logger, user_features
from py_neuromodulation import logger, user_features, PYNM_DIR

from py_neuromodulation.utils.types import (
BoolSelector,
Expand Down Expand Up @@ -296,7 +296,7 @@ def from_file(PATH: _PathLike) -> "NMSettings":

@staticmethod
def get_default() -> "NMSettings":
return NMSettings()
return NMSettings.from_file(PYNM_DIR / "default_settings.yaml")

@staticmethod
def list_normalization_methods() -> list[NORM_METHOD]:
Expand Down
25 changes: 22 additions & 3 deletions py_neuromodulation/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,30 @@ def __iter__(self): # type: ignore
@model_validator(mode="after")
def validate_range(self):
if not (isnan(self.frequency_high_hz) or isnan(self.frequency_low_hz)):
assert (
self.frequency_high_hz > self.frequency_low_hz
), "Frequency high must be greater than frequency low"
assert self.frequency_high_hz > self.frequency_low_hz, (
"Frequency high must be greater than frequency low"
)
return self

@model_validator(mode="before")
@classmethod
def check_input(cls, input):
"""Pydantic validator to convert the input to a dictionary when passed as a list
as we have it by default in the default_settings.yaml file
For example, [1,2] will be converted to {"frequency_low_hz": 1, "frequency_high_hz": 2}
"""
match input:
case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
return input
case Sequence() if len(input) == 2:
return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]}
case _:
raise ValueError(
"Value for FrequencyRange must be a dictionary, "
"or a sequence of 2 numeric values, "
f"but got {input} instead."
)


class BoolSelector(NMBaseModel):
def get_enabled(self):
Expand Down

0 comments on commit d9a83fd

Please sign in to comment.