@@ -11,6 +11,7 @@ from typing import (
11
11
Generic ,
12
12
Literal ,
13
13
NamedTuple ,
14
+ Protocol ,
14
15
TypeVar ,
15
16
final ,
16
17
overload ,
@@ -208,28 +209,45 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
208
209
209
210
_TT = TypeVar ("_TT" , bound = Literal [True , False ])
210
211
212
+ # ty ignore needed because of https://github.com/astral-sh/ty/issues/157#issuecomment-3017337945
213
+ class DFCallable1 (Protocol [P ]): # ty: ignore[invalid-argument-type]
214
+ def __call__ (
215
+ self , df : DataFrame , / , * args : P .args , ** kwargs : P .kwargs
216
+ ) -> Scalar | list | dict : ...
217
+
218
+ class DFCallable2 (Protocol [P ]): # ty: ignore[invalid-argument-type]
219
+ def __call__ (
220
+ self , df : DataFrame , / , * args : P .args , ** kwargs : P .kwargs
221
+ ) -> DataFrame | Series : ...
222
+
223
+ class DFCallable3 (Protocol [P ]): # ty: ignore[invalid-argument-type]
224
+ def __call__ (self , df : Iterable , / , * args : P .args , ** kwargs : P .kwargs ) -> float : ...
225
+
211
226
class DataFrameGroupBy (GroupBy [DataFrame ], Generic [ByT , _TT ]):
212
227
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
213
228
@overload # type: ignore[override]
214
229
def apply (
215
230
self ,
216
- func : Callable [[DataFrame ], Scalar | list | dict ],
217
- * args ,
218
- ** kwargs ,
231
+ func : DFCallable1 [P ],
232
+ / ,
233
+ * args : P .args ,
234
+ ** kwargs : P .kwargs ,
219
235
) -> Series : ...
220
236
@overload
221
237
def apply (
222
238
self ,
223
- func : Callable [[DataFrame ], Series | DataFrame ],
224
- * args ,
225
- ** kwargs ,
239
+ func : DFCallable2 [P ],
240
+ / ,
241
+ * args : P .args ,
242
+ ** kwargs : P .kwargs ,
226
243
) -> DataFrame : ...
227
244
@overload
228
- def apply ( # pyright: ignore[reportOverlappingOverload]
245
+ def apply (
229
246
self ,
230
- func : Callable [[Iterable ], float ],
231
- * args ,
232
- ** kwargs ,
247
+ func : DFCallable3 [P ],
248
+ / ,
249
+ * args : P .args ,
250
+ ** kwargs : P .kwargs ,
233
251
) -> DataFrame : ...
234
252
# error: overload 1 overlaps overload 2 because of different return types
235
253
@overload
0 commit comments