Skip to content

A general-purpose, deep learning-first library for constrained optimization in PyTorch

License

Notifications You must be signed in to change notification settings

cooper-org/cooper

Repository files navigation

Cooper

LICENSE Version Python PyTorch DOCS Coverage badge Continuous Integration Stars HitCount contributions welcome Discord Ruff

What is Cooper?

Cooper is a library for solving constrained optimization problems in PyTorch.

Cooper implements several Lagrangian-based (first-order) update schemes that are applicable to a wide range of continuous constrained optimization problems. Cooper is mainly targeted for deep learning applications, where gradients are estimated based on mini-batches, but it is also suitable for general continuous constrained optimization tasks.

There exist other libraries for constrained optimization in PyTorch, like CHOP and GeoTorch, but they rely on assumptions about the constraints (such as admitting efficient projection or proximal operators). These assumptions are often not met in modern machine learning problems. Cooper can be applied to a wider range of constrained optimization problems (including non-convex problems) thanks to its Lagrangian-based approach.

You can check out Cooper's FAQ here.

Cooper's companion paper is available here.

Installation

To install the latest release of Cooper, use the following command:

pip install cooper-optim

To install the latest development version, use the following command instead:

pip install git+https://github.com/cooper-org/cooper@main

Getting Started

Quick Start

To use Cooper, you need to:

Example

This is an abstract example on how to solve a constrained optimization problem with Cooper. You can find runnable notebooks with concrete examples in our Tutorials.

import cooper
import torch

# Set up GPU acceleration
DEVICE = ...

class MyCMP(cooper.ConstrainedMinimizationProblem):
    def __init__(self):
        super().__init__()
        multiplier = cooper.multipliers.DenseMultiplier(num_constraints=..., device=DEVICE)
        # By default, constraints are built using `formulation_type=cooper.formulations.Lagrangian`
        self.constraint = cooper.Constraint(
            multiplier=multiplier, constraint_type=cooper.ConstraintType.INEQUALITY
        )

    def compute_cmp_state(self, model, inputs, targets):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        loss = ...
        constraint_state = cooper.ConstraintState(violation=...)
        observed_constraints = {self.constraint: constraint_state}

        return cooper.CMPState(loss=loss, observed_constraints=observed_constraints)


train_loader = ...
model = (...).to(DEVICE)
cmp = MyCMP()

primal_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Must set `maximize=True` since the Lagrange multipliers solve a _maximization_ problem
dual_optimizer = torch.optim.SGD(cmp.dual_parameters(), lr=1e-2, maximize=True)

cooper_optimizer = cooper.optim.SimultaneousOptimizer(
    cmp=cmp, primal_optimizers=primal_optimizer, dual_optimizers=dual_optimizer
)

for epoch_num in range(NUM_EPOCHS):
    for inputs, targets in train_loader:
        # `roll` is a convenience function that packages together the evaluation
        # of the loss, call for gradient computation, the primal and dual updates and zero_grad
        compute_cmp_state_kwargs = {"model": model, "inputs": inputs, "targets": targets}
        roll_out = cooper_optimizer.roll(compute_cmp_state_kwargs=compute_cmp_state_kwargs)
        # `roll_out` is a namedtuple containing the loss, last CMPState, and the primal
        # and dual Lagrangian stores, useful for inspection and logging

Contributions

We appreciate all contributions. Please let us know if you encounter a bug by filing an issue.

If you plan to contribute new features, utility functions, or extensions, please first open an issue and discuss the feature with us. To learn more about making a contribution to Cooper, please see our Contribution page.

Papers Using Cooper

Cooper has enabled several papers published at top machine learning conferences: Gallego-Posada et al. (2022); Lachapelle and Lacoste-Julien (2022); Ramirez and Gallego-Posada (2022); Zhu et al. (2023); Hashemizadeh et al. (2024); Sohrabi et al. (2024); Lachapelle et al. (2024); Jang et al. (2024); Navarin et al. (2024); Chung et al. (2024).

Acknowledgements

We thank Manuel Del Verme, Daniel Otero, and Isabel Urrego for useful discussions during the early stages of Cooper.

Many Cooper features arose during the development of several research papers. We would like to thank our co-authors Yoshua Bengio, Juan Elenter, Akram Erraqabi, Golnoosh Farnadi, Ignacio Hounie, Alejandro Ribeiro, Rohan Sukumaran, Motahareh Sohrabi and Tianyue (Helen) Zhang.

License

Cooper is distributed under an MIT license, as found in the LICENSE file.

How to cite Cooper

To cite Cooper, please cite this paper:

@article{gallegoPosada2025cooper,
    author={Gallego-Posada, Jose and Ramirez, Juan and Hashemizadeh, Meraj and Lacoste-Julien, Simon},
    title={{Cooper: A Library for Constrained Optimization in Deep Learning}},
    journal={arXiv preprint arXiv:2504.01212},
    year={2025}
}