|
66 | 66 | optional_import,
|
67 | 67 | )
|
68 | 68 | from monai.utils.enums import TransformBackends
|
69 |
| -from monai.utils.misc import is_module_ver_at_least |
70 | 69 | from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype
|
71 | 70 |
|
72 | 71 | PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
|
@@ -939,19 +938,10 @@ def __call__(
|
939 | 938 | data = img[[*select_labels]]
|
940 | 939 | else:
|
941 | 940 | where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore
|
942 |
| - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): |
943 |
| - data = where(in1d(img, select_labels), True, False).reshape(img.shape) |
944 |
| - # pre pytorch 1.8.0, need to use 1/0 instead of True/False |
945 |
| - else: |
946 |
| - data = where( |
947 |
| - in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device) |
948 |
| - ).reshape(img.shape) |
| 941 | + data = where(in1d(img, select_labels), True, False).reshape(img.shape) |
949 | 942 |
|
950 | 943 | if merge_channels or self.merge_channels:
|
951 |
| - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): |
952 |
| - return data.any(0)[None] |
953 |
| - # pre pytorch 1.8.0 compatibility |
954 |
| - return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore |
| 944 | + return data.any(0)[None] |
955 | 945 |
|
956 | 946 | return data
|
957 | 947 |
|
|
0 commit comments