Lighter, Better, Faster Multi-Source Domain Adaptation with Gaussian Mixture Models for optimal Transport
In this repository, we provide the source code for our paper "Lighter, Better, Faster Multi-Source Domain Adaptation with Gaussian Mixture Models for optimal Transport", Accepted at the European Conference on Machine Learning and Principles and Practices of Knowledge Discovery in Databases (ECML-PKDD'24), which proposes new tools for multi-source domain adaptation through Gaussian Mixture Model-based OT. Especially, here you will find our source code for reproducing the toy example in Section 4.1 of our main paper.
In this paper, we tackle Multi-Source Domain Adaptation (MSDA), a task in transfer learning where one adapts multiple heterogeneous, labeled source probability measures towards a different, unlabeled target measure. We propose a novel framework for MSDA, based on Optimal Transport (OT) and Gaussian Mixture Models (GMMs). Our framework has two key advantages. First, OT between GMMs can be solved efficiently via linear programming. Second, it provides a convenient model for supervised learning, especially classification, as components in the GMM can be associated with existing classes. Based on the GMM-OT problem, we propose a novel technique for calculating barycenters of GMMs. Based on this novel algorithm, we propose two new strategies for MSDA: GMM-Wasserstein Barycenter Transport (WBT) and GMM-Dataset Dictionary Learning (DaDiL). We empirically evaluate our proposed methods on four benchmarks in image classification and fault diagnosis, showing that we improve over the prior art while being faster and involving fewer parameters.
Keyword. Domain Adaptation, Optimal Transport, Gaussian Mixture Models
We summarize our contributions as follows,
- We propose a novel strategy for mapping the parameters of GMMs using OT.
- We propose a novel algorithm for computing mixture-Wasserstein barycenters of GMMs.
- We propose an efficient parametric extension of the WBT and DaDiL algorithms based on GMMs.
Algorithm | Ar | Cl | Pr | Rw | Avg. |
---|---|---|---|---|---|
ResNet101 | 72.90 | 62.20 | 83.70 | 85.00 | 75.95 |
M$^{3}$SDA | 71.13 | 61.41 | 80.18 | 80.64 | 73.34 |
LtC-MSDA | 74.52 | 60.56 | 85.52 | 83.63 | 76.05 |
KD3A | 73.80 | 63.10 | 84.30 | 83.50 | 76.17 |
Co-MDA$^{\ddag}$ | 74.40 | 64.00 | 85.30 | 83.90 | 76.90 |
WJDOT | 74.28 | 63.80 | 83.78 | 84.52 | 76.59 |
WBT | 75.72 | 63.80 | 84.23 | 84.63 | 77.09 |
DaDiL-E | 77.16 | 64.95 | 85.47 | 84.97 | 78.14 |
DaDiL-R | 75.92 | 64.83 | 85.36 | 85.32 | 77.86 |
GMM-WBT | 75.31 | 64.26 | 86.71 | 85.21 | 77.87 |
GMM-DaDiL | 77.16 | 66.21 | 86.15 | 85.32 | 78.81 |
Algorithm | A | D | W | Avg. |
---|---|---|---|---|
ResNet50 | 67.50 | 95.00 | 96.83 | 86.40 |
M$^{3}$SDA | 66.75 | 97.00 | 96.83 | 86.86 |
LtC-MSDA | 66.82 | 100.00 | 97.12 | 87.98 |
KD3A | 65.20 | 100.0 | 98.70 | 87.96 |
Co-MDA | 64.80 | 99.83 | 98.70 | 87.83 |
WJDOT | 67.77 | 97.32 | 95.32 | 86.80 |
WBT | 67.94 | 98.21 | 97.66 | 87.93 |
DaDiL-E | 70.55 | 100.0 | 98.83 | 89.79 |
DaDiL-R | 70.90 | 100.0 | 98.83 | 89.91 |
GMM-WBT | 70.13 | 99.11 | 96.49 | 88.54 |
GMM-DaDiL | 72.47 | 100.0 | 99.41 | 90.63 |
Algorithm | A | B | C | Avg. |
---|---|---|---|---|
MLP$^{\star}$ | 70.90 |
79.76 |
72.26 |
74.31 |
M3SDA | 56.86 |
69.81 |
61.06 |
62.57 |
LTC-MSDA$^{\star}$ | 82.21 |
75.33 |
81.04 |
79.52 |
KD3A | 81.02 |
78.04 |
74.64 |
77.90 |
Co-MDA | 62.66 |
55.78 |
76.35 |
64.93 |
WJDOT | 99.96 |
98.86 |
100.0 |
99.60 |
WBT$^{\star}$ | 99.28 |
79.91 |
97.71 |
92.30 |
DaDiL-R$^{\star}$ | 99.86 |
99.85 |
100.00 |
99.90 |
DaDiL-E$^{\star}$ | 93.71 |
83.63 |
99.97 $\pm$ 0.05 | 92.33 |
GMM-WBT | 100.00 |
99.95 |
100.00 |
99.98 |
GMM-DaDiL | 100.00 |
99.95 |
100.00 |
99.98 |
Algorithm | Mode 1 | Mode 2 | Mode 3 | Mode 4 | Mode 5 | Mode 6 | Avg. |
---|---|---|---|---|---|---|---|
CNN$^{\dag}$ | 80.82 |
63.69 |
87.47 |
79.96 |
74.44 |
84.53 |
78.48 |
M$^{3}$SDA$^{\dag}$ | 81.17 |
61.61 |
79.99 |
79.12 |
75.16 |
78.91 |
75.99 |
LtC-MSDA$^{1}$ | - | - | - | - | - | - | - |
KD3A | 72.52 |
18.96 |
81.02 |
74.42 |
67.18 |
78.22 |
65.38 |
Co-MDA | 64.56 |
35.99 |
79.66 |
72.06 |
66.33 |
78.91 |
66.34 |
WJDOT | 89.06 |
75.60 |
89.99 |
89.38 $\pm$ 0.77 | 85.32 |
87.43 |
86.13 |
WBT$^{\dag}$ | 92.38 |
73.74 |
88.89 |
89.38 $\pm$ 1.26 | 85.53 |
86.60 |
86.09 |
DaDiL-R$^{\ddag}$ | 91.97 |
77.15 |
85.41 |
89.39 |
84.49 |
88.44 |
86.14 |
DaDiL-E$^{\ddag}$ | 90.45 |
77.08 $\pm$ 1.21 | 86.79 |
89.01 |
84.04 |
87.85 |
85.87 |
GMM-WBT | 92.23 $\pm$ 0.70 | 71.81 |
84.72 |
89.28 |
87.51 |
82.49 |
84.67 |
GMM-DaDiL | 91.72 |
76.41 |
89.68 $\pm$ 1.49 | 89.18 |
86.05 $\pm$ 1.46 | 88.02 $\pm$ 1.12 | 86.85 |
@article{montesuma2024lighter,
title={Lighter, Better, Faster Multi-Source Domain Adaptation with Gaussian Mixture Models and Optimal Transport},
author={Montesuma, Eduardo Fernandes and Mboula, Fred Ngol{\`e} and Souloumiac, Antoine},
journal={arXiv preprint arXiv:2404.10261},
year={2024}
}