|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from collections.abc import ( |
| 4 | + Callable, |
| 5 | + Hashable, |
| 6 | +) |
| 7 | +from typing import ( |
| 8 | + TYPE_CHECKING, |
| 9 | + Any, |
| 10 | +) |
| 11 | + |
| 12 | +from pandas.core.series import Series |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from pandas import DataFrame |
| 16 | + |
| 17 | + |
| 18 | +# Used only for generating the str repr of expressions. |
| 19 | +_OP_SYMBOLS = { |
| 20 | + "__add__": "+", |
| 21 | + "__radd__": "+", |
| 22 | + "__sub__": "-", |
| 23 | + "__rsub__": "-", |
| 24 | + "__mul__": "*", |
| 25 | + "__rmul__": "*", |
| 26 | + "__truediv__": "/", |
| 27 | + "__rtruediv__": "/", |
| 28 | + "__floordiv__": "//", |
| 29 | + "__rfloordiv__": "//", |
| 30 | + "__mod__": "%", |
| 31 | + "__rmod__": "%", |
| 32 | + "__ge__": ">=", |
| 33 | + "__gt__": ">", |
| 34 | + "__le__": "<=", |
| 35 | + "__lt__": "<", |
| 36 | + "__eq__": "==", |
| 37 | + "__ne__": "!=", |
| 38 | +} |
| 39 | + |
| 40 | + |
| 41 | +def _parse_args(df: DataFrame, *args: Any) -> tuple[Series]: |
| 42 | + # Parse `args`, evaluating any expressions we encounter. |
| 43 | + return tuple([x(df) if isinstance(x, Expression) else x for x in args]) |
| 44 | + |
| 45 | + |
| 46 | +def _parse_kwargs(df: DataFrame, **kwargs: Any) -> dict[str, Any]: |
| 47 | + # Parse `kwargs`, evaluating any expressions we encounter. |
| 48 | + return { |
| 49 | + key: val(df) if isinstance(val, Expression) else val |
| 50 | + for key, val in kwargs.items() |
| 51 | + } |
| 52 | + |
| 53 | + |
| 54 | +def _pretty_print_args_kwargs(*args: Any, **kwargs: Any) -> str: |
| 55 | + inputs_repr = ", ".join( |
| 56 | + arg._repr_str if isinstance(arg, Expression) else repr(arg) for arg in args |
| 57 | + ) |
| 58 | + kwargs_repr = ", ".join( |
| 59 | + f"{k}={v._repr_str if isinstance(v, Expression) else v!r}" |
| 60 | + for k, v in kwargs.items() |
| 61 | + ) |
| 62 | + |
| 63 | + all_args = [] |
| 64 | + if inputs_repr: |
| 65 | + all_args.append(inputs_repr) |
| 66 | + if kwargs_repr: |
| 67 | + all_args.append(kwargs_repr) |
| 68 | + |
| 69 | + return ", ".join(all_args) |
| 70 | + |
| 71 | + |
| 72 | +class Expression: |
| 73 | + """ |
| 74 | + Class representing a deferred column. |
| 75 | +
|
| 76 | + This is not meant to be instantiated directly. Instead, use :meth:`pandas.col`. |
| 77 | + """ |
| 78 | + |
| 79 | + def __init__(self, func: Callable[[DataFrame], Any], repr_str: str) -> None: |
| 80 | + self._func = func |
| 81 | + self._repr_str = repr_str |
| 82 | + |
| 83 | + def __call__(self, df: DataFrame) -> Any: |
| 84 | + return self._func(df) |
| 85 | + |
| 86 | + def _with_binary_op(self, op: str, other: Any) -> Expression: |
| 87 | + op_symbol = _OP_SYMBOLS.get(op, op) |
| 88 | + |
| 89 | + if isinstance(other, Expression): |
| 90 | + if op.startswith("__r"): |
| 91 | + repr_str = f"({other._repr_str} {op_symbol} {self._repr_str})" |
| 92 | + else: |
| 93 | + repr_str = f"({self._repr_str} {op_symbol} {other._repr_str})" |
| 94 | + return Expression(lambda df: getattr(self(df), op)(other(df)), repr_str) |
| 95 | + else: |
| 96 | + if op.startswith("__r"): |
| 97 | + repr_str = f"({other!r} {op_symbol} {self._repr_str})" |
| 98 | + else: |
| 99 | + repr_str = f"({self._repr_str} {op_symbol} {other!r})" |
| 100 | + return Expression(lambda df: getattr(self(df), op)(other), repr_str) |
| 101 | + |
| 102 | + # Binary ops |
| 103 | + def __add__(self, other: Any) -> Expression: |
| 104 | + return self._with_binary_op("__add__", other) |
| 105 | + |
| 106 | + def __radd__(self, other: Any) -> Expression: |
| 107 | + return self._with_binary_op("__radd__", other) |
| 108 | + |
| 109 | + def __sub__(self, other: Any) -> Expression: |
| 110 | + return self._with_binary_op("__sub__", other) |
| 111 | + |
| 112 | + def __rsub__(self, other: Any) -> Expression: |
| 113 | + return self._with_binary_op("__rsub__", other) |
| 114 | + |
| 115 | + def __mul__(self, other: Any) -> Expression: |
| 116 | + return self._with_binary_op("__mul__", other) |
| 117 | + |
| 118 | + def __rmul__(self, other: Any) -> Expression: |
| 119 | + return self._with_binary_op("__rmul__", other) |
| 120 | + |
| 121 | + def __truediv__(self, other: Any) -> Expression: |
| 122 | + return self._with_binary_op("__truediv__", other) |
| 123 | + |
| 124 | + def __rtruediv__(self, other: Any) -> Expression: |
| 125 | + return self._with_binary_op("__rtruediv__", other) |
| 126 | + |
| 127 | + def __floordiv__(self, other: Any) -> Expression: |
| 128 | + return self._with_binary_op("__floordiv__", other) |
| 129 | + |
| 130 | + def __rfloordiv__(self, other: Any) -> Expression: |
| 131 | + return self._with_binary_op("__rfloordiv__", other) |
| 132 | + |
| 133 | + def __ge__(self, other: Any) -> Expression: |
| 134 | + return self._with_binary_op("__ge__", other) |
| 135 | + |
| 136 | + def __gt__(self, other: Any) -> Expression: |
| 137 | + return self._with_binary_op("__gt__", other) |
| 138 | + |
| 139 | + def __le__(self, other: Any) -> Expression: |
| 140 | + return self._with_binary_op("__le__", other) |
| 141 | + |
| 142 | + def __lt__(self, other: Any) -> Expression: |
| 143 | + return self._with_binary_op("__lt__", other) |
| 144 | + |
| 145 | + def __eq__(self, other: object) -> Expression: # type: ignore[override] |
| 146 | + return self._with_binary_op("__eq__", other) |
| 147 | + |
| 148 | + def __ne__(self, other: object) -> Expression: # type: ignore[override] |
| 149 | + return self._with_binary_op("__ne__", other) |
| 150 | + |
| 151 | + def __mod__(self, other: Any) -> Expression: |
| 152 | + return self._with_binary_op("__mod__", other) |
| 153 | + |
| 154 | + def __rmod__(self, other: Any) -> Expression: |
| 155 | + return self._with_binary_op("__rmod__", other) |
| 156 | + |
| 157 | + def __array_ufunc__( |
| 158 | + self, ufunc: Callable[..., Any], method: str, *inputs: Any, **kwargs: Any |
| 159 | + ) -> Expression: |
| 160 | + def func(df: DataFrame) -> Any: |
| 161 | + parsed_inputs = _parse_args(df, *inputs) |
| 162 | + parsed_kwargs = _parse_kwargs(df, *kwargs) |
| 163 | + return ufunc(*parsed_inputs, **parsed_kwargs) |
| 164 | + |
| 165 | + args_str = _pretty_print_args_kwargs(*inputs, **kwargs) |
| 166 | + repr_str = f"{ufunc.__name__}({args_str})" |
| 167 | + |
| 168 | + return Expression(func, repr_str) |
| 169 | + |
| 170 | + # Everything else |
| 171 | + def __getattr__(self, attr: str, /) -> Any: |
| 172 | + if attr in Series._accessors: |
| 173 | + return NamespaceExpression(self, attr) |
| 174 | + |
| 175 | + def func(df: DataFrame, *args: Any, **kwargs: Any) -> Any: |
| 176 | + parsed_args = _parse_args(df, *args) |
| 177 | + parsed_kwargs = _parse_kwargs(df, **kwargs) |
| 178 | + return getattr(self(df), attr)(*parsed_args, **parsed_kwargs) |
| 179 | + |
| 180 | + def wrapper(*args: Any, **kwargs: Any) -> Expression: |
| 181 | + args_str = _pretty_print_args_kwargs(*args, **kwargs) |
| 182 | + repr_str = f"{self._repr_str}.{attr}({args_str})" |
| 183 | + |
| 184 | + return Expression(lambda df: func(df, *args, **kwargs), repr_str) |
| 185 | + |
| 186 | + return wrapper |
| 187 | + |
| 188 | + def __repr__(self) -> str: |
| 189 | + return self._repr_str or "Expr(...)" |
| 190 | + |
| 191 | + |
| 192 | +class NamespaceExpression: |
| 193 | + def __init__(self, func: Expression, namespace: str) -> None: |
| 194 | + self._func = func |
| 195 | + self._namespace = namespace |
| 196 | + |
| 197 | + def __call__(self, df: DataFrame) -> Any: |
| 198 | + return self._func(df) |
| 199 | + |
| 200 | + def __getattr__(self, attr: str) -> Any: |
| 201 | + if isinstance(getattr(getattr(Series, self._namespace), attr), property): |
| 202 | + repr_str = f"{self._func._repr_str}.{self._namespace}.{attr}" |
| 203 | + return Expression( |
| 204 | + lambda df: getattr(getattr(self(df), self._namespace), attr), |
| 205 | + repr_str, |
| 206 | + ) |
| 207 | + |
| 208 | + def func(df: DataFrame, *args: Any, **kwargs: Any) -> Any: |
| 209 | + parsed_args = _parse_args(df, *args) |
| 210 | + parsed_kwargs = _parse_kwargs(df, **kwargs) |
| 211 | + return getattr(getattr(self(df), self._namespace), attr)( |
| 212 | + *parsed_args, **parsed_kwargs |
| 213 | + ) |
| 214 | + |
| 215 | + def wrapper(*args: Any, **kwargs: Any) -> Expression: |
| 216 | + args_str = _pretty_print_args_kwargs(*args, **kwargs) |
| 217 | + repr_str = f"{self._func._repr_str}.{self._namespace}.{attr}({args_str})" |
| 218 | + return Expression(lambda df: func(df, *args, **kwargs), repr_str) |
| 219 | + |
| 220 | + return wrapper |
| 221 | + |
| 222 | + |
| 223 | +def col(col_name: Hashable) -> Expression: |
| 224 | + """ |
| 225 | + Generate deferred object representing a column of a DataFrame. |
| 226 | +
|
| 227 | + Any place which accepts ``lambda df: df[col_name]``, such as |
| 228 | + :meth:`DataFrame.assign` or :meth:`DataFrame.loc`, can also accept |
| 229 | + ``pd.col(col_name)``. |
| 230 | +
|
| 231 | + Parameters |
| 232 | + ---------- |
| 233 | + col_name : Hashable |
| 234 | + Column name. |
| 235 | +
|
| 236 | + Returns |
| 237 | + ------- |
| 238 | + `pandas.api.typing.Expression` |
| 239 | + A deferred object representing a column of a DataFrame. |
| 240 | +
|
| 241 | + See Also |
| 242 | + -------- |
| 243 | + DataFrame.query : Query columns of a dataframe using string expressions. |
| 244 | +
|
| 245 | + Examples |
| 246 | + -------- |
| 247 | +
|
| 248 | + You can use `col` in `assign`. |
| 249 | +
|
| 250 | + >>> df = pd.DataFrame({"name": ["beluga", "narwhal"], "speed": [100, 110]}) |
| 251 | + >>> df.assign(name_titlecase=pd.col("name").str.title()) |
| 252 | + name speed name_titlecase |
| 253 | + 0 beluga 100 Beluga |
| 254 | + 1 narwhal 110 Narwhal |
| 255 | +
|
| 256 | + You can also use it for filtering. |
| 257 | +
|
| 258 | + >>> df.loc[pd.col("speed") > 105] |
| 259 | + name speed |
| 260 | + 1 narwhal 110 |
| 261 | + """ |
| 262 | + if not isinstance(col_name, Hashable): |
| 263 | + msg = f"Expected Hashable, got: {type(col_name)}" |
| 264 | + raise TypeError(msg) |
| 265 | + |
| 266 | + def func(df: DataFrame) -> Series: |
| 267 | + if col_name not in df.columns: |
| 268 | + columns_str = str(df.columns.tolist()) |
| 269 | + max_len = 90 |
| 270 | + if len(columns_str) > max_len: |
| 271 | + columns_str = columns_str[:max_len] + "...]" |
| 272 | + |
| 273 | + msg = ( |
| 274 | + f"Column '{col_name}' not found in given DataFrame.\n\n" |
| 275 | + f"Hint: did you mean one of {columns_str} instead?" |
| 276 | + ) |
| 277 | + raise ValueError(msg) |
| 278 | + return df[col_name] |
| 279 | + |
| 280 | + return Expression(func, f"col({col_name!r})") |
| 281 | + |
| 282 | + |
| 283 | +__all__ = ["Expression", "col"] |
0 commit comments