Skip to content

Commit cb30da3

Browse files
authored
BUG: raise a clear error when cwt() is given a discrete wavelet (gh-776) (#849)
Passing a discrete wavelet, e.g. cwt(data, scales, 'coif1'), crashed with a confusing `AttributeError: 'Wavelet' object has no attribute 'complex_cwt'`, because `complex_cwt` exists only on ContinuousWavelet. Validate that the resolved wavelet is continuous and otherwise raise a ValueError that names the offending wavelet and points to pywt.wavelist(kind='continuous').
2 parents ffb1c8c + 7ef1011 commit cb30da3

2 files changed

Lines changed: 19 additions & 0 deletions

File tree

pywt/_cwt.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1,
111111
dt_cplx = np.result_type(dt, np.complex64)
112112
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
113113
wavelet = DiscreteContinuousWavelet(wavelet)
114+
if not isinstance(wavelet, ContinuousWavelet):
115+
raise ValueError(
116+
f"cwt() requires a continuous wavelet, but {wavelet.name!r} is a "
117+
f"discrete wavelet. Use a continuous wavelet such as those returned "
118+
f"by pywt.wavelist(kind='continuous') (e.g. 'morl', 'mexh', 'cmor')."
119+
)
114120

115121
scales = np.atleast_1d(scales)
116122
if np.any(scales <= 0):

pywt/tests/test_cwt_wavelets.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,16 @@ def test_continuous_wavelet_pickle(tmpdir):
480480
wavelet2 = pickle.load(f)
481481
assert isinstance(wavelet2, pywt.ContinuousWavelet)
482482
assert wavelet2.name == wavelet.name
483+
484+
485+
def test_cwt_discrete_wavelet_raises():
486+
# A discrete wavelet such as 'coif1' has no continuous form; cwt should
487+
# raise a clear error rather than an opaque AttributeError (gh-776).
488+
data = np.ones(100)
489+
for bad in ['coif1', 'db2', pywt.Wavelet('coif1')]:
490+
with pytest.raises(ValueError, match='continuous wavelet'):
491+
pywt.cwt(data, [1, 2], bad)
492+
493+
# a continuous wavelet still works
494+
out, _ = pywt.cwt(data, [1, 2], 'morl')
495+
assert out.shape == (2, 100)

0 commit comments

Comments
 (0)