Skip to content

Commit

Permalink
Add methods for data interface management in ProcessingModule
Browse files Browse the repository at this point in the history
  • Loading branch information
bendichter committed Jan 27, 2025
1 parent 5afdf2e commit 84223b4
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/pynwb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,46 @@ def get_data_interface(self, **kwargs):
warn(PendingDeprecationWarning('get_data_interface will be replaced by get'))
return self.get(kwargs['data_interface_name'])

def __len__(self):
"""Get the number of data interfaces in this ProcessingModule.
Returns
-------
int
Number of data interfaces
"""
return len(self.data_interfaces)

def keys(self):
"""Get the names of data interfaces in this ProcessingModule.
Returns
-------
KeysView
View of interface names
"""
return self.data_interfaces.keys()

def values(self):
"""Get the data interfaces in this ProcessingModule.
Returns
-------
ValuesView
View of interfaces
"""
return self.data_interfaces.values()

def items(self):
"""Get the (name, interface) pairs in this ProcessingModule.
Returns
-------
ItemsView
View of (name, interface) pairs
"""
return self.data_interfaces.items()


@register_class('TimeSeries', CORE_NAMESPACE)
class TimeSeries(NWBDataInterface):
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,35 @@ def test_getitem(self):
tmp = self.pm["test_ts"]
self.assertIs(tmp, ts)

def test_len(self):
"""Test that len() returns number of data interfaces."""
self.assertEqual(len(self.pm), 0)
ts = self._create_time_series()
self.pm.add(ts)
self.assertEqual(len(self.pm), 1)
ts2 = TimeSeries(name="test_ts2", data=[1, 2, 3], unit="unit", rate=1.0)
self.pm.add(ts2)
self.assertEqual(len(self.pm), 2)

def test_dict_methods(self):
"""Test dictionary-like methods (keys, values, items)."""
ts = self._create_time_series()
ts2 = TimeSeries(name="test_ts2", data=[1, 2, 3], unit="unit", rate=1.0)
self.pm.add(ts)
self.pm.add(ts2)

# Test keys()
keys = self.pm.keys()
self.assertEqual(set(keys), {"test_ts", "test_ts2"})

# Test values()
values = self.pm.values()
self.assertEqual(set(values), {ts, ts2})

# Test items()
items = self.pm.items()
self.assertEqual(set(items), {("test_ts", ts), ("test_ts2", ts2)})


class TestTimeSeries(TestCase):
def test_init_no_parent(self):
Expand Down

0 comments on commit 84223b4

Please sign in to comment.