diff --git a/airio/_src/pygrain/dataset_providers.py b/airio/_src/pygrain/dataset_providers.py index 6a55416..1fd6fba 100644 --- a/airio/_src/pygrain/dataset_providers.py +++ b/airio/_src/pygrain/dataset_providers.py @@ -87,6 +87,7 @@ def get_lazy_dataset( shard_info: core_dataset_providers.ShardInfo | None, num_epochs: int | None, num_prefetch_threads: int | None, + drop_remainder: bool = False, ) -> lazy_dataset.LazyMapDataset | lazy_dataset.LazyIterDataset: """Returns a lazy dataset for Task source and preprocessors.""" # Step 1: Get Source. @@ -156,7 +157,7 @@ def get_lazy_dataset( runtime_preps.extend(runtime_preprocessors) if batch_size: runtime_preps.append( - grain.Batch(batch_size=batch_size, drop_remainder=False) + grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder) ) unused_next_epoch_rng, prep_rng = jax.random.split(next_epoch_rng) ds, _, _ = _apply_preprocessors_to_lazy_dataset( @@ -185,6 +186,7 @@ def get_dataset( num_epochs: int | None = 1, num_prefetch_threads: int | None = None, num_workers: int | None = 0, + drop_remainder: bool = False, ) -> clu_dataset_iterator.DatasetIterator: """Returns the dataset iterator as per the task configuration.""" # TODO(b/311720936): Until Task preprocessing is fully switched to @@ -201,6 +203,7 @@ def get_dataset( shard_info=shard_info, num_epochs=num_epochs, num_prefetch_threads=num_prefetch_threads, + drop_remainder=drop_remainder, ) if num_epochs is None: ds = lazy_dataset.RepeatLazyMapDataset(ds, num_epochs=None) @@ -230,7 +233,7 @@ def get_dataset( if runtime_preprocessors: ops.extend(runtime_preprocessors) if batch_size: - ops.append(grain.Batch(batch_size=batch_size, drop_remainder=False)) + ops.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)) # Add runtime args runtime_args = core_preprocessors_lib.AirIOInjectedRuntimeArgs( @@ -300,6 +303,7 @@ def get_dataset_by_step( batch_size: int | None = None, shuffle: bool = True, seed: int | None = 0, + drop_remainder: bool = False, ) -> Iterable[Iterable[Mapping[str, Any]]]: """Returns a step-by-step transformation of a sample of records. @@ -314,6 +318,7 @@ def get_dataset_by_step( batch_size: the batch size. shuffle: whether to shuffle or not. seed: dataset seed. + drop_remainder: whether to drop the last batch if it's smaller than batch_size. Returns: a list indexed by processing step. For example: |-----------------------------| @@ -345,7 +350,7 @@ def get_dataset_by_step( if runtime_preprocessors: all_ops.extend(runtime_preprocessors) if batch_size: - all_ops.append(grain.Batch(batch_size=batch_size, drop_remainder=False)) + all_ops.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)) # Raw data records_step0 = self._load_data(source=source, sampler=sampler, ops=[]) @@ -432,6 +437,7 @@ def get_lazy_dataset( shard_info: core_dataset_providers.ShardInfo | None = None, num_epochs: int | None = 1, num_prefetch_threads: int | None = None, + drop_remainder: bool = False, ) -> lazy_dataset.LazyMapDataset | lazy_dataset.LazyIterDataset: """Returns a lazy dataset for the Mixture.""" if num_epochs is None and shuffle: @@ -454,6 +460,7 @@ def get_lazy_dataset( shard_info=shard_info, num_epochs=num_epochs, num_prefetch_threads=num_prefetch_threads, + drop_remainder=drop_remainder, ) ) proportions.append(self.get_proportion(task)) @@ -493,7 +500,7 @@ def get_lazy_dataset( post_mix_preps.extend(runtime_preprocessors) if batch_size: post_mix_preps.append( - grain.Batch(batch_size=batch_size, drop_remainder=False) + grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder) ) # Note: Use updated runtime args from the first Task. All updated runtime # args must match, or mixing won't work (compute all updated runtime args