This repository contains the official implementation for the paper Oscillatory State-Space Models by T. Konstantin Rusch and Daniela Rus.
This repository is an extension of https://github.com/Benjamin-Walker/log-neural-cdes.
We propose Linear Oscillatory State-Space models (LinOSS) for efficiently learning on long sequences. Inspired by cortical dynamics of biological neural networks, we base our proposed LinOSS model on a system of forced harmonic oscillators. A stable discretization, integrated over time using fast associative parallel scans, yields the proposed state-space model.
This repository is implemented in python 3.10 and uses Jax as their machine learning framework.
The code for preprocessing the datasets, training LinOSS, S5, LRU, NCDE, NRDE, and Log-NCDE uses the following packages:
jax
andjaxlib
for automatic differentiation.equinox
for constructing neural networks.optax
for neural network optimisers.diffrax
for differential equation solvers.signax
for calculating the signature.sktime
for handling time series data in ARFF format.tqdm
for progress bars.matplotlib
for plotting.pre-commit
for code formatting.
conda create -n LinOSS python=3.10
conda activate LinOSS
conda install pre-commit=3.7.1 sktime=0.30.1 tqdm=4.66.4 matplotlib=3.8.4 -c conda-forge
# Substitue for correct Jax pip install: https://jax.readthedocs.io/en/latest/installation.html
pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.11.4 optax==0.2.2 diffrax==0.5.1 signax==0.1.1
If running data_dir/process_uea.py
throws this error: No module named 'packaging'
Then run: pip install packaging
After installing the requirements, run pre-commit install
to install the pre-commit hooks.
The folder data_dir
contains the scripts for downloading data, preprocessing the data, and creating dataloaders and
datasets. Raw data should be downloaded into the data_dir/raw
folder. Processed data should be saved into the data_dir/processed
folder in the following format:
processed/{collection}/{dataset_name}/data.pkl,
processed/{collection}/{dataset_name}/labels.pkl,
processed/{collection}/{dataset_name}/original_idxs.pkl (if the dataset has original data splits)
where data.pkl and labels.pkl are jnp.arrays with shape (n_samples, n_timesteps, n_features) and (n_samples, n_classes) respectively. If the dataset had original_idxs then those should be saved as a list of jnp.arrays with shape [(n_train,), (n_val,), (n_test,)].
The UEA datasets are a collection of multivariate time series classification benchmarks. They can be downloaded by
running data_dir/download_uea.py
and preprocessed by running data_dir/process_uea.py
.
The PPG-DaLiA dataset is a multivariate time series regression dataset,
where the aim is to predict a person’s heart rate using data
collected from a wrist-worn device. The dataset can be downloaded from the
UCI Machine Learning Repository. The data should be
unzipped and saved in the data_dir/raw
folder in the following format PPG_FieldStudy/S{i}/S{i}.pkl
. The data can be
preprocessed by running the process_ppg.py
script.
The code for training and evaluating the models is contained in train.py
. Experiments can be run using the run_experiment.py
script.
This script requires you to specify the names of the models you want to train,
the names of the datasets you want to train on, and a directory which contains configuration files. By default,
it will run the LinOSS experiments. The configuration files should be organised as config_dir/{model_name}/{dataset_name}.json
and contain the
following fields:
seeds
: A list of seeds to use for training.data_dir
: The directory containing the data.output_parent_dir
: The directory to save the output.lr_scheduler
: A function which takes the learning rate and returns the new learning rate.num_steps
: The number of steps to train for.print_steps
: The number of steps between printing the loss.batch_size
: The batch size.metric
: The metric to use for evaluation.classification
: Whether the task is a classification task.linoss_discretization
: ONLY for LinoSS -- which discretization to use. Choices are ['IM','IMEX']lr
: The initial learning rate.time
: Whether to include time as a channel.- Any further specific model parameters.
See experiment_configs/repeats
for examples.
The configuration files for all the experiments with fixed hyperparameters can be found in the experiment_configs
folder and
run_experiment.py
is currently configured to run the repeat experiments on the UEA datasets.
The outputs
folder contains a zip file of the output files from the UEA, and PPG experiments.
If you found our work useful in your research, please cite our paper at:
@inproceedings{rusch2025linoss,
title={Oscillatory State-Space Models},
author={Rusch, T Konstantin and Rus, Daniela},
booktitle={International Conference on Learning Representations},
year={2025}
}
(Also consider starring the project on GitHub.)