Skip to content

Commit bbfa179

Browse files
authored
Support Detic and Multi-Datasets training (#10926)
1 parent b09d183 commit bbfa179

26 files changed

+4087
-5
lines changed

mmdet/datasets/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .coco_panoptic import CocoPanopticDataset
1111
from .coco_semantic import CocoSegDataset
1212
from .crowdhuman import CrowdHumanDataset
13-
from .dataset_wrappers import MultiImageMixDataset
13+
from .dataset_wrappers import ConcatDataset, MultiImageMixDataset
1414
from .deepfashion import DeepFashionDataset
1515
from .dsdl import DSDLDetDataset
1616
from .isaid import iSAIDDataset
@@ -42,5 +42,5 @@
4242
'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler',
4343
'ADE20KPanopticDataset', 'CocoCaptionDataset', 'RefCocoDataset',
4444
'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset',
45-
'ADE20KInstanceDataset', 'iSAIDDataset', 'V3DetDataset'
45+
'ADE20KInstanceDataset', 'iSAIDDataset', 'V3DetDataset', 'ConcatDataset'
4646
]

mmdet/datasets/dataset_wrappers.py

+85-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import collections
33
import copy
4-
from typing import Sequence, Union
4+
from typing import List, Sequence, Union
55

6-
from mmengine.dataset import BaseDataset, force_full_init
6+
from mmengine.dataset import BaseDataset
7+
from mmengine.dataset import ConcatDataset as MMENGINE_ConcatDataset
8+
from mmengine.dataset import force_full_init
79

810
from mmdet.registry import DATASETS, TRANSFORMS
911

@@ -167,3 +169,84 @@ def update_skip_type_keys(self, skip_type_keys):
167169
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
168170
])
169171
self._skip_type_keys = skip_type_keys
172+
173+
174+
@DATASETS.register_module()
175+
class ConcatDataset(MMENGINE_ConcatDataset):
176+
"""A wrapper of concatenated dataset.
177+
178+
Same as ``torch.utils.data.dataset.ConcatDataset``, support
179+
lazy_init and get_dataset_source.
180+
181+
Note:
182+
``ConcatDataset`` should not inherit from ``BaseDataset`` since
183+
``get_subset`` and ``get_subset_`` could produce ambiguous meaning
184+
sub-dataset which conflicts with original dataset. If you want to use
185+
a sub-dataset of ``ConcatDataset``, you should set ``indices``
186+
arguments for wrapped dataset which inherit from ``BaseDataset``.
187+
188+
Args:
189+
datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets
190+
which will be concatenated.
191+
lazy_init (bool, optional): Whether to load annotation during
192+
instantiation. Defaults to False.
193+
ignore_keys (List[str] or str): Ignore the keys that can be
194+
unequal in `dataset.metainfo`. Defaults to None.
195+
`New in version 0.3.0.`
196+
"""
197+
198+
def __init__(self,
199+
datasets: Sequence[Union[BaseDataset, dict]],
200+
lazy_init: bool = False,
201+
ignore_keys: Union[str, List[str], None] = None):
202+
self.datasets: List[BaseDataset] = []
203+
for i, dataset in enumerate(datasets):
204+
if isinstance(dataset, dict):
205+
self.datasets.append(DATASETS.build(dataset))
206+
elif isinstance(dataset, BaseDataset):
207+
self.datasets.append(dataset)
208+
else:
209+
raise TypeError(
210+
'elements in datasets sequence should be config or '
211+
f'`BaseDataset` instance, but got {type(dataset)}')
212+
if ignore_keys is None:
213+
self.ignore_keys = []
214+
elif isinstance(ignore_keys, str):
215+
self.ignore_keys = [ignore_keys]
216+
elif isinstance(ignore_keys, list):
217+
self.ignore_keys = ignore_keys
218+
else:
219+
raise TypeError('ignore_keys should be a list or str, '
220+
f'but got {type(ignore_keys)}')
221+
222+
meta_keys: set = set()
223+
for dataset in self.datasets:
224+
meta_keys |= dataset.metainfo.keys()
225+
# if the metainfo of multiple datasets are the same, use metainfo
226+
# of the first dataset, else the metainfo is a list with metainfo
227+
# of all the datasets
228+
is_all_same = True
229+
self._metainfo_first = self.datasets[0].metainfo
230+
for i, dataset in enumerate(self.datasets, 1):
231+
for key in meta_keys:
232+
if key in self.ignore_keys:
233+
continue
234+
if key not in dataset.metainfo:
235+
is_all_same = False
236+
break
237+
if self._metainfo_first[key] != dataset.metainfo[key]:
238+
is_all_same = False
239+
break
240+
241+
if is_all_same:
242+
self._metainfo = self.datasets[0].metainfo
243+
else:
244+
self._metainfo = [dataset.metainfo for dataset in self.datasets]
245+
246+
self._fully_initialized = False
247+
if not lazy_init:
248+
self.full_init()
249+
250+
def get_dataset_source(self, idx: int) -> int:
251+
dataset_idx, _ = self._get_ori_dataset_idx(idx)
252+
return dataset_idx

