9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
11
12
- from typing import TYPE_CHECKING , Mapping , Optional , Sequence , Union
12
+ from typing import TYPE_CHECKING , Dict , Mapping , Optional , Sequence , Union
13
13
14
14
import numpy as np
15
15
22
22
GridSampleMode ,
23
23
GridSamplePadMode ,
24
24
InterpolateMode ,
25
+ OptionalImportError ,
25
26
convert_data_type ,
26
27
look_up_option ,
27
28
optional_import ,
28
29
require_pkg ,
29
30
)
30
31
31
32
DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s"
33
+ EXT_WILDCARD = "*"
32
34
logger = get_logger (module_name = __name__ , fmt = DEFAULT_FMT )
33
35
34
36
if TYPE_CHECKING :
41
43
PILImage , _ = optional_import ("PIL.Image" )
42
44
43
45
44
- __all__ = ["ImageWriter" , "ITKWriter" , "NibabelWriter" , "PILWriter" , "logger" ]
46
+ __all__ = [
47
+ "ImageWriter" ,
48
+ "ITKWriter" ,
49
+ "NibabelWriter" ,
50
+ "PILWriter" ,
51
+ "SUPPORTED_WRITERS" ,
52
+ "register_writer" ,
53
+ "resolve_writer" ,
54
+ "logger" ,
55
+ ]
56
+
57
+ SUPPORTED_WRITERS : Dict = {}
58
+
59
+
60
+ def register_writer (ext_name , * im_writers ):
61
+ """
62
+ Register ``ImageWriter``, so that writing a file with filename extension ``ext_name``
63
+ could be resolved to a tuple of potentially appropriate ``ImageWriter``.
64
+ The customised writers could be registered by:
65
+
66
+ .. code-block:: python
67
+
68
+ from monai.data import register_writer
69
+ # `MyWriter` must implement `ImageWriter` interface
70
+ register_writer("nii", MyWriter)
71
+
72
+ Args:
73
+ ext_name: the filename extension of the image.
74
+ As an indexing key, it will be converted to a lower case string.
75
+ im_writers: one or multiple ImageWriter classes with high priority ones first.
76
+ """
77
+ fmt = f"{ ext_name } " .lower ()
78
+ if fmt .startswith ("." ):
79
+ fmt = fmt [1 :]
80
+ existing = look_up_option (fmt , SUPPORTED_WRITERS , default = ())
81
+ all_writers = im_writers + existing
82
+ SUPPORTED_WRITERS [fmt ] = all_writers
83
+
84
+
85
+ def resolve_writer (ext_name , error_if_not_found = True ) -> Sequence :
86
+ """
87
+ Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS``
88
+ according to the filename extension key ``ext_name``.
89
+
90
+ Args:
91
+ ext_name: the filename extension of the image.
92
+ As an indexing key it will be converted to a lower case string.
93
+ error_if_not_found: whether to raise an error if no suitable image writer is found.
94
+ if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``.
95
+ """
96
+ if not SUPPORTED_WRITERS :
97
+ init ()
98
+ fmt = f"{ ext_name } " .lower ()
99
+ if fmt .startswith ("." ):
100
+ fmt = fmt [1 :]
101
+ avail_writers = []
102
+ default_writers = SUPPORTED_WRITERS .get (EXT_WILDCARD , ())
103
+ for _writer in look_up_option (fmt , SUPPORTED_WRITERS , default = default_writers ):
104
+ try :
105
+ _writer () # this triggers `monai.utils.module.require_pkg` to check the system availability
106
+ avail_writers .append (_writer )
107
+ except OptionalImportError :
108
+ continue
109
+ except Exception : # other writer init errors indicating it exists
110
+ avail_writers .append (_writer )
111
+ if not avail_writers and error_if_not_found :
112
+ raise OptionalImportError (f"No ImageWriter backend found for { fmt } ." )
113
+ writer_tuple = ensure_tuple (avail_writers )
114
+ SUPPORTED_WRITERS [fmt ] = writer_tuple
115
+ return writer_tuple
45
116
46
117
47
118
class ImageWriter :
@@ -297,7 +368,9 @@ def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs):
297
368
"""
298
369
super ().__init__ (output_dtype = output_dtype , affine = None , channel_dim = 0 , ** kwargs )
299
370
300
- def set_data_array (self , data_array , channel_dim : Optional [int ] = 0 , squeeze_end_dims : bool = True , ** kwargs ):
371
+ def set_data_array (
372
+ self , data_array : NdarrayOrTensor , channel_dim : Optional [int ] = 0 , squeeze_end_dims : bool = True , ** kwargs
373
+ ):
301
374
"""
302
375
Convert ``data_array`` into 'channel-last' numpy ndarray.
303
376
@@ -309,14 +382,15 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end
309
382
kwargs: keyword arguments passed to ``self.convert_to_channel_last``,
310
383
currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively.
311
384
"""
385
+ _r = len (data_array .shape )
312
386
self .data_obj = self .convert_to_channel_last (
313
387
data = data_array ,
314
388
channel_dim = channel_dim ,
315
389
squeeze_end_dims = squeeze_end_dims ,
316
390
spatial_ndim = kwargs .pop ("spatial_ndim" , 3 ),
317
391
contiguous = kwargs .pop ("contiguous" , True ),
318
392
)
319
- self .channel_dim = channel_dim
393
+ self .channel_dim = channel_dim if len ( self . data_obj . shape ) >= _r else None # channel dim is at the end
320
394
321
395
def set_metadata (self , meta_dict : Optional [Mapping ] = None , resample : bool = True , ** options ):
322
396
"""
@@ -335,7 +409,7 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru
335
409
data_array = self .data_obj ,
336
410
affine = affine ,
337
411
target_affine = original_affine if resample else None ,
338
- output_spatial_shape = spatial_shape ,
412
+ output_spatial_shape = spatial_shape if resample else None ,
339
413
mode = options .pop ("mode" , GridSampleMode .BILINEAR ),
340
414
padding_mode = options .pop ("padding_mode" , GridSamplePadMode .BORDER ),
341
415
align_corners = options .pop ("align_corners" , False ),
@@ -476,7 +550,7 @@ def set_metadata(self, meta_dict: Optional[Mapping], resample: bool = True, **op
476
550
data_array = self .data_obj ,
477
551
affine = affine ,
478
552
target_affine = original_affine if resample else None ,
479
- output_spatial_shape = spatial_shape ,
553
+ output_spatial_shape = spatial_shape if resample else None ,
480
554
mode = options .pop ("mode" , GridSampleMode .BILINEAR ),
481
555
padding_mode = options .pop ("padding_mode" , GridSamplePadMode .BORDER ),
482
556
align_corners = options .pop ("align_corners" , False ),
@@ -716,3 +790,15 @@ def create_backend_obj(
716
790
data = np .moveaxis (data , 0 , 1 )
717
791
718
792
return PILImage .fromarray (data , mode = kwargs .pop ("image_mode" , None ))
793
+
794
+
795
+ def init ():
796
+ """
797
+ Initialize the image writer modules according to the filename extension.
798
+ """
799
+ for ext in ("png" , "jpg" , "jpeg" , "bmp" , "tiff" , "tif" ):
800
+ register_writer (ext , PILWriter ) # TODO: test 16-bit
801
+ for ext in ("nii.gz" , "nii" ):
802
+ register_writer (ext , NibabelWriter , ITKWriter )
803
+ register_writer ("nrrd" , ITKWriter , NibabelWriter )
804
+ register_writer (EXT_WILDCARD , ITKWriter , NibabelWriter , ITKWriter )
0 commit comments