Skip to content

Commit 743a185

Browse files
committed
Updated TS2Vec docs
1 parent a715ba5 commit 743a185

File tree

1 file changed

+12
-0
lines changed
  • aeon/transformations/collection/contrastive_based

1 file changed

+12
-0
lines changed

aeon/transformations/collection/contrastive_based/_ts2vec.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
class TS2Vec(BaseCollectionTransformer):
1414
"""TS2Vec Transformer.
1515
16+
TS2Vec [1]_ is a self-supervised model designed to learn universal representations
17+
of time series data. It employs a hierarchical contrastive learning framework
18+
that captures both local and global temporal dependencies. This approach
19+
enables TS2Vec to generate robust representations for each timestamp
20+
and allows for flexible aggregation to obtain representations for arbitrary
21+
subsequences.
22+
1623
Parameters
1724
----------
1825
output_dim : int, default=320
@@ -39,6 +46,8 @@ class TS2Vec(BaseCollectionTransformer):
3946
that sets n_iters to 200 for datasets with size <= 100000, and 600 otherwise.
4047
verbose : bool, default=False
4148
Whether to print the training loss after each epoch.
49+
after_epoch_callback : callable, default=None
50+
A callback function to be called after each epoch.
4251
device : None or str, default=None
4352
The device to use for training and inference. If None, it will automatically
4453
select 'cuda' if available, otherwise 'cpu'.
@@ -92,6 +101,7 @@ def __init__(
92101
n_iters=None,
93102
device=None,
94103
n_jobs=1,
104+
after_epoch_callback=None,
95105
verbose=False,
96106
):
97107
self.output_dim = output_dim
@@ -107,6 +117,7 @@ def __init__(
107117
self.verbose = verbose
108118
self.n_epochs = n_epochs
109119
self.n_iters = n_iters
120+
self.after_epoch_callback = after_epoch_callback
110121
super().__init__()
111122

112123
def _transform(self, X, y=None):
@@ -138,6 +149,7 @@ def _fit(self, X, y=None):
138149
batch_size=self.batch_size,
139150
max_train_length=self.max_train_length,
140151
temporal_unit=self.temporal_unit,
152+
after_epoch_callback=self.after_epoch_callback,
141153
)
142154
self.loss_ = self._ts2vec.fit(
143155
X.transpose(0, 2, 1),

0 commit comments

Comments
 (0)