Skip to content

index error when num_shards > len(dataset) #7443

@eminorhan

Description

@eminorhan

In ds.push_to_hub() and ds.save_to_disk(), num_shards must be smaller than or equal to the number of rows in the dataset, but currently this is not checked anywhere inside these functions. Attempting to invoke these functions with num_shards > len(dataset) should raise an informative ValueError.

I frequently work with datasets with a small number of rows where each row is pretty large, so I often encounter this issue, where the function runs until the shard index in ds.shard(num_shards, indx) goes out of bounds. Ideally, a ValueError should be raised before reaching this point (i.e. as soon as ds.push_to_hub() or ds.save_to_disk() is invoked with num_shards > len(dataset)).

It seems that adding something like:

if len(self) < num_shards:
   raise ValueError(f"num_shards ({num_shards}) must be smaller than or equal to the number of rows in the dataset ({len(self)}). Please either reduce num_shards or increase max_shard_size to make sure num_shards <= len(dataset).")

to the beginning of the definition of the ds.shard() function here would deal with this issue for both ds.push_to_hub() and ds.save_to_disk(), but I'm not exactly sure if this is the best place to raise the ValueError (it seems that a more correct way to do it would be to write separate checks for ds.push_to_hub() and ds.save_to_disk()). I'd be happy to submit a PR if you think something along these lines would be acceptable.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions