Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions bluecellulab/cell/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,14 @@ def add_replay_noise(
"""Add a replay noise stimulus."""
if section is None:
section = self.soma # type: ignore
mean = (stimulus.mean_percent * self.threshold) / 100.0 # type: ignore
variance = (stimulus.variance * self.threshold) / 100.0 # type: ignore

if stimulus.mean is not None:
mean = stimulus.mean
variance = stimulus.variance
else:
mean = (stimulus.mean_percent * self.threshold) / 100.0 # type: ignore
variance = (stimulus.variance * self.threshold) / 100.0 # type: ignore

tstim = self.add_noise_step(
section, # type: ignore
segx,
Expand Down
18 changes: 16 additions & 2 deletions bluecellulab/stimulus/circuit_stimulus_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,17 @@ def from_sonata(cls, stimulus_entry: dict, config_dir: Optional[str] = None) ->
raise ValueError("Stimulus entry must contain either 'node_set' or 'compartment_set'.")

if pattern == Pattern.NOISE:
has_mean = "mean" in stimulus_entry
has_mean_percent = "mean_percent" in stimulus_entry
if has_mean == has_mean_percent:
raise ValueError("Noise input must contain exactly one of 'mean' or 'mean_percent'.")

return Noise(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
mean_percent=stimulus_entry["mean_percent"],
mean=stimulus_entry.get("mean"),
mean_percent=stimulus_entry.get("mean_percent"),
variance=stimulus_entry["variance"],
node_set=node_set,
compartment_set=compartment_set,
Expand Down Expand Up @@ -380,8 +386,16 @@ def from_sonata(cls, stimulus_entry: dict, config_dir: Optional[str] = None) ->

@dataclass(frozen=True, config=dict(extra="forbid"))
class Noise(Stimulus):
mean_percent: float
variance: float
mean: Optional[float] = None # nA
mean_percent: Optional[float] = None # % of threshold

def __post_init__(self):
# exactly one of mean / mean_percent must be provided
if (self.mean is None) == (self.mean_percent is None):
raise ValueError("Noise stimulus must define exactly one of 'mean' or 'mean_percent'.")
if self.variance < 0:
raise ValueError("'variance' must be >= 0.")


@dataclass(frozen=True, config=dict(extra="forbid"))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_circuit/test_simulation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_init_with_invalid_type():

def test_get_all_stimuli_entries():
sim = SonataSimulationConfig(multi_input_conf_path)
noise_stim = Noise("Mosaic_A", 10.0, 20.0, 200.0, 0.001, node_set="Mosaic_A")
noise_stim = Noise(target="Mosaic_A", delay=10.0, duration=20.0, mean_percent=200.0, variance=0.001, node_set="Mosaic_A",)
hyper_stim = Hyperpolarizing("Mosaic_A", 0.0, 50.0, node_set="Mosaic_A")
pulse_stim = Pulse("Mosaic_A", 10.0, 20.0, 0.1, 25, 10, node_set="Mosaic_A")
linear_stim = Linear("Mosaic_A", 10.0, 20.0, 0.1, 0.4, node_set="Mosaic_A")
Expand Down
34 changes: 33 additions & 1 deletion tests/test_stimulus/test_circuit_stimulus_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import pytest
from bluecellulab.stimulus.circuit_stimulus_definitions import Pattern
from bluecellulab.stimulus.circuit_stimulus_definitions import Noise, Pattern, Stimulus


def test_pattern_from_sonata_valid():
Expand All @@ -40,3 +40,35 @@ def test_pattern_from_sonata_invalid():
"""Test that an invalid pattern raises a ValueError."""
with pytest.raises(ValueError, match="Unknown pattern unknown_pattern"):
Pattern.from_sonata("unknown_pattern")


def test_noise_requires_exactly_one_mean_field():
"""Noise dataclass should validate exactly one of mean/mean_percent."""
with pytest.raises(ValueError, match="Noise stimulus must define exactly one of 'mean' or 'mean_percent'."):
Noise(target="T", delay=0.0, duration=1.0, variance=0.1)

with pytest.raises(ValueError, match="Noise stimulus must define exactly one of 'mean' or 'mean_percent'."):
Noise(target="T", delay=0.0, duration=1.0, variance=0.1, mean=0.01, mean_percent=5.0)


def test_noise_negative_variance_raises():
"""Noise variance must be non-negative."""
with pytest.raises(ValueError, match="'variance' must be >= 0."):
Noise(target="T", delay=0.0, duration=1.0, variance=-0.1, mean=0.01)


def test_from_sonata_noise_requires_one_mean_field():
"""Parsing SONATA noise stimulus enforces exactly one mean field."""
base = {
"module": "noise",
"delay": 0.0,
"duration": 1.0,
"variance": 0.1,
"node_set": "T",
}

with pytest.raises(ValueError, match="Noise input must contain exactly one of 'mean' or 'mean_percent'."):
Stimulus.from_sonata(dict(base))

with pytest.raises(ValueError, match="Noise input must contain exactly one of 'mean' or 'mean_percent'."):
Stimulus.from_sonata({**base, "mean": 0.01, "mean_percent": 5.0})