Skip to content

Commit 21d262d

Browse files
committed
now handles the clipping & added docstring
1 parent 006b53e commit 21d262d

File tree

1 file changed

+236
-40
lines changed

1 file changed

+236
-40
lines changed

sunkit_image/coalignment.py

Lines changed: 236 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,218 @@
11
"""
2-
This module provides routines for the co-alignment of images and
3-
`~sunpy.map.mapsequence.MapSequence` objects through template matching.
2+
This module provides routines for the co-alignment of images through template
3+
matching.
44
"""
55

66
import warnings
77

8+
import astropy.units as u
89
import numpy as np
910
import sunpy.map
10-
from scipy.ndimage import shift
1111
from skimage.feature import match_template
1212
from sunpy.util.exceptions import SunpyUserWarning
1313

1414
__all__ = ["coalignment_interface", "register_coalignment_method", "registered_methods"]
1515

16-
## This will be further replaced, once the decorator structure is in place.
16+
## This dictionary will be further replaced in a more appropriate location, once the decorator structure is in place.
1717
registered_methods = {}
1818

1919

2020
def register_coalignment_method(name, method):
21+
"""
22+
Registers a coalignment method to be used by the coalignment interface.
23+
24+
Parameters
25+
----------
26+
name : str
27+
The name of the coalignment method.
28+
method : callable
29+
The function implementing the coalignment method.
30+
"""
2131
registered_methods[name] = method
2232

2333

2434
############################ Coalignment Interface begins #################################
25-
def replace_nan_with_mean(array):
26-
mean_value = np.nanmean(array)
27-
array[np.isnan(array)] = mean_value
2835

29-
return array
3036

31-
32-
def convert_map_to_array(map_obj):
33-
# Convert map object to array
34-
return np.array(map_obj.data)
37+
@u.quantity_input
38+
def _clip_edges(data, yclips: u.pix, xclips: u.pix):
39+
"""
40+
Clips off the "y" and "x" edges of a 2D array according to a list of pixel
41+
values. This function is useful for removing data at the edge of 2d images
42+
that may be affected by shifts from solar de- rotation and layer co-
43+
registration, leaving an image unaffected by edge effects.
44+
45+
Parameters
46+
----------
47+
data : `numpy.ndarray`
48+
A numpy array of shape ``(ny, nx)``.
49+
yclips : `astropy.units.Quantity`
50+
The amount to clip in the y-direction of the data. Has units of
51+
pixels, and values should be whole non-negative numbers.
52+
xclips : `astropy.units.Quantity`
53+
The amount to clip in the x-direction of the data. Has units of
54+
pixels, and values should be whole non-negative numbers.
55+
56+
Returns
57+
-------
58+
`numpy.ndarray`
59+
A 2D image with edges clipped off according to ``yclips`` and ``xclips``
60+
arrays.
61+
"""
62+
ny = data.shape[0]
63+
nx = data.shape[1]
64+
# The purpose of the int below is to ensure integer type since by default
65+
# astropy quantities are converted to floats.
66+
return data[int(yclips[0].value) : ny - int(yclips[1].value), int(xclips[0].value) : nx - int(xclips[1].value)]
67+
68+
69+
@u.quantity_input
70+
def _calculate_clipping(y: u.pix, x: u.pix):
71+
"""
72+
Return the upper and lower clipping values for the "y" and "x" directions.
73+
74+
Parameters
75+
----------
76+
y : `astropy.units.Quantity`
77+
An array of pixel shifts in the y-direction for an image.
78+
x : `astropy.units.Quantity`
79+
An array of pixel shifts in the x-direction for an image.
80+
81+
Returns
82+
-------
83+
`tuple`
84+
The tuple is of the form ``([y0, y1], [x0, x1])``.
85+
The number of (integer) pixels that need to be clipped off at each
86+
edge in an image. The first element in the tuple is a list that gives
87+
the number of pixels to clip in the y-direction. The first element in
88+
that list is the number of rows to clip at the lower edge of the image
89+
in y. The clipped image has "clipping[0][0]" rows removed from its
90+
lower edge when compared to the original image. The second element in
91+
that list is the number of rows to clip at the upper edge of the image
92+
in y. The clipped image has "clipping[0][1]" rows removed from its
93+
upper edge when compared to the original image. The second element in
94+
the "clipping" tuple applies similarly to the x-direction (image
95+
columns). The parameters ``y0, y1, x0, x1`` have the type
96+
`~astropy.units.Quantity`.
97+
"""
98+
return (
99+
[_lower_clip(y.value), _upper_clip(y.value)] * u.pix,
100+
[_lower_clip(x.value), _upper_clip(x.value)] * u.pix,
101+
)
102+
103+
104+
def _upper_clip(z):
105+
"""
106+
Find smallest integer bigger than all the positive entries in the input
107+
array.
108+
"""
109+
zupper = 0
110+
zcond = z >= 0
111+
if np.any(zcond):
112+
zupper = int(np.max(np.ceil(z[zcond])))
113+
return zupper
114+
115+
116+
def _lower_clip(z):
117+
"""
118+
Find smallest positive integer bigger than the absolute values of the
119+
negative entries in the input array.
120+
"""
121+
zlower = 0
122+
zcond = z <= 0
123+
if np.any(zcond):
124+
zlower = int(np.max(np.ceil(-z[zcond])))
125+
return zlower
35126

