|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved.
|
2 | 2 | import collections
|
3 | 3 | import copy
|
4 |
| -from typing import Sequence, Union |
| 4 | +from typing import List, Sequence, Union |
5 | 5 |
|
6 |
| -from mmengine.dataset import BaseDataset, force_full_init |
| 6 | +from mmengine.dataset import BaseDataset |
| 7 | +from mmengine.dataset import ConcatDataset as MMENGINE_ConcatDataset |
| 8 | +from mmengine.dataset import force_full_init |
7 | 9 |
|
8 | 10 | from mmdet.registry import DATASETS, TRANSFORMS
|
9 | 11 |
|
@@ -167,3 +169,84 @@ def update_skip_type_keys(self, skip_type_keys):
|
167 | 169 | isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
|
168 | 170 | ])
|
169 | 171 | self._skip_type_keys = skip_type_keys
|
| 172 | + |
| 173 | + |
| 174 | +@DATASETS.register_module() |
| 175 | +class ConcatDataset(MMENGINE_ConcatDataset): |
| 176 | + """A wrapper of concatenated dataset. |
| 177 | +
|
| 178 | + Same as ``torch.utils.data.dataset.ConcatDataset``, support |
| 179 | + lazy_init and get_dataset_source. |
| 180 | +
|
| 181 | + Note: |
| 182 | + ``ConcatDataset`` should not inherit from ``BaseDataset`` since |
| 183 | + ``get_subset`` and ``get_subset_`` could produce ambiguous meaning |
| 184 | + sub-dataset which conflicts with original dataset. If you want to use |
| 185 | + a sub-dataset of ``ConcatDataset``, you should set ``indices`` |
| 186 | + arguments for wrapped dataset which inherit from ``BaseDataset``. |
| 187 | +
|
| 188 | + Args: |
| 189 | + datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets |
| 190 | + which will be concatenated. |
| 191 | + lazy_init (bool, optional): Whether to load annotation during |
| 192 | + instantiation. Defaults to False. |
| 193 | + ignore_keys (List[str] or str): Ignore the keys that can be |
| 194 | + unequal in `dataset.metainfo`. Defaults to None. |
| 195 | + `New in version 0.3.0.` |
| 196 | + """ |
| 197 | + |
| 198 | + def __init__(self, |
| 199 | + datasets: Sequence[Union[BaseDataset, dict]], |
| 200 | + lazy_init: bool = False, |
| 201 | + ignore_keys: Union[str, List[str], None] = None): |
| 202 | + self.datasets: List[BaseDataset] = [] |
| 203 | + for i, dataset in enumerate(datasets): |
| 204 | + if isinstance(dataset, dict): |
| 205 | + self.datasets.append(DATASETS.build(dataset)) |
| 206 | + elif isinstance(dataset, BaseDataset): |
| 207 | + self.datasets.append(dataset) |
| 208 | + else: |
| 209 | + raise TypeError( |
| 210 | + 'elements in datasets sequence should be config or ' |
| 211 | + f'`BaseDataset` instance, but got {type(dataset)}') |
| 212 | + if ignore_keys is None: |
| 213 | + self.ignore_keys = [] |
| 214 | + elif isinstance(ignore_keys, str): |
| 215 | + self.ignore_keys = [ignore_keys] |
| 216 | + elif isinstance(ignore_keys, list): |
| 217 | + self.ignore_keys = ignore_keys |
| 218 | + else: |
| 219 | + raise TypeError('ignore_keys should be a list or str, ' |
| 220 | + f'but got {type(ignore_keys)}') |
| 221 | + |
| 222 | + meta_keys: set = set() |
| 223 | + for dataset in self.datasets: |
| 224 | + meta_keys |= dataset.metainfo.keys() |
| 225 | + # if the metainfo of multiple datasets are the same, use metainfo |
| 226 | + # of the first dataset, else the metainfo is a list with metainfo |
| 227 | + # of all the datasets |
| 228 | + is_all_same = True |
| 229 | + self._metainfo_first = self.datasets[0].metainfo |
| 230 | + for i, dataset in enumerate(self.datasets, 1): |
| 231 | + for key in meta_keys: |
| 232 | + if key in self.ignore_keys: |
| 233 | + continue |
| 234 | + if key not in dataset.metainfo: |
| 235 | + is_all_same = False |
| 236 | + break |
| 237 | + if self._metainfo_first[key] != dataset.metainfo[key]: |
| 238 | + is_all_same = False |
| 239 | + break |
| 240 | + |
| 241 | + if is_all_same: |
| 242 | + self._metainfo = self.datasets[0].metainfo |
| 243 | + else: |
| 244 | + self._metainfo = [dataset.metainfo for dataset in self.datasets] |
| 245 | + |
| 246 | + self._fully_initialized = False |
| 247 | + if not lazy_init: |
| 248 | + self.full_init() |
| 249 | + |
| 250 | + def get_dataset_source(self, idx: int) -> int: |
| 251 | + dataset_idx, _ = self._get_ori_dataset_idx(idx) |
| 252 | + return dataset_idx |
0 commit comments