|
| 1 | +from typing import Union |
| 2 | +from pandas import DataFrame as PdDataFrame |
| 3 | +from polars import DataFrame as PlDataFrame |
| 4 | +import polars as pl |
| 5 | +from pyindicators.exceptions import PyIndicatorException |
| 6 | + |
| 7 | + |
| 8 | +def cci( |
| 9 | + data: Union[PdDataFrame, PlDataFrame], |
| 10 | + high_column='High', |
| 11 | + low_column='Low', |
| 12 | + close_column='Close', |
| 13 | + period=20, |
| 14 | + result_column='CCI' |
| 15 | +) -> Union[PdDataFrame, PlDataFrame]: |
| 16 | + """ |
| 17 | + Calculate the Commodity Channel Index (CCI) for a price series. |
| 18 | +
|
| 19 | + Args: |
| 20 | + data: Input DataFrame (pandas or polars). |
| 21 | + high_column: Name of the column with high prices. |
| 22 | + low_column: Name of the column with low prices. |
| 23 | + close_column: Name of the column with close prices. |
| 24 | + period: Lookback period for CCI calculation. |
| 25 | + result_column: Name of the result column to store CCI values. |
| 26 | +
|
| 27 | + Returns the original DataFrame with a new column for CCI. |
| 28 | + """ |
| 29 | + if isinstance(data, PdDataFrame): |
| 30 | + # Calculate CCI for pandas DataFrame |
| 31 | + typical_price = (data[high_column] + |
| 32 | + data[low_column] + data[close_column]) / 3 |
| 33 | + sma = typical_price.rolling(window=period).mean() |
| 34 | + mad = (typical_price - sma).abs().rolling(window=period).mean() |
| 35 | + data[result_column] = (typical_price - sma) / (0.015 * mad) |
| 36 | + return data |
| 37 | + |
| 38 | + elif isinstance(data, PlDataFrame): |
| 39 | + # Calculate CCI for polars DataFrame |
| 40 | + typical_price = (pl.col(high_column) |
| 41 | + + pl.col(low_column) |
| 42 | + + pl.col(close_column)) / 3 |
| 43 | + sma = typical_price.rolling_mean(window_size=period) |
| 44 | + mad = (typical_price - sma).abs().rolling_mean(window_size=period) |
| 45 | + return data.with_columns( |
| 46 | + (typical_price - sma) / (0.015 * mad).alias(result_column) |
| 47 | + ) |
| 48 | + |
| 49 | + else: |
| 50 | + raise PyIndicatorException( |
| 51 | + "Input data must be a pandas or polars DataFrame." |
| 52 | + ) |
0 commit comments