36127

37128
def convert_array_to_map(array_obj, map_obj):
129+
"""
130+
Convert a 2D numpy array to a sunpy Map object using the header of a given
131+
map object.
132+
133+
Parameters
134+
----------
135+
array_obj : `numpy.ndarray`
136+
The 2D numpy array to be converted.
137+
map_obj : `sunpy.map.Map`
138+
The map object whose header is to be used for the new map.
139+
140+
Returns
141+
-------
142+
`sunpy.map.Map`
143+
A new sunpy map object with the data from `array_obj` and the header from `map_obj`.
144+
"""
38145
header = map_obj.meta.copy()
39146
header["crpix1"] -= array_obj.shape[1] / 2.0 - map_obj.data.shape[1] / 2.0
40147
header["crpix2"] -= array_obj.shape[0] / 2.0 - map_obj.data.shape[0] / 2.0
41148
return sunpy.map.Map(array_obj, header)
42149

43150

44-
def coalignment_interface(method, input_map, template_map):
151+
def coalignment_interface(method, input_map, template_map, handle_nan=None):
152+
"""
153+
Interface for performing image coalignment using a specified method.
154+
155+
Parameters
156+
----------
157+
method : str
158+
The name of the registered coalignment method to use.
159+
input_map : `sunpy.map.Map`
160+
The input map to be coaligned.
161+
template_map : `sunpy.map.Map`
162+
The template map to which the input map is to be coaligned.
163+
handle_nan : callable, optional
164+
Function to handle NaN values in the input and template arrays.
165+
166+
Returns
167+
-------
168+
`sunpy.map.Map`
169+
The coaligned input map.
170+
171+
Raises
172+
------
173+
ValueError
174+
If the specified method is not registered.
175+
"""
45176
if method not in registered_methods:
46177
msg = f"Method {method} is not a registered method. Please register before using."
47178
raise ValueError(msg)
48-
input_array = np.float64(convert_map_to_array(input_map))
49-
template_array = np.float64(convert_map_to_array(template_map))
179+
input_array = np.float64(input_map.data)
180+
template_array = np.float64(template_map.data)
50181

51182
# Warn user if any NANs, Infs, etc are present in the input or the template array
52183
if not np.all(np.isfinite(input_array)):
53-
warnings.warn(
54-
"The layer image has nonfinite entries. "
55-
"This could cause errors when calculating shift between two "
56-
"images. Please make sure there are no infinity or "
57-
"Not a Number values. For instance, replacing them with a "
58-
"local mean.",
59-
SunpyUserWarning,
60-
stacklevel=3,
61-
)
62-
## By default replace with mean
63-
input_array = replace_nan_with_mean(input_array)
184+
if not handle_nan:
185+
warnings.warn(
186+
"The layer image has nonfinite entries. "
187+
"This could cause errors when calculating shift between two "
188+
"images. Please make sure there are no infinity or "
189+
"Not a Number values. For instance, replacing them with a "
190+
"local mean.",
191+
SunpyUserWarning,
192+
stacklevel=3,
193+
)
194+
else:
195+
input_array = handle_nan(input_array)
64196

