Skip to content

Commit fb8c5bf

Browse files
authored
Removed outdated torch version checks from transform functions (#8359)
Fixes #8348 ### Description Support for `torch` versions prior to `1.13` has been dropped, so those `1.8` version checks are not required anymore. Furthermore, as reported in the issue description, those checks led to unstable behaviour when using certain transforms in data pipelines. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Nicolas Kaenzig <[email protected]>
1 parent 960c59b commit fb8c5bf

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

monai/transforms/utility/array.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
optional_import,
6767
)
6868
from monai.utils.enums import TransformBackends
69-
from monai.utils.misc import is_module_ver_at_least
7069
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype
7170

7271
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
@@ -939,19 +938,10 @@ def __call__(
939938
data = img[[*select_labels]]
940939
else:
941940
where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore
942-
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
943-
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
944-
# pre pytorch 1.8.0, need to use 1/0 instead of True/False
945-
else:
946-
data = where(
947-
in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device)
948-
).reshape(img.shape)
941+
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
949942

950943
if merge_channels or self.merge_channels:
951-
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
952-
return data.any(0)[None]
953-
# pre pytorch 1.8.0 compatibility
954-
return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore
944+
return data.any(0)[None]
955945

956946
return data
957947

monai/transforms/utils_pytorch_numpy_unification.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919

2020
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
21-
from monai.utils.misc import is_module_ver_at_least
2221
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
2322

2423
__all__ = [
@@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
215214
Element-wise floor division between two arrays/tensors.
216215
"""
217216
if isinstance(a, torch.Tensor):
218-
if is_module_ver_at_least(torch, (1, 8, 0)):
219-
return torch.div(a, b, rounding_mode="floor")
220217
return torch.floor_divide(a, b)
221-
return np.floor_divide(a, b)
218+
else:
219+
return np.floor_divide(a, b)
222220

223221

224222
def unravel_index(idx, shape) -> NdarrayOrTensor:

0 commit comments

Comments
 (0)