diff --git a/bluecellulab/cell/injector.py b/bluecellulab/cell/injector.py index 6b08100d..8b832e63 100644 --- a/bluecellulab/cell/injector.py +++ b/bluecellulab/cell/injector.py @@ -243,8 +243,14 @@ def add_replay_noise( """Add a replay noise stimulus.""" if section is None: section = self.soma # type: ignore - mean = (stimulus.mean_percent * self.threshold) / 100.0 # type: ignore - variance = (stimulus.variance * self.threshold) / 100.0 # type: ignore + + if stimulus.mean is not None: + mean = stimulus.mean + variance = stimulus.variance + else: + mean = (stimulus.mean_percent * self.threshold) / 100.0 # type: ignore + variance = (stimulus.variance * self.threshold) / 100.0 # type: ignore + tstim = self.add_noise_step( section, # type: ignore segx, diff --git a/bluecellulab/stimulus/circuit_stimulus_definitions.py b/bluecellulab/stimulus/circuit_stimulus_definitions.py index fae177f1..2849c8bb 100644 --- a/bluecellulab/stimulus/circuit_stimulus_definitions.py +++ b/bluecellulab/stimulus/circuit_stimulus_definitions.py @@ -242,11 +242,17 @@ def from_sonata(cls, stimulus_entry: dict, config_dir: Optional[str] = None) -> raise ValueError("Stimulus entry must contain either 'node_set' or 'compartment_set'.") if pattern == Pattern.NOISE: + has_mean = "mean" in stimulus_entry + has_mean_percent = "mean_percent" in stimulus_entry + if has_mean == has_mean_percent: + raise ValueError("Noise input must contain exactly one of 'mean' or 'mean_percent'.") + return Noise( target=target_name, delay=stimulus_entry["delay"], duration=stimulus_entry["duration"], - mean_percent=stimulus_entry["mean_percent"], + mean=stimulus_entry.get("mean"), + mean_percent=stimulus_entry.get("mean_percent"), variance=stimulus_entry["variance"], node_set=node_set, compartment_set=compartment_set, @@ -380,8 +386,16 @@ def from_sonata(cls, stimulus_entry: dict, config_dir: Optional[str] = None) -> @dataclass(frozen=True, config=dict(extra="forbid")) class Noise(Stimulus): - mean_percent: float variance: float + mean: Optional[float] = None # nA + mean_percent: Optional[float] = None # % of threshold + + def __post_init__(self): + # exactly one of mean / mean_percent must be provided + if (self.mean is None) == (self.mean_percent is None): + raise ValueError("Noise stimulus must define exactly one of 'mean' or 'mean_percent'.") + if self.variance < 0: + raise ValueError("'variance' must be >= 0.") @dataclass(frozen=True, config=dict(extra="forbid")) diff --git a/tests/test_circuit/test_simulation_config.py b/tests/test_circuit/test_simulation_config.py index 6fb568c9..fc250c50 100644 --- a/tests/test_circuit/test_simulation_config.py +++ b/tests/test_circuit/test_simulation_config.py @@ -73,7 +73,7 @@ def test_init_with_invalid_type(): def test_get_all_stimuli_entries(): sim = SonataSimulationConfig(multi_input_conf_path) - noise_stim = Noise("Mosaic_A", 10.0, 20.0, 200.0, 0.001, node_set="Mosaic_A") + noise_stim = Noise(target="Mosaic_A", delay=10.0, duration=20.0, mean_percent=200.0, variance=0.001, node_set="Mosaic_A",) hyper_stim = Hyperpolarizing("Mosaic_A", 0.0, 50.0, node_set="Mosaic_A") pulse_stim = Pulse("Mosaic_A", 10.0, 20.0, 0.1, 25, 10, node_set="Mosaic_A") linear_stim = Linear("Mosaic_A", 10.0, 20.0, 0.1, 0.4, node_set="Mosaic_A") diff --git a/tests/test_stimulus/test_circuit_stimulus_definitions.py b/tests/test_stimulus/test_circuit_stimulus_definitions.py index 4fd74986..dd8bf74e 100644 --- a/tests/test_stimulus/test_circuit_stimulus_definitions.py +++ b/tests/test_stimulus/test_circuit_stimulus_definitions.py @@ -14,7 +14,7 @@ # limitations under the License. import pytest -from bluecellulab.stimulus.circuit_stimulus_definitions import Pattern +from bluecellulab.stimulus.circuit_stimulus_definitions import Noise, Pattern, Stimulus def test_pattern_from_sonata_valid(): @@ -40,3 +40,35 @@ def test_pattern_from_sonata_invalid(): """Test that an invalid pattern raises a ValueError.""" with pytest.raises(ValueError, match="Unknown pattern unknown_pattern"): Pattern.from_sonata("unknown_pattern") + + +def test_noise_requires_exactly_one_mean_field(): + """Noise dataclass should validate exactly one of mean/mean_percent.""" + with pytest.raises(ValueError, match="Noise stimulus must define exactly one of 'mean' or 'mean_percent'."): + Noise(target="T", delay=0.0, duration=1.0, variance=0.1) + + with pytest.raises(ValueError, match="Noise stimulus must define exactly one of 'mean' or 'mean_percent'."): + Noise(target="T", delay=0.0, duration=1.0, variance=0.1, mean=0.01, mean_percent=5.0) + + +def test_noise_negative_variance_raises(): + """Noise variance must be non-negative.""" + with pytest.raises(ValueError, match="'variance' must be >= 0."): + Noise(target="T", delay=0.0, duration=1.0, variance=-0.1, mean=0.01) + + +def test_from_sonata_noise_requires_one_mean_field(): + """Parsing SONATA noise stimulus enforces exactly one mean field.""" + base = { + "module": "noise", + "delay": 0.0, + "duration": 1.0, + "variance": 0.1, + "node_set": "T", + } + + with pytest.raises(ValueError, match="Noise input must contain exactly one of 'mean' or 'mean_percent'."): + Stimulus.from_sonata(dict(base)) + + with pytest.raises(ValueError, match="Noise input must contain exactly one of 'mean' or 'mean_percent'."): + Stimulus.from_sonata({**base, "mean": 0.01, "mean_percent": 5.0})