Skip to content

Commit 91c83bf

Browse files
authored
fix(DataFrame): #799 to_dict (#1283)
* fix(DataFrame): to_dict("index") and typevar * feat: https://github.com/pandas-dev/pandas/blob/v2.3.1/pandas/core/common.py#L416-L417 * fix: #1283 (comment) * fix(comment): #1283 (comment)
1 parent 0ecc8ea commit 91c83bf

File tree

2 files changed

+134
-55
lines changed

2 files changed

+134
-55
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ from builtins import (
22
bool as _bool,
33
str as _str,
44
)
5+
from collections import (
6+
OrderedDict,
7+
defaultdict,
8+
)
59
from collections.abc import (
610
Callable,
711
Hashable,
@@ -19,6 +23,7 @@ from typing import (
1923
Generic,
2024
Literal,
2125
NoReturn,
26+
TypeVar,
2227
final,
2328
overload,
2429
)
@@ -167,6 +172,8 @@ from pandas._typing import (
167172
from pandas.io.formats.style import Styler
168173
from pandas.plotting import PlotAccessor
169174

175+
_T_MUTABLE_MAPPING = TypeVar("_T_MUTABLE_MAPPING", bound=MutableMapping, covariant=True)
176+
170177
class _iLocIndexerFrame(_iLocIndexer, Generic[_T]):
171178
@overload
172179
def __getitem__(self, idx: tuple[int, int]) -> Scalar: ...
@@ -447,13 +454,21 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
447454
na_value: Scalar = ...,
448455
) -> np.ndarray: ...
449456
@overload
457+
def to_dict(
458+
self,
459+
orient: str = ...,
460+
*,
461+
into: type[defaultdict],
462+
index: Literal[True] = ...,
463+
) -> Never: ...
464+
@overload
450465
def to_dict(
451466
self,
452467
orient: Literal["records"],
453468
*,
454-
into: MutableMapping | type[MutableMapping],
469+
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
455470
index: Literal[True] = ...,
456-
) -> list[MutableMapping[Hashable, Any]]: ...
471+
) -> list[_T_MUTABLE_MAPPING]: ...
457472
@overload
458473
def to_dict(
459474
self,
@@ -465,51 +480,67 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
465480
@overload
466481
def to_dict(
467482
self,
468-
orient: Literal["dict", "list", "series", "index"],
483+
orient: Literal["index"],
469484
*,
470-
into: MutableMapping | type[MutableMapping],
485+
into: defaultdict,
471486
index: Literal[True] = ...,
472-
) -> MutableMapping[Hashable, Any]: ...
487+
) -> defaultdict[Hashable, dict[Hashable, Any]]: ...
473488
@overload
474489
def to_dict(
475490
self,
476-
orient: Literal["split", "tight"],
491+
orient: Literal["index"],
477492
*,
478-
into: MutableMapping | type[MutableMapping],
479-
index: bool = ...,
480-
) -> MutableMapping[Hashable, Any]: ...
493+
into: OrderedDict | type[OrderedDict],
494+
index: Literal[True] = ...,
495+
) -> OrderedDict[Hashable, dict[Hashable, Any]]: ...
481496
@overload
482497
def to_dict(
483498
self,
484-
orient: Literal["dict", "list", "series", "index"] = ...,
499+
orient: Literal["index"],
485500
*,
486-
into: MutableMapping | type[MutableMapping],
501+
into: type[MutableMapping],
487502
index: Literal[True] = ...,
488-
) -> MutableMapping[Hashable, Any]: ...
503+
) -> MutableMapping[Hashable, dict[Hashable, Any]]: ...
489504
@overload
490505
def to_dict(
491506
self,
492-
orient: Literal["split", "tight"] = ...,
507+
orient: Literal["index"],
493508
*,
494-
into: MutableMapping | type[MutableMapping],
495-
index: bool = ...,
496-
) -> MutableMapping[Hashable, Any]: ...
509+
into: type[dict] = ...,
510+
index: Literal[True] = ...,
511+
) -> dict[Hashable, dict[Hashable, Any]]: ...
512+
@overload
513+
def to_dict(
514+
self,
515+
orient: Literal["dict", "list", "series"] = ...,
516+
*,
517+
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
518+
index: Literal[True] = ...,
519+
) -> _T_MUTABLE_MAPPING: ...
497520
@overload
498521
def to_dict(
499522
self,
500-
orient: Literal["dict", "list", "series", "index"] = ...,
523+
orient: Literal["dict", "list", "series"] = ...,
501524
*,
502525
into: type[dict] = ...,
503526
index: Literal[True] = ...,
504527
) -> dict[Hashable, Any]: ...
505528
@overload
506529
def to_dict(
507530
self,
508-
orient: Literal["split", "tight"] = ...,
531+
orient: Literal["split", "tight"],
532+
*,
533+
into: MutableMapping | type[MutableMapping],
534+
index: bool = ...,
535+
) -> MutableMapping[str, list]: ...
536+
@overload
537+
def to_dict(
538+
self,
539+
orient: Literal["split", "tight"],
509540
*,
510541
into: type[dict] = ...,
511542
index: bool = ...,
512-
) -> dict[Hashable, Any]: ...
543+
) -> dict[str, list]: ...
513544
@classmethod
514545
def from_records(
515546
cls,

tests/test_frame.py

Lines changed: 84 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
3+
from collections import (
4+
OrderedDict,
5+
defaultdict,
6+
)
47
from collections.abc import (
8+
Callable,
59
Hashable,
610
Iterable,
711
Iterator,
@@ -20,7 +24,6 @@
2024
from typing import (
2125
TYPE_CHECKING,
2226
Any,
23-
Callable,
2427
Generic,
2528
TypedDict,
2629
TypeVar,
@@ -38,6 +41,7 @@
3841
)
3942
import pytest
4043
from typing_extensions import (
44+
Never,
4145
TypeAlias,
4246
assert_never,
4347
assert_type,
@@ -2199,19 +2203,6 @@ def test_types_resample() -> None:
21992203
df.resample(datetime.timedelta(minutes=20), origin="epoch", on="date")
22002204

22012205

2202-
def test_types_to_dict() -> None:
2203-
data = pd.DataFrame({"a": [1], "b": [2]})
2204-
check(assert_type(data.to_dict(orient="records"), list[dict[Hashable, Any]]), list)
2205-
check(assert_type(data.to_dict(orient="dict"), dict[Hashable, Any]), dict)
2206-
check(assert_type(data.to_dict(orient="list"), dict[Hashable, Any]), dict)
2207-
check(assert_type(data.to_dict(orient="series"), dict[Hashable, Any]), dict)
2208-
check(assert_type(data.to_dict(orient="split"), dict[Hashable, Any]), dict)
2209-
check(assert_type(data.to_dict(orient="index"), dict[Hashable, Any]), dict)
2210-
2211-
# orient param accepting "tight" added in 1.4.0 https://pandas.pydata.org/docs/whatsnew/v1.4.0.html
2212-
check(assert_type(data.to_dict(orient="tight"), dict[Hashable, Any]), dict)
2213-
2214-
22152206
def test_types_from_dict() -> None:
22162207
check(
22172208
assert_type(
@@ -3746,33 +3737,87 @@ def test_to_records() -> None:
37463737
)
37473738

37483739

3749-
def test_to_dict() -> None:
3750-
check(assert_type(DF.to_dict(), dict[Hashable, Any]), dict)
3751-
check(assert_type(DF.to_dict("split"), dict[Hashable, Any]), dict)
3740+
def test_to_dict_simple() -> None:
3741+
data = pd.DataFrame({"a": [1], "b": [2]})
3742+
check(assert_type(data.to_dict(), dict[Hashable, Any]), dict)
3743+
check(assert_type(data.to_dict("records"), list[dict[Hashable, Any]]), list)
3744+
check(assert_type(data.to_dict("index"), dict[Hashable, dict[Hashable, Any]]), dict)
3745+
check(assert_type(data.to_dict("dict"), dict[Hashable, Any]), dict)
3746+
check(assert_type(data.to_dict("list"), dict[Hashable, Any]), dict)
3747+
check(assert_type(data.to_dict("series"), dict[Hashable, Any]), dict)
3748+
check(assert_type(data.to_dict("split"), dict[str, list]), dict, str)
3749+
3750+
# orient param accepting "tight" added in 1.4.0 https://pandas.pydata.org/docs/whatsnew/v1.4.0.html
3751+
check(assert_type(data.to_dict("tight"), dict[str, list]), dict, str)
37523752

3753-
target: MutableMapping = defaultdict(list)
3753+
if TYPE_CHECKING_INVALID_USAGE:
3754+
3755+
def test(mapping: Mapping) -> None: # pyright: ignore[reportUnusedFunction]
3756+
data.to_dict(into=mapping) # type: ignore[call-overload] # pyright: ignore[reportArgumentType,reportCallIssue]
3757+
3758+
assert_type(data.to_dict(into=defaultdict), Never)
3759+
assert_type(data.to_dict("records", into=defaultdict), Never)
3760+
assert_type(data.to_dict("index", into=defaultdict), Never)
3761+
assert_type(data.to_dict("dict", into=defaultdict), Never)
3762+
assert_type(data.to_dict("list", into=defaultdict), Never)
3763+
assert_type(data.to_dict("series", into=defaultdict), Never)
3764+
assert_type(data.to_dict("split", into=defaultdict), Never)
3765+
assert_type(data.to_dict("tight", into=defaultdict), Never)
3766+
3767+
3768+
def test_to_dict_into_defaultdict() -> None:
3769+
"""Test DataFrame.to_dict with `into` is an instance of defaultdict[Any, list]"""
3770+
3771+
data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]})
3772+
target: defaultdict[Any, list] = defaultdict(list)
3773+
3774+
check(
3775+
assert_type(data.to_dict(into=target), defaultdict[Any, list]),
3776+
defaultdict,
3777+
tuple,
3778+
)
37543779
check(
3755-
assert_type(DF.to_dict(into=target), MutableMapping[Hashable, Any]), defaultdict
3780+
assert_type(
3781+
data.to_dict("index", into=target),
3782+
defaultdict[Hashable, dict[Hashable, Any]],
3783+
),
3784+
defaultdict,
3785+
)
3786+
check(
3787+
assert_type(data.to_dict("tight", into=target), MutableMapping[str, list]),
3788+
defaultdict,
3789+
str,
37563790
)
3757-
target = defaultdict(list)
37583791
check(
3759-
assert_type(DF.to_dict("tight", into=target), MutableMapping[Hashable, Any]),
3792+
assert_type(data.to_dict("records", into=target), list[defaultdict[Any, list]]),
3793+
list,
37603794
defaultdict,
37613795
)
3762-
target = defaultdict(list)
3763-
check(assert_type(DF.to_dict("records"), list[dict[Hashable, Any]]), list)
3796+
3797+
3798+
def test_to_dict_into_ordered_dict() -> None:
3799+
"""Test DataFrame.to_dict with `into=OrderedDict`"""
3800+
3801+
data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]})
3802+
3803+
check(assert_type(data.to_dict(into=OrderedDict), OrderedDict), OrderedDict, tuple)
37643804
check(
37653805
assert_type(
3766-
DF.to_dict("records", into=target), list[MutableMapping[Hashable, Any]]
3806+
data.to_dict("index", into=OrderedDict),
3807+
OrderedDict[Hashable, dict[Hashable, Any]],
37673808
),
3809+
OrderedDict,
3810+
)
3811+
check(
3812+
assert_type(data.to_dict("tight", into=OrderedDict), MutableMapping[str, list]),
3813+
OrderedDict,
3814+
str,
3815+
)
3816+
check(
3817+
assert_type(data.to_dict("records", into=OrderedDict), list[OrderedDict]),
37683818
list,
3819+
OrderedDict,
37693820
)
3770-
if TYPE_CHECKING_INVALID_USAGE:
3771-
3772-
def test(mapping: Mapping) -> None: # pyright: ignore[reportUnusedFunction]
3773-
DF.to_dict( # type: ignore[call-overload]
3774-
into=mapping # pyright: ignore[reportArgumentType,reportCallIssue]
3775-
)
37763821