65197
if not np.all(np.isfinite(template_array)):
66-
warnings.warn(
67-
"The template image has nonfinite entries. "
68-
"This could cause errors when calculating shift between two "
69-
"images. Please make sure there are no infinity or "
70-
"Not a Number values. For instance, replacing them with a "
71-
"local mean.",
72-
SunpyUserWarning,
73-
stacklevel=3,
74-
)
75-
## By default replace with mean
76-
template_array = replace_nan_with_mean(template_array)
77-
78-
coaligned_input_array = registered_methods[method](input_array, template_array)
198+
if not handle_nan:
199+
warnings.warn(
200+
"The template image has nonfinite entries. "
201+
"This could cause errors when calculating shift between two "
202+
"images. Please make sure there are no infinity or "
203+
"Not a Number values. For instance, replacing them with a "
204+
"local mean.",
205+
SunpyUserWarning,
206+
stacklevel=3,
207+
)
208+
else:
209+
template_array = handle_nan(template_array)
210+
211+
shifts = registered_methods[method](input_array, template_array)
212+
# Calculate the clipping required
213+
yclips, xclips = _calculate_clipping(shifts["x"] * u.pix, shifts["y"] * u.pix)
214+
# Clip 'em
215+
coaligned_input_array = _clip_edges(input_array, yclips, xclips)
79216
return convert_array_to_map(coaligned_input_array, input_map)
80217

81218

@@ -84,12 +221,43 @@ def coalignment_interface(method, input_map, template_map):
84221

85222
######################################## Defining a method ###########################
86223
def _parabolic_turning_point(y):
224+
"""
225+
Calculate the turning point of a parabola given three points.
226+
227+
Parameters
228+
----------
229+
y : `numpy.ndarray`
230+
An array of three points defining the parabola.
231+
232+
Returns
233+
-------
234+
float
235+
The x-coordinate of the turning point.
236+
"""
87237
numerator = -0.5 * y.dot([-1, 0, 1])
88238
denominator = y.dot([1, -2, 1])
89239
return numerator / denominator
90240

91241

92242
def _get_correlation_shifts(array):
243+
"""
244+
Calculate the shifts in x and y directions based on the correlation array.
245+
246+
Parameters
247+
----------
248+
array : `numpy.ndarray`
249+
A 2D array representing the correlation values.
250+
251+
Returns
252+
-------
253+
tuple
254+
The shifts in y and x directions.
255+
256+
Raises
257+
------
258+
ValueError
259+
If the input array dimensions are greater than 3 in any direction.
260+
"""
93261
ny, nx = array.shape
94262
if nx > 3 or ny > 3:
95263
msg = "Input array dimension should not be greater than 3 in any dimension."
@@ -105,6 +273,19 @@ def _get_correlation_shifts(array):
105273

106274

107275
def _find_best_match_location(corr):
276+
"""
277+
Find the best match location in the correlation array.
278+
279+
Parameters
280+
----------
281+
corr : `numpy.ndarray`
282+
The correlation array.
283+
284+
Returns
285+
-------
286+
tuple
287+
The best match location in the y and x directions.
288+
"""
108289
ij = np.unravel_index(np.argmax(corr), corr.shape)
109290
cor_max_x, cor_max_y = ij[::-1]
110291

@@ -122,13 +303,28 @@ def _find_best_match_location(corr):
122303

123304

124305
def match_template_coalign(input_array, template_array):
306+
"""
307+
Perform coalignment by matching the template array to the input array.
308+
309+
Parameters
310+
----------
311+
input_array : `numpy.ndarray`
312+
The input 2D array to be coaligned.
313+
template_array : `numpy.ndarray`
314+
The template 2D array to align to.
315+
316+
Returns
317+
-------
318+
dict
319+
A dictionary containing the shifts in x and y directions.
320+
"""
125321
corr = match_template(input_array, template_array)
126322

127323
# Find the best match location
128324
y_shift, x_shift = _find_best_match_location(corr)
129325

130326
# Apply the shift to get the coaligned input array
131-
return shift(input_array, shift=[y_shift, x_shift])
327+
return {"x": x_shift, "y": y_shift}
132328

133329

134330
################################ Registering the defined method ########################

0 commit comments

Comments
 (0)