-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Description
In ds.push_to_hub()
and ds.save_to_disk()
, num_shards
must be smaller than or equal to the number of rows in the dataset, but currently this is not checked anywhere inside these functions. Attempting to invoke these functions with num_shards > len(dataset)
should raise an informative ValueError
.
I frequently work with datasets with a small number of rows where each row is pretty large, so I often encounter this issue, where the function runs until the shard index in ds.shard(num_shards, indx)
goes out of bounds. Ideally, a ValueError
should be raised before reaching this point (i.e. as soon as ds.push_to_hub()
or ds.save_to_disk()
is invoked with num_shards > len(dataset)
).
It seems that adding something like:
if len(self) < num_shards:
raise ValueError(f"num_shards ({num_shards}) must be smaller than or equal to the number of rows in the dataset ({len(self)}). Please either reduce num_shards or increase max_shard_size to make sure num_shards <= len(dataset).")
to the beginning of the definition of the ds.shard()
function here would deal with this issue for both ds.push_to_hub()
and ds.save_to_disk()
, but I'm not exactly sure if this is the best place to raise the ValueError
(it seems that a more correct way to do it would be to write separate checks for ds.push_to_hub()
and ds.save_to_disk()
). I'd be happy to submit a PR if you think something along these lines would be acceptable.