-
Notifications
You must be signed in to change notification settings - Fork 8
Add TileDataset #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add TileDataset #63
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1004,7 +1004,7 @@ def shuffle(self, reshuffle: bool = False, | |
else: | ||
raise ValueError(reshuffle, self) | ||
|
||
def tile(self, reps: int, shuffle: bool = False) -> 'Dataset': | ||
def tile(self, reps: int, shuffle: bool = False) -> "Dataset": | ||
""" | ||
Constructs a new dataset by repeating the dataset the number of | ||
times given by `reps`. This is done by copying the dataset and | ||
|
@@ -1022,21 +1022,16 @@ def tile(self, reps: int, shuffle: bool = False) -> 'Dataset': | |
>>> ds | ||
ListDataset(len=5) | ||
MapDataset(_pickle.loads) | ||
ListDataset(len=5) | ||
MapDataset(_pickle.loads) | ||
ListDataset(len=5) | ||
MapDataset(_pickle.loads) | ||
ConcatenateDataset() | ||
TileDataset(repetitions=3) | ||
>>> list(ds) | ||
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5] | ||
""" | ||
datasets = [self] * reps | ||
if shuffle: | ||
datasets = [ | ||
ds.shuffle() | ||
for ds in datasets | ||
] | ||
return self.__class__.concatenate(*datasets) | ||
datasets = [self] * reps | ||
datasets = [ds.shuffle() for ds in datasets] | ||
return self.__class__.concatenate(*datasets) | ||
else: | ||
return TileDataset(self, reps) | ||
|
||
def cycle(self) -> 'CycleDataset': | ||
""" | ||
|
@@ -2763,6 +2758,88 @@ def __getitem__(self, item): | |
return super().__getitem__(item) | ||
|
||
|
||
class TileDataset(Dataset): | ||
""" | ||
Iterates over all elements of the input_dataset for `repetitions` times. | ||
|
||
""" | ||
|
||
def __init__(self, input_dataset, repetitions): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you rename |
||
""" | ||
|
||
Args: | ||
input_dataset: dataset | ||
repetitions: int | ||
|
||
""" | ||
self.input_dataset = input_dataset | ||
self.repetitions = repetitions | ||
|
||
def copy(self, freeze=False): | ||
return self.__class__(self.input_dataset.copy(freeze=freeze), self.repetitions) | ||
|
||
@property | ||
def indexable(self): | ||
return self.input_dataset.indexable | ||
|
||
@property | ||
def ordered(self) -> bool: | ||
return self.input_dataset.ordered | ||
|
||
def __str__(self): | ||
return f"{self.__class__.__name__}(repetitions={self.repetitions})" | ||
|
||
def __iter__(self, with_key=False): | ||
for _ in range(self.repetitions): | ||
if with_key: | ||
iterable = self.input_dataset.__iter__(with_key=True) | ||
else: | ||
iterable = self.input_dataset | ||
for example in iterable: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you use |
||
yield example | ||
|
||
def __len__(self): | ||
return self.repetitions * len(self.input_dataset) | ||
|
||
def __getitem__(self, item): | ||
""" | ||
>>> ds = DictDataset({'a': {}, 'b': {}}) | ||
>>> ds = ds.items().map(lambda x: {'example_id': x[0], **x[1]}) | ||
>>> ds = ds.tile(2) | ||
>>> len(ds) | ||
4 | ||
>>> ds['a'] | ||
{'example_id': 'a'} | ||
>>> ds['b'] | ||
{'example_id': 'b'} | ||
>>> ds[5] | ||
Traceback (most recent call last): | ||
... | ||
IndexError: 5 | ||
>>> ds[-1] | ||
{'example_id': 'b'} | ||
>>> ds[-5] | ||
Traceback (most recent call last): | ||
... | ||
IndexError: -5 | ||
|
||
""" | ||
if isinstance(item, str): | ||
return self.input_dataset[item] | ||
elif isinstance(item, numbers.Integral): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure, whether this code will have an effect in the performance. How about changing the code to the follwing? input_len = len(self.input_dataset)
if not (-self.repetitions <= item // input_len < self.repetitions):
raise IndexError(_item)
return self.input_dataset[item % input_len] |
||
_item = item | ||
if item < 0: | ||
item = item + len(self) | ||
if item < 0: | ||
raise IndexError(_item) | ||
if item > self.repetitions * len(self.input_dataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
raise IndexError(_item) | ||
item = item % len(self.input_dataset) | ||
return self.input_dataset[item] | ||
else: | ||
return super().__getitem__(item) | ||
|
||
|
||
class IntersperseDataset(Dataset): | ||
""" | ||
See Dataset.intersperse | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To minimize overhead: Could your return self, when reps is one? concatenate does this already.