Skip to content

Commit ad7133d

Browse files
ydaimingfacebook-github-bot
authored andcommitted
Patch for pytorch#40026 RandomSampler generates samples one at a time when replacement=True (pytorch#41682)
Summary: Fix pytorch#32530 Fix/Patch pytorch#40026 Resubmit this patch and fix the type error. Force the input type to `manual_seed()` in `sampler.py` to be `int`. ezyang Pull Request resolved: pytorch#41682 Reviewed By: izdeby Differential Revision: D22665477 Pulled By: ezyang fbshipit-source-id: 1725c8aa742c31e74321f20448f4b6a392afb38d
1 parent 2d15b39 commit ad7133d

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

test/test_dataloader.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1288,22 +1288,24 @@ def test_random_sampler(self):
12881288
def sample_stat(sampler, num_samples):
12891289
counts = Counter(sampler)
12901290
count_repeated = sum(val > 1 for val in counts.values())
1291-
return (count_repeated, min(counts.keys()), max(counts.keys()))
1291+
return (count_repeated, min(counts.keys()), max(counts.keys()), sum(counts.values()))
12921292

12931293
# test sample with replacement
12941294
n = len(self.dataset) + 1 # ensure at least one sample is drawn more than once
12951295
sampler_with_replacement = RandomSampler(self.dataset, replacement=True, num_samples=n)
1296-
count_repeated, minval, maxval = sample_stat(sampler_with_replacement, n)
1296+
count_repeated, minval, maxval, count_total = sample_stat(sampler_with_replacement, n)
12971297
self.assertTrue(count_repeated > 0)
12981298
self.assertTrue(minval >= 0)
12991299
self.assertTrue(maxval < len(self.dataset))
1300+
self.assertTrue(count_total == n)
13001301

13011302
# test sample without replacement
13021303
sampler_without_replacement = RandomSampler(self.dataset)
1303-
count_repeated, minval, maxval = sample_stat(sampler_without_replacement, len(self.dataset))
1304+
count_repeated, minval, maxval, count_total = sample_stat(sampler_without_replacement, len(self.dataset))
13041305
self.assertTrue(count_repeated == 0)
13051306
self.assertTrue(minval == 0)
13061307
self.assertTrue(maxval == len(self.dataset) - 1)
1308+
self.assertTrue(count_total == len(self.dataset))
13071309

13081310
# raise error when replacement=False and num_samples is not None
13091311
self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=len(self.dataset)))

torch/utils/data/sampler.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class RandomSampler(Sampler[int]):
7676
7777
Arguments:
7878
data_source (Dataset): dataset to sample from
79-
replacement (bool): samples are drawn with replacement if ``True``, default=``False``
79+
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
8080
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
8181
is supposed to be specified only when `replacement` is ``True``.
8282
generator (Generator): Generator used in sampling.
@@ -112,10 +112,17 @@ def num_samples(self) -> int:
112112

113113
def __iter__(self):
114114
n = len(self.data_source)
115+
if self.generator is None:
116+
generator = torch.Generator()
117+
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
118+
else:
119+
generator = self.generator
115120
if self.replacement:
116-
rand_tensor = torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=self.generator)
117-
return iter(rand_tensor.tolist())
118-
return iter(torch.randperm(n, generator=self.generator).tolist())
121+
for _ in range(self.num_samples // 32):
122+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
123+
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
124+
else:
125+
yield from torch.randperm(n, generator=self.generator).tolist()
119126

120127
def __len__(self):
121128
return self.num_samples

0 commit comments

Comments
 (0)