Skip to content

Commit b7d1849

Browse files
committed
readme for dg/classification
1 parent 4927d34 commit b7d1849

File tree

2 files changed

+46
-135
lines changed

2 files changed

+46
-135
lines changed

docs/dglib/benchmarks/image_classification.rst

Lines changed: 0 additions & 121 deletions
This file was deleted.

examples/domain_generalization/image_classification/README.md

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
# Domain Generalization for Image Classification
22

33
## Installation
4-
Example scripts can deal with [WILDS datasets](https://wilds.stanford.edu/).
5-
You should first install ``wilds`` before using these scripts.
4+
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
65

7-
```
8-
pip install wilds
9-
```
10-
11-
Example scripts also support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).
6+
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).
127
You also need to install timm to use PyTorch-Image-Models.
138

149
```
@@ -23,9 +18,6 @@ Following datasets can be downloaded automatically:
2318
- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)
2419
- [DomainNet](http://ai.bu.edu/M3SDA/)
2520
- [PACS](https://domaingeneralization.github.io/#data)
26-
- [iwildcam (WILDS)](https://wilds.stanford.edu/datasets/)
27-
- [camelyon17 (WILDS)](https://wilds.stanford.edu/datasets/)
28-
- [fmow (WILDS)](https://wilds.stanford.edu/datasets/)
2921

3022
## Supported Methods
3123

@@ -37,19 +29,59 @@ Following datasets can be downloaded automatically:
3729
- [Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (GroupDRO)](https://arxiv.org/abs/1911.08731)
3830
- [Deep CORAL: Correlation Alignment for Deep Domain Adaptation (Deep Coral, 2016 ECCV)](https://arxiv.org/abs/1607.01719)
3931

40-
## Experiment and Results
32+
## Usage
4133

42-
The shell files give the script to reproduce the [benchmarks](/docs/dglib/benchmarks/image_classification.rst) with specified hyper-parameters.
43-
For example, if you want to reproduce IRM on Office-Home, use the following script
34+
The shell files give the script to reproduce the benchmark with specified hyper-parameters.
35+
For example, if you want to train IRM on Office-Home, use the following script
4436

4537
```shell script
4638
# Train with IRM on Office-Home Ar Cl Rw -> Pr task using ResNet 50.
4739
# Assume you have put the datasets under the path `data/office-home`,
4840
# or you are glad to download the datasets automatically from the Internet to this path
4941
CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr
5042
```
43+
Note that ``-s`` specifies the source domain, ``-t`` specifies the target domain,
44+
and ``--log`` specifies where to store results.
5145

52-
For more information please refer to [Get Started](/docs/get_started/quickstart.rst) for help.
46+
## Experiment and Results
47+
Following [DomainBed](https://github.com/facebookresearch/DomainBed), we select hyper-parameters based on
48+
the model's performance on `training-domain validation set` (first rule in DomainBed).
49+
Concretely, we save model with the highest accuracy on `training-domain validation set` and then
50+
load this checkpoint to test on the target domain.
51+
52+
Here are some differences between our implementation and DomainBed. For the model,
53+
we do not freeze `BatchNorm2d` layers and do not insert additional `Dropout` layer except for `PACS` dataset.
54+
For the optimizer, we use `SGD` with momentum by default and find this usually achieves better performance than `Adam`.
55+
56+
**Notations**
57+
- ``ERM`` refers to the model trained with data from the source domain.
58+
- ``Avg`` is the accuracy reported by `TLlib`.
59+
60+
### PACS accuracy on ResNet-50
61+
62+
| Methods | avg | A | C | P | S |
63+
|----------|------|------|------|------|------|
64+
| ERM | 86.4 | 88.5 | 78.4 | 97.2 | 81.4 |
65+
| IBN | 87.8 | 88.2 | 84.5 | 97.1 | 81.4 |
66+
| MixStyle | 87.4 | 87.8 | 82.3 | 95.0 | 84.5 |
67+
| MLDG | 87.2 | 88.2 | 81.4 | 96.6 | 82.5 |
68+
| IRM | 86.9 | 88.0 | 82.5 | 98.0 | 79.0 |
69+
| VREx | 87.0 | 87.2 | 82.3 | 97.4 | 81.0 |
70+
| GroupDRO | 87.3 | 88.9 | 81.7 | 97.8 | 80.8 |
71+
| CORAL | 86.4 | 89.1 | 80.0 | 97.4 | 79.1 |
72+
73+
### Office-Home accuracy on ResNet-50
74+
75+
| Methods | avg | A | C | P | R |
76+
|----------|------|------|------|------|------|
77+
| ERM | 70.8 | 68.3 | 55.9 | 78.9 | 80.0 |
78+
| IBN | 69.9 | 67.4 | 55.2 | 77.3 | 79.6 |
79+
| MixStyle | 71.7 | 66.8 | 58.1 | 78.0 | 79.9 |
80+
| MLDG | 70.3 | 65.9 | 57.6 | 78.2 | 79.6 |
81+
| IRM | 70.3 | 66.7 | 54.8 | 78.6 | 80.9 |
82+
| VREx | 70.2 | 66.9 | 54.9 | 78.2 | 80.9 |
83+
| GroupDRO | 70.0 | 66.7 | 55.2 | 78.8 | 79.9 |
84+
| CORAL | 70.9 | 68.3 | 55.4 | 78.8 | 81.0 |
5385

5486
## Citation
5587
If you use these methods in your research, please consider citing.

0 commit comments

Comments
 (0)