Skip to content

Commit 1a330f3

Browse files
Aditya2755lhoestq
andauthored
Add type overloads to load_dataset for better static type inference (#7888)
* Add type overloads to load_dataset for better static type inference Fixes #7883 This PR adds @overload decorators to load_dataset() to help type checkers like Pylance and mypy correctly infer the return type based on the split and streaming parameters. Changes: - Added typing imports (Literal, overload) to load.py - Added 4 @overload signatures that map argument combinations to specific return types: * split=None, streaming=False -> DatasetDict * split specified, streaming=False -> Dataset * split=None, streaming=True -> IterableDatasetDict * split specified, streaming=True -> IterableDataset This resolves the Pylance error where to_csv() was not recognized on Dataset objects returned by load_dataset(..., split='train'), since the type checker previously saw the return type as a Union that included types without to_csv(). No runtime behavior changes - this is purely a static typing improvement. * make style --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 2ed6f72 commit 1a330f3

File tree

1 file changed

+96
-1
lines changed

1 file changed

+96
-1
lines changed

src/datasets/load.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from collections.abc import Mapping, Sequence
2626
from dataclasses import dataclass, field
2727
from pathlib import Path
28-
from typing import Any, Optional, Union
28+
from typing import Any, Literal, Optional, Union, overload
2929

3030
import fsspec
3131
import httpx
@@ -1187,6 +1187,101 @@ def load_dataset_builder(
11871187
return builder_instance
11881188

11891189

1190+
@overload
1191+
def load_dataset(
1192+
path: str,
1193+
name: Optional[str] = None,
1194+
data_dir: Optional[str] = None,
1195+
data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None,
1196+
split: None = None,
1197+
cache_dir: Optional[str] = None,
1198+
features: Optional[Features] = None,
1199+
download_config: Optional[DownloadConfig] = None,
1200+
download_mode: Optional[Union[DownloadMode, str]] = None,
1201+
verification_mode: Optional[Union[VerificationMode, str]] = None,
1202+
keep_in_memory: Optional[bool] = None,
1203+
save_infos: bool = False,
1204+
revision: Optional[Union[str, Version]] = None,
1205+
token: Optional[Union[bool, str]] = None,
1206+
streaming: Literal[False] = False,
1207+
num_proc: Optional[int] = None,
1208+
storage_options: Optional[dict] = None,
1209+
**config_kwargs: Any,
1210+
) -> DatasetDict: ...
1211+
1212+
1213+
@overload
1214+
def load_dataset(
1215+
path: str,
1216+
name: Optional[str] = None,
1217+
data_dir: Optional[str] = None,
1218+
data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None,
1219+
*,
1220+
split: Union[str, Split, list[str], list[Split]],
1221+
cache_dir: Optional[str] = None,
1222+
features: Optional[Features] = None,
1223+
download_config: Optional[DownloadConfig] = None,
1224+
download_mode: Optional[Union[DownloadMode, str]] = None,
1225+
verification_mode: Optional[Union[VerificationMode, str]] = None,
1226+
keep_in_memory: Optional[bool] = None,
1227+
save_infos: bool = False,
1228+
revision: Optional[Union[Version, str]] = None,
1229+
token: Optional[Union[bool, str]] = None,
1230+
streaming: Literal[False] = False,
1231+
num_proc: Optional[int] = None,
1232+
storage_options: Optional[dict] = None,
1233+
**config_kwargs: Any,
1234+
) -> Dataset: ...
1235+
1236+
1237+
@overload
1238+
def load_dataset(
1239+
path: str,
1240+
name: Optional[str] = None,
1241+
data_dir: Optional[str] = None,
1242+
data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None,
1243+
split: None = None,
1244+
cache_dir: Optional[str] = None,
1245+
features: Optional[Features] = None,
1246+
download_config: Optional[DownloadConfig] = None,
1247+
download_mode: Optional[Union[DownloadMode, str]] = None,
1248+
verification_mode: Optional[Union[VerificationMode, str]] = None,
1249+
keep_in_memory: Optional[bool] = None,
1250+
save_infos: bool = False,
1251+
revision: Optional[Union[Version, str]] = None,
1252+
token: Optional[Union[bool, str]] = None,
1253+
*,
1254+
streaming: Literal[True],
1255+
num_proc: Optional[int] = None,
1256+
storage_options: Optional[dict] = None,
1257+
**config_kwargs: Any,
1258+
) -> IterableDatasetDict: ...
1259+
1260+
1261+
@overload
1262+
def load_dataset(
1263+
path: str,
1264+
name: Optional[str] = None,
1265+
data_dir: Optional[str] = None,
1266+
data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None,
1267+
*,
1268+
split: Union[str, Split, list[str], list[Split]],
1269+
cache_dir: Optional[str] = None,
1270+
features: Optional[Features] = None,
1271+
download_config: Optional[DownloadConfig] = None,
1272+
download_mode: Optional[Union[DownloadMode, str]] = None,
1273+
verification_mode: Optional[Union[VerificationMode, str]] = None,
1274+
keep_in_memory: Optional[bool] = None,
1275+
save_infos: bool = False,
1276+
revision: Optional[Union[Version, str]] = None,
1277+
token: Optional[Union[bool, str]] = None,
1278+
streaming: Literal[True],
1279+
num_proc: Optional[int] = None,
1280+
storage_options: Optional[dict] = None,
1281+
**config_kwargs: Any,
1282+
) -> IterableDataset: ...
1283+
1284+
11901285
def load_dataset(
11911286
path: str,
11921287
name: Optional[str] = None,

0 commit comments

Comments
 (0)