Skip to content

tk-rusch/linoss

Repository files navigation

Oscillatory State-Space Models (ICLR2025 Oral)

This repository contains the official implementation for the paper Oscillatory State-Space Models by T. Konstantin Rusch and Daniela Rus.

This repository is an extension of https://github.com/Benjamin-Walker/log-neural-cdes.


We propose Linear Oscillatory State-Space models (LinOSS) for efficiently learning on long sequences. Inspired by cortical dynamics of biological neural networks, we base our proposed LinOSS model on a system of forced harmonic oscillators. A stable discretization, integrated over time using fast associative parallel scans, yields the proposed state-space model.

linoss_animation

Requirements

This repository is implemented in python 3.10 and uses Jax as their machine learning framework.

Environment

The code for preprocessing the datasets, training LinOSS, S5, LRU, NCDE, NRDE, and Log-NCDE uses the following packages:

  • jax and jaxlib for automatic differentiation.
  • equinox for constructing neural networks.
  • optax for neural network optimisers.
  • diffrax for differential equation solvers.
  • signax for calculating the signature.
  • sktime for handling time series data in ARFF format.
  • tqdm for progress bars.
  • matplotlib for plotting.
  • pre-commit for code formatting.
conda create -n LinOSS python=3.10
conda activate LinOSS
conda install pre-commit=3.7.1 sktime=0.30.1 tqdm=4.66.4 matplotlib=3.8.4 -c conda-forge
# Substitue for correct Jax pip install: https://jax.readthedocs.io/en/latest/installation.html
pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.11.4 optax==0.2.2 diffrax==0.5.1 signax==0.1.1

If running data_dir/process_uea.py throws this error: No module named 'packaging' Then run: pip install packaging

After installing the requirements, run pre-commit install to install the pre-commit hooks.


Data

The folder data_dir contains the scripts for downloading data, preprocessing the data, and creating dataloaders and datasets. Raw data should be downloaded into the data_dir/raw folder. Processed data should be saved into the data_dir/processed folder in the following format:

processed/{collection}/{dataset_name}/data.pkl, 
processed/{collection}/{dataset_name}/labels.pkl,
processed/{collection}/{dataset_name}/original_idxs.pkl (if the dataset has original data splits)

where data.pkl and labels.pkl are jnp.arrays with shape (n_samples, n_timesteps, n_features) and (n_samples, n_classes) respectively. If the dataset had original_idxs then those should be saved as a list of jnp.arrays with shape [(n_train,), (n_val,), (n_test,)].

The UEA Datasets

The UEA datasets are a collection of multivariate time series classification benchmarks. They can be downloaded by running data_dir/download_uea.py and preprocessed by running data_dir/process_uea.py.

The PPG-DaLiA Dataset

The PPG-DaLiA dataset is a multivariate time series regression dataset, where the aim is to predict a person’s heart rate using data collected from a wrist-worn device. The dataset can be downloaded from the UCI Machine Learning Repository. The data should be unzipped and saved in the data_dir/raw folder in the following format PPG_FieldStudy/S{i}/S{i}.pkl. The data can be preprocessed by running the process_ppg.py script.


Experiments

The code for training and evaluating the models is contained in train.py. Experiments can be run using the run_experiment.py script. This script requires you to specify the names of the models you want to train, the names of the datasets you want to train on, and a directory which contains configuration files. By default, it will run the LinOSS experiments. The configuration files should be organised as config_dir/{model_name}/{dataset_name}.json and contain the following fields:

  • seeds: A list of seeds to use for training.
  • data_dir: The directory containing the data.
  • output_parent_dir: The directory to save the output.
  • lr_scheduler: A function which takes the learning rate and returns the new learning rate.
  • num_steps: The number of steps to train for.
  • print_steps: The number of steps between printing the loss.
  • batch_size: The batch size.
  • metric: The metric to use for evaluation.
  • classification: Whether the task is a classification task.
  • linoss_discretization: ONLY for LinoSS -- which discretization to use. Choices are ['IM','IMEX']
  • lr: The initial learning rate.
  • time: Whether to include time as a channel.
  • Any further specific model parameters.

See experiment_configs/repeats for examples.


Reproducing the Results

The configuration files for all the experiments with fixed hyperparameters can be found in the experiment_configs folder and run_experiment.py is currently configured to run the repeat experiments on the UEA datasets. The outputs folder contains a zip file of the output files from the UEA, and PPG experiments.


Citation

If you found our work useful in your research, please cite our paper at:

@inproceedings{rusch2025linoss,
  title={Oscillatory State-Space Models},
  author={Rusch, T Konstantin and Rus, Daniela},
  booktitle={International Conference on Learning Representations},
  year={2025}
}

(Also consider starring the project on GitHub.)

About

Oscillatory State-Space Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages