14
14
import os
15
15
import time
16
16
import warnings
17
+ import concurrent .futures
17
18
18
19
from .pixel_ident import recombine_skeletons , isolateregions
19
20
from .utilities import eight_con , round_to_odd , threshold_local , in_ipynb
@@ -498,7 +499,7 @@ def create_mask(self, glob_thresh=None, adapt_thresh=None,
498
499
if in_ipynb ():
499
500
p .clf ()
500
501
501
- def medskel (self , verbose = False , save_png = False ):
502
+ def medskel (self , verbose = False , save_png = False , rng = None ):
502
503
'''
503
504
This function performs the medial axis transform (skeletonization)
504
505
on the mask. This is essentially a wrapper function of
@@ -516,6 +517,9 @@ def medskel(self, verbose=False, save_png=False):
516
517
Enables plotting.
517
518
save_png : bool, optional
518
519
Saves the plot made in verbose mode. Disabled by default.
520
+ rng : numpy.random.RandomState or int, optional
521
+ Random number generator for reproducibility. Used for tie breaks in
522
+ the `medial_axis <https://scikit-image.org/docs/stable/api/skimage.morphology.html#skimage.morphology.medial_axis>`_ function.
519
523
520
524
Attributes
521
525
----------
@@ -525,8 +529,17 @@ def medskel(self, verbose=False, save_png=False):
525
529
The distance transform used to create the skeletons.
526
530
'''
527
531
528
- self .skeleton , self .medial_axis_distance = \
529
- medial_axis (self .mask , return_distance = True )
532
+ if rng is None :
533
+ rng = np .random .default_rng ()
534
+
535
+ # The kwarg for rng has changed in different skimage versions.
536
+ try :
537
+ self .skeleton , self .medial_axis_distance = \
538
+ medial_axis (self .mask , return_distance = True , rng = rng )
539
+ except TypeError :
540
+ self .skeleton , self .medial_axis_distance = \
541
+ medial_axis (self .mask , return_distance = True , random_state = rng )
542
+
530
543
self .medial_axis_distance = \
531
544
self .medial_axis_distance * self .skeleton * u .pix
532
545
# Delete connection smaller than 2 pixels wide. Such a small
@@ -551,11 +564,18 @@ def medskel(self, verbose=False, save_png=False):
551
564
if in_ipynb ():
552
565
p .clf ()
553
566
554
- def analyze_skeletons (self , prune_criteria = 'all' , relintens_thresh = 0.2 ,
555
- nbeam_lengths = 5 , branch_nbeam_lengths = 3 ,
556
- skel_thresh = None , branch_thresh = None ,
567
+ def analyze_skeletons (self ,
568
+ nthreads = 1 ,
569
+ prune_criteria = 'all' ,
570
+ relintens_thresh = 0.2 ,
571
+ nbeam_lengths = 5 ,
572
+ branch_nbeam_lengths = 3 ,
573
+ skel_thresh = None ,
574
+ branch_thresh = None ,
557
575
max_prune_iter = 10 ,
558
- verbose = False , save_png = False , save_name = None ):
576
+ verbose = False ,
577
+ save_png = False ,
578
+ save_name = None ):
559
579
'''
560
580
561
581
Prune skeleton structure and calculate the branch and longest-path
@@ -564,6 +584,8 @@ def analyze_skeletons(self, prune_criteria='all', relintens_thresh=0.2,
564
584
565
585
Parameters
566
586
----------
587
+ nthreads : int, optional
588
+ Number of threads to use to parallelize the skeleton analysis.
567
589
prune_criteria : {'all', 'intensity', 'length'}, optional
568
590
Choose the property to base pruning on. 'all' requires that the
569
591
branch fails to satisfy the length and relative intensity checks.
@@ -638,25 +660,26 @@ def analyze_skeletons(self, prune_criteria='all', relintens_thresh=0.2,
638
660
# Relabel after deleting short skeletons.
639
661
labels , num = nd .label (self .skeleton , eight_con ())
640
662
663
+
641
664
self .filaments = [Filament2D (np .where (labels == lab ),
642
665
converter = self .converter ) for lab in
643
666
range (1 , num + 1 )]
644
667
645
- self .number_of_filaments = num
646
-
647
668
# Now loop over the skeleton analysis for each filament object
648
- for n , fil in enumerate (self .filaments ):
649
- savename = "{0}_{1}" .format (save_name , n )
650
- if verbose :
651
- print ("Filament: %s / %s" % (n + 1 , self .number_of_filaments ))
669
+ with concurrent .futures .ProcessPoolExecutor (nthreads ) as executor :
670
+ futures = [executor .submit (fil .skeleton_analysis , self .image ,
671
+ verbose = verbose ,
672
+ save_png = save_png ,
673
+ save_name = save_name ,
674
+ prune_criteria = prune_criteria ,
675
+ relintens_thresh = relintens_thresh ,
676
+ branch_thresh = self .branch_thresh ,
677
+ max_prune_iter = max_prune_iter ,
678
+ return_self = True )
679
+ for fil in self .filaments ]
680
+ self .filaments = [future .result () for future in futures ]
652
681
653
- fil .skeleton_analysis (self .image , verbose = verbose ,
654
- save_png = save_png ,
655
- save_name = savename ,
656
- prune_criteria = prune_criteria ,
657
- relintens_thresh = relintens_thresh ,
658
- branch_thresh = self .branch_thresh ,
659
- max_prune_iter = max_prune_iter )
682
+ self .number_of_filaments = num
660
683
661
684
self .array_offsets = [fil .pixel_extents for fil in self .filaments ]
662
685
@@ -749,7 +772,9 @@ def end_pts(self):
749
772
'''
750
773
return [fil .end_pts for fil in self .filaments ]
751
774
752
- def exec_rht (self , radius = 10 * u .pix ,
775
+ def exec_rht (self ,
776
+ nthreads = 1 ,
777
+ radius = 10 * u .pix ,
753
778
ntheta = 180 , background_percentile = 25 ,
754
779
branches = False , min_branch_length = 3 * u .pix ,
755
780
verbose = False , save_png = False , save_name = None ):
@@ -774,6 +799,8 @@ def exec_rht(self, radius=10 * u.pix,
774
799
775
800
Parameters
776
801
----------
802
+ nthreads : int, optional
803
+ The number of threads to use.
777
804
radius : int
778
805
Sets the patch size that the RHT uses.
779
806
ntheta : int, optional
@@ -792,6 +819,7 @@ def exec_rht(self, radius=10 * u.pix,
792
819
save_name : str, optional
793
820
Prefix for the saved plots.
794
821
822
+
795
823
Attributes
796
824
----------
797
825
rht_curvature : dict
@@ -813,23 +841,35 @@ def exec_rht(self, radius=10 * u.pix,
813
841
if save_name is None :
814
842
save_name = self .save_name
815
843
816
- for n , fil in enumerate (self .filaments ):
817
- if verbose :
818
- print ("Filament: %s / %s" % (n + 1 , self .number_of_filaments ))
819
844
820
- if branches :
821
- fil .rht_branch_analysis (radius = radius ,
845
+ if branches :
846
+ with concurrent .futures .ProcessPoolExecutor (nthreads ) as executor :
847
+ futures = [executor .submit (fil .rht_branch_analysis ,
848
+ radius = radius ,
822
849
ntheta = ntheta ,
823
850
background_percentile = background_percentile ,
824
- min_branch_length = min_branch_length )
851
+ min_branch_length = min_branch_length ,
852
+ return_self = True )
853
+ for fil in self .filaments ]
854
+ self .filaments = [future .result () for future in futures ]
825
855
826
- else :
827
- fil .rht_analysis (radius = radius , ntheta = ntheta ,
828
- background_percentile = background_percentile )
829
856
830
- if verbose :
857
+ else :
858
+ with concurrent .futures .ProcessPoolExecutor (nthreads ) as executor :
859
+ futures = [executor .submit (fil .rht_analysis ,
860
+ radius = radius ,
861
+ ntheta = ntheta ,
862
+ background_percentile = background_percentile ,
863
+ return_self = True )
864
+ for fil in self .filaments ]
865
+ self .filaments = [future .result () for future in futures ]
866
+
867
+
868
+ if verbose :
869
+ for n , fil in enumerate (self .filaments ):
870
+
831
871
if save_png :
832
- savename = "{0}_{1}_rht.png" .format (save_name , n )
872
+ save_name = "{0}_{1}_rht.png" .format (save_name , n )
833
873
else :
834
874
save_name = None
835
875
fil .plot_rht_distrib (save_name = save_name )
@@ -886,7 +926,9 @@ def pre_recombine_mask_corners(self):
886
926
'''
887
927
return self ._pre_recombine_mask_corners
888
928
889
- def find_widths (self , max_dist = 10 * u .pix ,
929
+ def find_widths (self ,
930
+ nthreads = 1 ,
931
+ max_dist = 10 * u .pix ,
890
932
pad_to_distance = 0 * u .pix ,
891
933
fit_model = 'gaussian_bkg' ,
892
934
fitter = None ,
@@ -915,12 +957,8 @@ def find_widths(self, max_dist=10 * u.pix,
915
957
916
958
Parameters
917
959
----------
918
- image : `~astropy.unit.Quantity` or `~numpy.ndarray`
919
- The image from which the filament was extracted.
920
- all_skeleton_array : np.ndarray
921
- An array with the skeletons of other filaments. This is used to
922
- avoid double-counting pixels in the radial profiles in nearby
923
- filaments.
960
+ nthreads : int, optional
961
+ Number of threads to use.
924
962
max_dist : `~astropy.units.Quantity`, optional
925
963
Largest radius around the skeleton to create the profile from. This
926
964
can be given in physical, angular, or physical units.
@@ -967,23 +1005,26 @@ def find_widths(self, max_dist=10 * u.pix,
967
1005
if save_name is None :
968
1006
save_name = self .save_name
969
1007
970
- for n , fil in enumerate (self .filaments ):
1008
+ with concurrent .futures .ProcessPoolExecutor (nthreads ) as executor :
1009
+ futures = [executor .submit (fil .width_analysis , self .image ,
1010
+ all_skeleton_array = self .skeleton ,
1011
+ max_dist = max_dist ,
1012
+ pad_to_distance = pad_to_distance ,
1013
+ fit_model = fit_model ,
1014
+ fitter = fitter , try_nonparam = try_nonparam ,
1015
+ use_longest_path = use_longest_path ,
1016
+ add_width_to_length = add_width_to_length ,
1017
+ deconvolve_width = deconvolve_width ,
1018
+ beamwidth = self .beamwidth ,
1019
+ fwhm_function = fwhm_function ,
1020
+ chisq_max = chisq_max ,
1021
+ return_self = True ,
1022
+ ** kwargs )
1023
+ for fil in self .filaments ]
1024
+ self .filaments = [future .result () for future in futures ]
971
1025
972
- if verbose :
973
- print ("Filament: %s / %s" % (n + 1 , self .number_of_filaments ))
974
-
975
- fil .width_analysis (self .image , all_skeleton_array = self .skeleton ,
976
- max_dist = max_dist ,
977
- pad_to_distance = pad_to_distance ,
978
- fit_model = fit_model ,
979
- fitter = fitter , try_nonparam = try_nonparam ,
980
- use_longest_path = use_longest_path ,
981
- add_width_to_length = add_width_to_length ,
982
- deconvolve_width = deconvolve_width ,
983
- beamwidth = self .beamwidth ,
984
- fwhm_function = fwhm_function ,
985
- chisq_max = chisq_max ,
986
- ** kwargs )
1026
+
1027
+ for n , fil in enumerate (self .filaments ):
987
1028
988
1029
if verbose :
989
1030
if save_png :
@@ -1409,7 +1450,10 @@ def save_fits(self, save_name=None,
1409
1450
out_hdu .writeto ("{0}_image_output.fits" .format (save_name ),
1410
1451
** kwargs )
1411
1452
1412
- def save_stamp_fits (self , save_name = None , pad_size = 20 * u .pix ,
1453
+ def save_stamp_fits (self ,
1454
+ image_dict = None ,
1455
+ save_name = None ,
1456
+ pad_size = 20 * u .pix ,
1413
1457
model_kwargs = {},
1414
1458
** kwargs ):
1415
1459
'''
@@ -1421,6 +1465,10 @@ def save_stamp_fits(self, save_name=None, pad_size=20 * u.pix,
1421
1465
1422
1466
Parameters
1423
1467
----------
1468
+ image_dict : dict, optional
1469
+ Dictionary of arrays to save matching the pixel extents of each filament.
1470
+ The shape of each array *must* be the same shape as the original image
1471
+ given to `~FilFinder2D`.
1424
1472
save_name : str, optional
1425
1473
The prefix for the saved file. If None, the save name specified
1426
1474
when `~FilFinder2D` was first called.
@@ -1436,10 +1484,21 @@ def save_stamp_fits(self, save_name=None, pad_size=20 * u.pix,
1436
1484
else :
1437
1485
save_name = os .path .splitext (save_name )[0 ]
1438
1486
1487
+ if image_dict is not None :
1488
+ for ii , key in enumerate (image_dict ):
1489
+ this_image = image_dict [key ]
1490
+ if this_image .shape != self .image .shape :
1491
+ raise ValueError ("All images in image_dict must be same shape as fil.image. "
1492
+ f"For index { ii } , found shape { this_image .shape } not { self .image .shape } " )
1493
+
1494
+
1439
1495
for n , fil in enumerate (self .filaments ):
1440
1496
1441
- savename = "{0 }_stamp_{1 }.fits". format ( save_name , n )
1497
+ savename = f" { save_name } _stamp_{ n } .fits"
1442
1498
1443
- fil .save_fits (savename , self .image , pad_size = pad_size ,
1499
+ fil .save_fits (savename ,
1500
+ self .image ,
1501
+ image_dict = image_dict ,
1502
+ pad_size = pad_size ,
1444
1503
model_kwargs = model_kwargs ,
1445
1504
** kwargs )
0 commit comments