Skip to content

Commit 6285398

Browse files
committedFeb 13, 2025·
Optimize RSI indicator with Numba JIT compilation
- Implement Numba-accelerated RSI calculation function - Replace vectorized implementation with loop-based Numba implementation - Improve computational efficiency for Relative Strength Index calculation - Maintain consistent function interface and return types
1 parent f1e9191 commit 6285398

File tree

1 file changed

+53
-67
lines changed

1 file changed

+53
-67
lines changed
 

‎jesse/indicators/rsi.py

+53-67
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,61 @@
11
import numpy as np
22
from typing import Union
3+
from numba import njit
34

45
from jesse.helpers import get_candle_source, slice_candles
56

67

8+
@njit
9+
def _rsi(p: np.ndarray, period: int) -> np.ndarray:
10+
"""
11+
Compute the Relative Strength Index using a loop and Wilder's smoothing.
12+
"""
13+
n = len(p)
14+
rsi_arr = np.full(n, np.nan)
15+
if n < period + 1:
16+
return rsi_arr
17+
# Calculate differences between consecutive prices.
18+
diff = np.empty(n - 1)
19+
for i in range(n - 1):
20+
diff[i] = p[i+1] - p[i]
21+
22+
# Compute initial average gain and loss over the first 'period' differences.
23+
sum_gain = 0.0
24+
sum_loss = 0.0
25+
for i in range(period):
26+
change = diff[i]
27+
if change > 0:
28+
sum_gain += change
29+
else:
30+
sum_loss += -change
31+
avg_gain = sum_gain / period
32+
avg_loss = sum_loss / period
33+
34+
# Compute first RSI value at index 'period'
35+
if avg_loss == 0:
36+
rsi_arr[period] = 100.0
37+
else:
38+
rs = avg_gain / avg_loss
39+
rsi_arr[period] = 100 - (100 / (1 + rs))
40+
41+
# Recursively update average gain and loss and compute subsequent RSI values.
42+
for i in range(period, n - 1):
43+
change = diff[i]
44+
gain = change if change > 0 else 0.0
45+
loss = -change if change < 0 else 0.0
46+
avg_gain = (avg_gain * (period - 1) + gain) / period
47+
avg_loss = (avg_loss * (period - 1) + loss) / period
48+
if avg_loss == 0:
49+
rsi_arr[i+1] = 100.0
50+
else:
51+
rs = avg_gain / avg_loss
52+
rsi_arr[i+1] = 100 - (100 / (1 + rs))
53+
return rsi_arr
54+
55+
756
def rsi(candles: np.ndarray, period: int = 14, source_type: str = "close", sequential: bool = False) -> Union[float, np.ndarray]:
857
"""
9-
RSI - Relative Strength Index
58+
RSI - Relative Strength Index using Numba for optimization
1059
1160
:param candles: np.ndarray
1261
:param period: int - default: 14
@@ -20,70 +69,7 @@ def rsi(candles: np.ndarray, period: int = 14, source_type: str = "close", seque
2069
else:
2170
candles = slice_candles(candles, sequential)
2271
source = get_candle_source(candles, source_type=source_type)
23-
24-
# Convert source to a numpy array of floats
72+
2573
p = np.asarray(source, dtype=float)
26-
n = len(p)
27-
28-
# Not enough data to compute RSI if we don't have at least period+1 price points
29-
if n < period + 1:
30-
return np.nan if not sequential else np.full(n, np.nan)
31-
32-
# Compute price differences
33-
delta = np.diff(p)
34-
gains = np.where(delta > 0, delta, 0)
35-
losses = np.where(delta < 0, -delta, 0)
36-
37-
# Initialize average gains and losses arrays (length equal to len(gains))
38-
avg_gain = np.zeros_like(gains)
39-
avg_loss = np.zeros_like(losses)
40-
41-
# Vectorized computation of average gains and losses using matrix operations to mimic Wilder's smoothing method
42-
alpha = 1/period
43-
# Compute the initial simple averages
44-
A0_gain = np.mean(gains[:period])
45-
A0_loss = np.mean(losses[:period])
46-
# The number of smoothed values to compute (for gains and losses) equals the length of gains from index (period-1) to end
47-
# Since gains has length (n-1), we want L = (n-1) - (period-1) = n - period
48-
L = len(gains) - (period - 1)
49-
50-
# Vectorized smoothing for avg_gain:
51-
# x_gain: gains from index 'period' onward, length = (n-1) - period = L - 1
52-
x_gain = gains[period:]
53-
if L > 1:
54-
t_values = np.arange(1, L) # t=1,...,L-1
55-
# Build lower triangular matrix of shape (L-1, L-1) with elements (1-alpha)^(t-1-k)
56-
M_gain = np.tril((1 - alpha) ** (np.subtract.outer(np.arange(L-1), np.arange(L-1))))
57-
# weighted sum: for each t from 1 to L-1, sum_{k=0}^{t-1} (1-alpha)^(t-1-k)*x_gain[k]
58-
weighted_sum_gain = M_gain.dot(x_gain)
59-
A_gain = np.empty(L)
60-
A_gain[0] = A0_gain
61-
A_gain[1:] = A0_gain * ((1 - alpha) ** t_values) + alpha * weighted_sum_gain
62-
else:
63-
A_gain = np.array([A0_gain])
64-
avg_gain[period - 1:] = A_gain
65-
66-
# Vectorized smoothing for avg_loss:
67-
x_loss = losses[period:]
68-
if L > 1:
69-
t_values = np.arange(1, L)
70-
M_loss = np.tril((1 - alpha) ** (np.subtract.outer(np.arange(L-1), np.arange(L-1))))
71-
weighted_sum_loss = M_loss.dot(x_loss)
72-
A_loss = np.empty(L)
73-
A_loss[0] = A0_loss
74-
A_loss[1:] = A0_loss * ((1 - alpha) ** t_values) + alpha * weighted_sum_loss
75-
else:
76-
A_loss = np.array([A0_loss])
77-
avg_loss[period - 1:] = A_loss
78-
79-
# Prepare RSI result array of same length as price array, fill initial values with NaN
80-
rsi_values = np.full(n, np.nan)
81-
82-
# Vectorized computation of RSI for indices period to end
83-
with np.errstate(divide='ignore', invalid='ignore'):
84-
rs = avg_gain[period - 1:] / avg_loss[period - 1:]
85-
rsi_comp = 100 - 100 / (1 + rs)
86-
rsi_comp = np.where(avg_loss[period - 1:] == 0, 100.0, rsi_comp)
87-
rsi_values[period:] = rsi_comp
88-
89-
return rsi_values if sequential else rsi_values[-1]
74+
result = _rsi(p, period)
75+
return result if sequential else result[-1]

0 commit comments

Comments
 (0)
Please sign in to comment.