mmdet/datasets/samplers/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .batch_sampler import (AspectRatioBatchSampler,
3+
MultiDataAspectRatioBatchSampler,
34
TrackAspectRatioBatchSampler)
45
from .class_aware_sampler import ClassAwareSampler
6+
from .multi_data_sampler import MultiDataSampler
57
from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler
68
from .track_img_sampler import TrackImgSampler
79

810
__all__ = [
911
'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler',
1012
'GroupMultiSourceSampler', 'TrackImgSampler',
11-
'TrackAspectRatioBatchSampler'
13+
'TrackAspectRatioBatchSampler', 'MultiDataSampler',
14+
'MultiDataAspectRatioBatchSampler'
1215
]

mmdet/datasets/samplers/batch_sampler.py

+77
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,80 @@ def __iter__(self) -> Sequence[int]:
114114
else:
115115
yield left_data[:self.batch_size]
116116
left_data = left_data[self.batch_size:]
117+
118+
119+
@DATA_SAMPLERS.register_module()
120+
class MultiDataAspectRatioBatchSampler(BatchSampler):
121+
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
122+
123+
>= 1) into a same batch for multi-source datasets.
124+
125+
Args:
126+
sampler (Sampler): Base sampler.
127+
batch_size (Sequence(int)): Size of mini-batch for multi-source
128+
datasets.
129+
num_datasets(int): Number of multi-source datasets.
130+
drop_last (bool): If ``True``, the sampler will drop the last batch if
131+
its size would be less than ``batch_size``.
132+
"""
133+
134+
def __init__(self,
135+
sampler: Sampler,
136+
batch_size: Sequence[int],
137+
num_datasets: int,
138+
drop_last: bool = True) -> None:
139+
if not isinstance(sampler, Sampler):
140+
raise TypeError('sampler should be an instance of ``Sampler``, '
141+
f'but got {sampler}')
142+
self.sampler = sampler
143+
self.batch_size = batch_size
144+
self.num_datasets = num_datasets
145+
self.drop_last = drop_last
146+
# two groups for w < h and w >= h for each dataset --> 2 * num_datasets
147+
self._buckets = [[] for _ in range(2 * self.num_datasets)]
148+
149+
def __iter__(self) -> Sequence[int]:
150+
for idx in self.sampler:
151+
data_info = self.sampler.dataset.get_data_info(idx)
152+
width, height = data_info['width'], data_info['height']
153+
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
154+
aspect_ratio_bucket_id = 0 if width < height else 1
155+
bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id
156+
bucket = self._buckets[bucket_id]
157+
bucket.append(idx)
158+
# yield a batch of indices in the same aspect ratio group
159+
if len(bucket) == self.batch_size[dataset_source_idx]:
160+
yield bucket[:]
161+
del bucket[:]
162+
163+
# yield the rest data and reset the bucket
164+
for i in range(self.num_datasets):
165+
left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1]
166+
while len(left_data) > 0:
167+
if len(left_data) <= self.batch_size[i]:
168+
if not self.drop_last:
169+
yield left_data[:]
170+
left_data = []
171+
else:
172+
yield left_data[:self.batch_size[i]]
173+
left_data = left_data[self.batch_size[i]:]
174+
175+
self._buckets = [[] for _ in range(2 * self.num_datasets)]
176+
177+
def __len__(self) -> int:
178+
sizes = [0 for _ in range(self.num_datasets)]
179+
for idx in self.sampler:
180+
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
181+
sizes[dataset_source_idx] += 1
182+
183+
if self.drop_last:
184+
lens = 0
185+
for i in range(self.num_datasets):
186+
lens += sizes[i] // self.batch_size[i]
187+
return lens
188+
else:
189+
lens = 0
190+
for i in range(self.num_datasets):
191+
lens += (sizes[i] + self.batch_size[i] -
192+
1) // self.batch_size[i]
193+
return lens
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import math
3+
from typing import Iterator, Optional, Sequence, Sized
4+
5+
import torch
6+
from mmengine.dist import get_dist_info, sync_random_seed
7+
from mmengine.registry import DATA_SAMPLERS
8+
from torch.utils.data import Sampler
9+
10+
11+
@DATA_SAMPLERS.register_module()
12+
class MultiDataSampler(Sampler):
13+
"""The default data sampler for both distributed and non-distributed
14+
environment.
15+
16+
It has several differences from the PyTorch ``DistributedSampler`` as
17+
below:
18+
19+
1. This sampler supports non-distributed environment.
20+
21+
2. The round up behaviors are a little different.
22+
23+
- If ``round_up=True``, this sampler will add extra samples to make the
24+
number of samples is evenly divisible by the world size. And
25+
this behavior is the same as the ``DistributedSampler`` with
26+
``drop_last=False``.
27+
- If ``round_up=False``, this sampler won't remove or add any samples
28+
while the ``DistributedSampler`` with ``drop_last=True`` will remove
29+
tail samples.
30+
31+
Args:
32+
dataset (Sized): The dataset.
33+
dataset_ratio (Sequence(int)) The ratios of different datasets.
34+
seed (int, optional): Random seed used to shuffle the sampler if
35+
:attr:`shuffle=True`. This number should be identical across all
36+
processes in the distributed group. Defaults to None.
37+
round_up (bool): Whether to add extra samples to make the number of
38+
samples evenly divisible by the world size. Defaults to True.
39+
"""
40+
41+
def __init__(self,
42+
dataset: Sized,
43+
dataset_ratio: Sequence[int],
44+
seed: Optional[int] = None,
45+
round_up: bool = True) -> None:
46+
rank, world_size = get_dist_info()
47+
self.rank = rank
48+
self.world_size = world_size
49+
50+
self.dataset = dataset
51+
self.dataset_ratio = dataset_ratio
52+
53+
if seed is None:
54+
seed = sync_random_seed()
55+
self.seed = seed
56+
self.epoch = 0
57+
self.round_up = round_up
58+
59+
if self.round_up:
60+
self.num_samples = math.ceil(len(self.dataset) / world_size)
61+
self.total_size = self.num_samples * self.world_size
62+
else:
63+
self.num_samples = math.ceil(
64+
(len(self.dataset) - rank) / world_size)
65+
self.total_size = len(self.dataset)
66+
67+
self.sizes = [len(dataset) for dataset in self.dataset.datasets]
68+
69+
dataset_weight = [
70+
torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio)
71+
for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes))
72+
]
73+
self.weights = torch.cat(dataset_weight)
74+
75+
def __iter__(self) -> Iterator[int]:
76+
"""Iterate the indices."""
77+
# deterministically shuffle based on epoch and seed
78+
g = torch.Generator()
79+
g.manual_seed(self.seed + self.epoch)
80+
81+
indices = torch.multinomial(
82+
self.weights, len(self.weights), generator=g,
83+
replacement=True).tolist()
84+
85+
# add extra samples to make it evenly divisible
86+
if self.round_up:
87+
indices = (
88+
indices *
89+
int(self.total_size / len(indices) + 1))[:self.total_size]
90+
91+
# subsample
92+
indices = indices[self.rank:self.total_size:self.world_size]
93+
94+
return iter(indices)
95+
96+
def __len__(self) -> int:
97+
"""The number of samples in this rank."""
98+
return self.num_samples
99+
100+
def set_epoch(self, epoch: int) -> None:
101+
"""Sets the epoch for this sampler.
102+
103+
When :attr:`shuffle=True`, this ensures all replicas use a different
104+
random ordering for each epoch. Otherwise, the next iteration of this
105+
sampler will yield the same ordering.
106+
107+
Args:
108+
epoch (int): Epoch number.
109+
"""
110+
self.epoch = epoch

projects/Detic/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Note: This project has been deprecated, please use [Detic_new](../Detic_new).
2+
13
# Detecting Twenty-thousand Classes using Image-level Supervision
24

35
## Description

0 commit comments

Comments
 (0)