|
| 1 | +from typing import List |
| 2 | + |
| 3 | +def calculate_rsi(prices: List[float], period: int = 14) -> List[float]: |
| 4 | + """ |
| 5 | + Calculate the Relative Strength Index (RSI) for a given list of prices. |
| 6 | +
|
| 7 | + RSI is a momentum oscillator that measures the speed and change of price movements. |
| 8 | + It is typically used to identify overbought or oversold conditions in a market. |
| 9 | +
|
| 10 | + Args: |
| 11 | + prices (List[float]): A list of prices for a financial asset. |
| 12 | + period (int): The number of periods to use in the calculation (default is 14). |
| 13 | +
|
| 14 | + Returns: |
| 15 | + List[float]: A list of RSI values corresponding to the input price data. |
| 16 | +
|
| 17 | + Example: |
| 18 | + >>> rsi_values = calculate_rsi([44.0, 44.15, 44.09, 44.20, 44.30, 44.25, 44.40, 44.35, 44.50, 44.60, 44.55, 44.75, 44.80, 44.70, 44.85], 14) |
| 19 | + >>> print(rsi_values) # doctest: +ELLIPSIS |
| 20 | + [78.91..., 80.99...] |
| 21 | + |
| 22 | + Reference: |
| 23 | + https://en.wikipedia.org/wiki/Relative_strength_index |
| 24 | + """ |
| 25 | + # Validate that there are enough prices to calculate RSI |
| 26 | + if len(prices) < period: |
| 27 | + raise ValueError("Not enough price data to calculate RSI.") |
| 28 | + |
| 29 | + gains = [] |
| 30 | + losses = [] |
| 31 | + |
| 32 | + # Calculate price changes between consecutive days |
| 33 | + for i in range(1, len(prices)): |
| 34 | + delta = prices[i] - prices[i - 1] |
| 35 | + gains.append(max(0, delta)) |
| 36 | + losses.append(max(0, -delta)) |
| 37 | + |
| 38 | + # Initial averages for gain and loss |
| 39 | + avg_gain = float(sum(gains[:period]) / period) |
| 40 | + avg_loss = float(sum(losses[:period]) / period) |
| 41 | + |
| 42 | + # Initialize RSI list and first RSI value |
| 43 | + rsi_values: List[float] = [] |
| 44 | + |
| 45 | + if avg_loss == 0: |
| 46 | + rsi = 100.0 |
| 47 | + else: |
| 48 | + rs = avg_gain / avg_loss |
| 49 | + rsi = 100.0 - (100.0 / (1.0 + rs)) |
| 50 | + |
| 51 | + rsi_values.append(rsi) |
| 52 | + |
| 53 | + # Calculate subsequent RSI values |
| 54 | + for i in range(period, len(prices)): |
| 55 | + delta = prices[i] - prices[i - 1] |
| 56 | + gain = max(0, delta) |
| 57 | + loss = max(0, -delta) |
| 58 | + |
| 59 | + avg_gain = (avg_gain * (period - 1) + gain) / period |
| 60 | + avg_loss = (avg_loss * (period - 1) + loss) / period |
| 61 | + |
| 62 | + if avg_loss == 0: |
| 63 | + rsi = 100.0 |
| 64 | + else: |
| 65 | + rs = avg_gain / avg_loss |
| 66 | + rsi = 100.0 - (100.0 / (1.0 + rs)) |
| 67 | + |
| 68 | + rsi_values.append(rsi) |
| 69 | + |
| 70 | + return rsi_values |
| 71 | + |
| 72 | + |
| 73 | +if __name__ == "__main__": |
| 74 | + prices = [44.0, 44.15, 44.09, 44.20, 44.30, 44.25, 44.40, 44.35, 44.50, 44.60, 44.55, 44.75, 44.80, 44.70, 44.85] |
| 75 | + rsi = calculate_rsi(prices, 14) |
| 76 | + print("RSI Values:", rsi) |
0 commit comments