37773822

37783823
def test_neg() -> None:
@@ -4247,19 +4292,22 @@ def test_to_dict_index() -> None:
42474292
assert_type(df.to_dict(orient="series", index=True), dict[Hashable, Any]), dict
42484293
)
42494294
check(
4250-
assert_type(df.to_dict(orient="index", index=True), dict[Hashable, Any]), dict
4295+
assert_type(
4296+
df.to_dict(orient="index", index=True), dict[Hashable, dict[Hashable, Any]]
4297+
),
4298+
dict,
42514299
)
42524300
check(
4253-
assert_type(df.to_dict(orient="split", index=True), dict[Hashable, Any]), dict
4301+
assert_type(df.to_dict(orient="split", index=True), dict[str, list]), dict, str
42544302
)
42554303
check(
4256-
assert_type(df.to_dict(orient="tight", index=True), dict[Hashable, Any]), dict
4304+
assert_type(df.to_dict(orient="tight", index=True), dict[str, list]), dict, str
42574305
)
42584306
check(
4259-
assert_type(df.to_dict(orient="tight", index=False), dict[Hashable, Any]), dict
4307+
assert_type(df.to_dict(orient="tight", index=False), dict[str, list]), dict, str
42604308
)
42614309
check(
4262-
assert_type(df.to_dict(orient="split", index=False), dict[Hashable, Any]), dict
4310+
assert_type(df.to_dict(orient="split", index=False), dict[str, list]), dict, str
42634311
)
42644312
if TYPE_CHECKING_INVALID_USAGE:
42654313
check(assert_type(df.to_dict(orient="records", index=False), list[dict[Hashable, Any]]), list) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]

0 commit comments

Comments
 (0)