Skip to content

Commit bbd9522

Browse files
authored
Merge branch 'Project-MONAI:dev' into 8085-average-precision
2 parents 2cb04e7 + e538f7f commit bbd9522

File tree

3 files changed

+57
-17
lines changed

3 files changed

+57
-17
lines changed

monai/transforms/compose.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
def execute_compose(
4848
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
4949
transforms: Sequence[Any],
50-
map_items: bool = True,
50+
map_items: bool | int = True,
5151
unpack_items: bool = False,
5252
start: int = 0,
5353
end: int | None = None,
@@ -65,8 +65,13 @@ def execute_compose(
6565
Args:
6666
data: a tensor-like object to be transformed
6767
transforms: a sequence of transforms to be carried out
68-
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
69-
defaults to `True`.
68+
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
69+
it can behave as follows:
70+
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
71+
to the first level of items in `data`.
72+
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
73+
should be recursively applied. This allows treating multi-sample transforms applied after another
74+
multi-sample transform while controlling how deep the mapping goes.
7075
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
7176
defaults to `False`.
7277
start: the index of the first transform to be executed. If not set, this defaults to 0
@@ -205,8 +210,14 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
205210
206211
Args:
207212
transforms: sequence of callables.
208-
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
209-
defaults to `True`.
213+
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
214+
it can behave as follows:
215+
216+
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
217+
to the first level of items in `data`.
218+
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
219+
should be recursively applied. This allows treating multi-sample transforms applied after another
220+
multi-sample transform while controlling how deep the mapping goes.
210221
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
211222
defaults to `False`.
212223
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
@@ -227,7 +238,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
227238
def __init__(
228239
self,
229240
transforms: Sequence[Callable] | Callable | None = None,
230-
map_items: bool = True,
241+
map_items: bool | int = True,
231242
unpack_items: bool = False,
232243
log_stats: bool | str = False,
233244
lazy: bool | None = False,
@@ -238,9 +249,9 @@ def __init__(
238249
if transforms is None:
239250
transforms = []
240251

241-
if not isinstance(map_items, bool):
252+
if not isinstance(map_items, (bool, int)):
242253
raise ValueError(
243-
f"Argument 'map_items' should be boolean. Got {type(map_items)}."
254+
f"Argument 'map_items' should be boolean or int. Got {type(map_items)}."
244255
"Check brackets when passing a sequence of callables."
245256
)
246257

@@ -391,8 +402,14 @@ class OneOf(Compose):
391402
transforms: sequence of callables.
392403
weights: probabilities corresponding to each callable in transforms.
393404
Probabilities are normalized to sum to one.
394-
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
395-
defaults to `True`.
405+
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
406+
it can behave as follows:
407+
408+
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
409+
to the first level of items in `data`.
410+
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
411+
should be recursively applied. This allows treating multi-sample transforms applied after another
412+
multi-sample transform while controlling how deep the mapping goes.
396413
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
397414
defaults to `False`.
398415
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
@@ -414,7 +431,7 @@ def __init__(
414431
self,
415432
transforms: Sequence[Callable] | Callable | None = None,
416433
weights: Sequence[float] | float | None = None,
417-
map_items: bool = True,
434+
map_items: bool | int = True,
418435
unpack_items: bool = False,
419436
log_stats: bool | str = False,
420437
lazy: bool | None = False,

monai/transforms/transform.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ def _apply_transform(
101101
def apply_transform(
102102
transform: Callable[..., ReturnType],
103103
data: Any,
104-
map_items: bool = True,
104+
map_items: bool | int = True,
105105
unpack_items: bool = False,
106106
log_stats: bool | str = False,
107107
lazy: bool | None = None,
108108
overrides: dict | None = None,
109-
) -> list[ReturnType] | ReturnType:
109+
) -> list[Any] | ReturnType:
110110
"""
111111
Transform `data` with `transform`.
112112
@@ -117,8 +117,13 @@ def apply_transform(
117117
Args:
118118
transform: a callable to be used to transform `data`.
119119
data: an object to be transformed.
120-
map_items: whether to apply transform to each item in `data`,
121-
if `data` is a list or tuple. Defaults to True.
120+
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
121+
it can behave as follows:
122+
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
123+
to the first level of items in `data`.
124+
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
125+
should be recursively applied. This allows treating multi-sample transforms applied after another
126+
multi-sample transform while controlling how deep the mapping goes.
122127
unpack_items: whether to unpack parameters using `*`. Defaults to False.
123128
log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which
124129
disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the
@@ -136,8 +141,12 @@ def apply_transform(
136141
Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof.
137142
"""
138143
try:
139-
if isinstance(data, (list, tuple)) and map_items:
140-
return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
144+
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
145+
if isinstance(data, (list, tuple)) and map_items_ > 0:
146+
return [
147+
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
148+
for item in data
149+
]
141150
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
142151
except Exception as e:
143152
# if in debug mode, don't swallow exception so that the breakpoint

tests/transforms/compose/test_compose.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ def b(i, i2):
141141
self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected)
142142
self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected)
143143

144+
def test_list_non_dict_compose_with_unpack_map_2(self):
145+
146+
def a(i, i2):
147+
return i + "a", i2 + "a2"
148+
149+
def b(i, i2):
150+
return i + "b", i2 + "b2"
151+
152+
transforms = [a, b, a, b]
153+
data = [[("", ""), ("", "")], [("t", "t"), ("t", "t")]]
154+
expected = [[("abab", "a2b2a2b2"), ("abab", "a2b2a2b2")], [("tabab", "ta2b2a2b2"), ("tabab", "ta2b2a2b2")]]
155+
self.assertEqual(mt.Compose(transforms, map_items=2, unpack_items=True)(data), expected)
156+
self.assertEqual(execute_compose(data, transforms, map_items=2, unpack_items=True), expected)
157+
144158
def test_list_dict_compose_no_map(self):
145159

146160
def a(d): # transform to handle dict data

0 commit comments

Comments
 (0)