-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathsubsample.py
85 lines (69 loc) · 3.52 KB
/
subsample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import numpy as np
import torch
class MaskFunc:
"""
MaskFunc creates a sub-sampling mask of a given shape.
The mask selects a subset of columns from the input k-space data. If the k-space data has N
columns, the mask picks out:
1. N_low_freqs = (N * center_fraction) columns in the center corresponding to
low-frequencies
2. The other columns are selected uniformly at random with a probability equal to:
prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs).
This ensures that the expected number of columns selected is equal to (N / acceleration)
It is possible to use multiple center_fractions and accelerations, in which case one possible
(center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is
called.
For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there
is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50%
probability that 8-fold acceleration with 4% center fraction is selected.
"""
def __init__(self, center_fractions, accelerations):
"""
Args:
center_fractions (List[float]): Fraction of low-frequency columns to be retained.
If multiple values are provided, then one of these numbers is chosen uniformly
each time.
accelerations (List[int]): Amount of under-sampling. This should have the same length
as center_fractions. If multiple values are provided, then one of these is chosen
uniformly each time. An acceleration of 4 retains 25% of the columns, but they may
not be spaced evenly.
"""
if len(center_fractions) != len(accelerations):
raise ValueError('Number of center fractions should match number of accelerations')
self.center_fractions = center_fractions
self.accelerations = accelerations
self.rng = np.random.RandomState()
def __call__(self, shape, seed=None):
"""
Args:
shape (iterable[int]): The shape of the mask to be created. The shape should have
at least 3 dimensions. Samples are drawn along the second last dimension.
seed (int, optional): Seed for the random number generator. Setting the seed
ensures the same mask is generated each time for the same shape.
Returns:
torch.Tensor: A mask of the specified shape.
"""
if len(shape) < 3:
raise ValueError('Shape should have 3 or more dimensions')
seed=1
self.rng.seed(seed)
num_cols = shape[-2]
choice = self.rng.randint(0, len(self.accelerations))
center_fraction = self.center_fractions[choice]
acceleration = self.accelerations[choice]
# Create the mask
num_low_freqs = int(round(num_cols * center_fraction))
prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs)
mask = self.rng.uniform(size=num_cols) < prob
pad = (num_cols - num_low_freqs + 1) // 2
mask[pad:pad + num_low_freqs] = True
# Reshape the mask
mask_shape = [1 for _ in shape]
mask_shape[-2] = num_cols
mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32))
return mask