You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -37,19 +29,59 @@ Following datasets can be downloaded automatically:
37
29
-[Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (GroupDRO)](https://arxiv.org/abs/1911.08731)
38
30
-[Deep CORAL: Correlation Alignment for Deep Domain Adaptation (Deep Coral, 2016 ECCV)](https://arxiv.org/abs/1607.01719)
39
31
40
-
## Experiment and Results
32
+
## Usage
41
33
42
-
The shell files give the script to reproduce the [benchmarks](/docs/dglib/benchmarks/image_classification.rst) with specified hyper-parameters.
43
-
For example, if you want to reproduce IRM on Office-Home, use the following script
34
+
The shell files give the script to reproduce the benchmark with specified hyper-parameters.
35
+
For example, if you want to train IRM on Office-Home, use the following script
44
36
45
37
```shell script
46
38
# Train with IRM on Office-Home Ar Cl Rw -> Pr task using ResNet 50.
47
39
# Assume you have put the datasets under the path `data/office-home`,
48
40
# or you are glad to download the datasets automatically from the Internet to this path
49
41
CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr
50
42
```
43
+
Note that ``-s`` specifies the source domain, ``-t`` specifies the target domain,
44
+
and ``--log`` specifies where to store results.
51
45
52
-
For more information please refer to [Get Started](/docs/get_started/quickstart.rst) for help.
46
+
## Experiment and Results
47
+
Following [DomainBed](https://github.com/facebookresearch/DomainBed), we select hyper-parameters based on
48
+
the model's performance on `training-domain validation set` (first rule in DomainBed).
49
+
Concretely, we save model with the highest accuracy on `training-domain validation set` and then
50
+
load this checkpoint to test on the target domain.
51
+
52
+
Here are some differences between our implementation and DomainBed. For the model,
53
+
we do not freeze `BatchNorm2d` layers and do not insert additional `Dropout` layer except for `PACS` dataset.
54
+
For the optimizer, we use `SGD` with momentum by default and find this usually achieves better performance than `Adam`.
55
+
56
+
**Notations**
57
+
-``ERM`` refers to the model trained with data from the source domain.
58
+
-``Avg`` is the accuracy reported by `TLlib`.
59
+
60
+
### PACS accuracy on ResNet-50
61
+
62
+
| Methods | avg | A | C | P | S |
63
+
|----------|------|------|------|------|------|
64
+
| ERM | 86.4 | 88.5 | 78.4 | 97.2 | 81.4 |
65
+
| IBN | 87.8 | 88.2 | 84.5 | 97.1 | 81.4 |
66
+
| MixStyle | 87.4 | 87.8 | 82.3 | 95.0 | 84.5 |
67
+
| MLDG | 87.2 | 88.2 | 81.4 | 96.6 | 82.5 |
68
+
| IRM | 86.9 | 88.0 | 82.5 | 98.0 | 79.0 |
69
+
| VREx | 87.0 | 87.2 | 82.3 | 97.4 | 81.0 |
70
+
| GroupDRO | 87.3 | 88.9 | 81.7 | 97.8 | 80.8 |
71
+
| CORAL | 86.4 | 89.1 | 80.0 | 97.4 | 79.1 |
72
+
73
+
### Office-Home accuracy on ResNet-50
74
+
75
+
| Methods | avg | A | C | P | R |
76
+
|----------|------|------|------|------|------|
77
+
| ERM | 70.8 | 68.3 | 55.9 | 78.9 | 80.0 |
78
+
| IBN | 69.9 | 67.4 | 55.2 | 77.3 | 79.6 |
79
+
| MixStyle | 71.7 | 66.8 | 58.1 | 78.0 | 79.9 |
80
+
| MLDG | 70.3 | 65.9 | 57.6 | 78.2 | 79.6 |
81
+
| IRM | 70.3 | 66.7 | 54.8 | 78.6 | 80.9 |
82
+
| VREx | 70.2 | 66.9 | 54.9 | 78.2 | 80.9 |
83
+
| GroupDRO | 70.0 | 66.7 | 55.2 | 78.8 | 79.9 |
84
+
| CORAL | 70.9 | 68.3 | 55.4 | 78.8 | 81.0 |
53
85
54
86
## Citation
55
87
If you use these methods in your research, please consider citing.
0 commit comments