Skip to content

Commit bd90715

Browse files
committed
Implement Daubechies MODWT
1 parent a67a5dd commit bd90715

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

pyriodicity/detectors/robustperiod.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Union
44

55
import numpy as np
6+
import pywt
67
from numpy.typing import ArrayLike, NDArray
78
from scipy.sparse import dia_matrix, eye
89
from scipy.sparse.linalg import spsolve
@@ -100,6 +101,7 @@ def detect(
100101
x: ArrayLike,
101102
lamb: Union[float, str] = "ravn-uhlig",
102103
c: float = 1.5,
104+
db_n: int = 10,
103105
) -> NDArray:
104106
"""
105107
Find periods in the given series.
@@ -119,12 +121,21 @@ def detect(
119121
The constant threshold that determines the robustness of the Huber function.
120122
A smaller value makes the Huber function more sensitive to outliers. Huber
121123
recommends using a value between 1 and 2 [3]_.
124+
db_n : int, default = 10
125+
The number of vanishing moments for the Daubechies wavelet [4]_ used to
126+
compute the Maximal Overlap Discrete Wavelet Transform (MODWT) [5]_. Must
127+
be an integer between 1 and 38, inclusive.
122128
123129
Returns
124130
-------
125131
NDArray
126132
List of detected periods.
127133
134+
Raises
135+
------
136+
AssertionError
137+
If `db_n` is not between 1 and 38, inclusive.
138+
128139
References
129140
----------
130141
.. [1] Hodrick, R. J., & Prescott, E. C. (1997).
@@ -137,8 +148,17 @@ def detect(
137148
https://doi.org/10.1162/003465302317411604
138149
.. [3] Huber, P. J., & Ronchetti, E. (2009). Robust Statistics. Wiley.
139150
https://doi.org/10.1002/9780470434697
151+
.. [4] Daubechies, I. (1992). Ten lectures on wavelets. Society for industrial
152+
and applied mathematics.
153+
https://doi.org/10.1137/1.9781611970104
154+
.. [5] Percival, D. B. (2000). Wavelet methods for time series analysis.
155+
Cambridge University Press.
156+
https://doi.org/10.1017/CBO9780511841040
140157
"""
141158

159+
# Validate the db_n parameter
160+
assert 1 <= db_n <= 38, "Invalid db_n parameter value: '{}'".format(db_n)
161+
142162
# Preprocess the data
143163
lamb = RobustPeriod.LambdaSelection(lamb) if isinstance(lamb, str) else lamb
144164
y = RobustPeriod._preprocess(x, lamb, c)
@@ -254,3 +274,32 @@ def _huber(x: ArrayLike, c: float) -> ArrayLike:
254274
An array-like object with the Huber function applied element-wise.
255275
"""
256276
return np.sign(x) * np.minimum(np.abs(x), c)
277+
278+
@staticmethod
279+
def _modwt(x: ArrayLike, db_n: int, level: int) -> NDArray:
280+
"""
281+
Compute the Maximal Overlap Discrete Wavelet Transform (MODWT) of a series using
282+
the Daubechies wavelet.
283+
284+
Parameters
285+
----------
286+
x : array_like
287+
Input data to be transformed. Must be squeezable to 1-d.
288+
db_n : int
289+
The number of vanishing moments for the Daubechies wavelet. Must be an
290+
integer between 1 and 38, inclusive.
291+
level : int
292+
The number of decomposition steps to perform.
293+
294+
Returns
295+
-------
296+
NDArray
297+
The MODWT coefficients of the input data.
298+
"""
299+
300+
# Pad the input data to the nearest 2^level multiple
301+
padding = (2**level - (len(x) % 2**level)) % 2**level
302+
y = np.pad(x, (0, padding), "wrap")
303+
304+
# Compute the Maximal Overlap Discrete Wavelet Transform
305+
return pywt.swt(y, "db{}".format(db_n), level, norm=True)

0 commit comments

Comments
 (0)