13
13
class TS2Vec (BaseCollectionTransformer ):
14
14
"""TS2Vec Transformer.
15
15
16
+ TS2Vec [1]_ is a self-supervised model designed to learn universal representations
17
+ of time series data. It employs a hierarchical contrastive learning framework
18
+ that captures both local and global temporal dependencies. This approach
19
+ enables TS2Vec to generate robust representations for each timestamp
20
+ and allows for flexible aggregation to obtain representations for arbitrary
21
+ subsequences.
22
+
16
23
Parameters
17
24
----------
18
25
output_dim : int, default=320
@@ -39,6 +46,8 @@ class TS2Vec(BaseCollectionTransformer):
39
46
that sets n_iters to 200 for datasets with size <= 100000, and 600 otherwise.
40
47
verbose : bool, default=False
41
48
Whether to print the training loss after each epoch.
49
+ after_epoch_callback : callable, default=None
50
+ A callback function to be called after each epoch.
42
51
device : None or str, default=None
43
52
The device to use for training and inference. If None, it will automatically
44
53
select 'cuda' if available, otherwise 'cpu'.
@@ -92,6 +101,7 @@ def __init__(
92
101
n_iters = None ,
93
102
device = None ,
94
103
n_jobs = 1 ,
104
+ after_epoch_callback = None ,
95
105
verbose = False ,
96
106
):
97
107
self .output_dim = output_dim
@@ -107,6 +117,7 @@ def __init__(
107
117
self .verbose = verbose
108
118
self .n_epochs = n_epochs
109
119
self .n_iters = n_iters
120
+ self .after_epoch_callback = after_epoch_callback
110
121
super ().__init__ ()
111
122
112
123
def _transform (self , X , y = None ):
@@ -138,6 +149,7 @@ def _fit(self, X, y=None):
138
149
batch_size = self .batch_size ,
139
150
max_train_length = self .max_train_length ,
140
151
temporal_unit = self .temporal_unit ,
152
+ after_epoch_callback = self .after_epoch_callback ,
141
153
)
142
154
self .loss_ = self ._ts2vec .fit (
143
155
X .transpose (0 , 2 , 1 ),
0 commit comments