Skip to content

Commit

Permalink
momentum indicators (2nd PR) #132 (#132)
Browse files Browse the repository at this point in the history
Ref -> #119
  • Loading branch information
pythonhacker authored Aug 15, 2023
1 parent ba864a8 commit afe4d60
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 38 deletions.
43 changes: 43 additions & 0 deletions example/Example-Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,46 @@
# <codecell>

print(pf.data.loc[pf.data.index.year == 2017].head(3))

# <markdowncell>

# ## Momentum Indicators
# `FinQuant` provides a module `finquant.momentum_indicators` to compute and
# visualize a number of momentum indicators. Currently RSI (Relative Strength Index)
# and MACD (Moving Average Convergence Divergence) indicators are available.
# See below.

# <codecell>
# plot the RSI (Relative Strength Index) for disney stock proces
from finquant.momentum_indicators import relative_strength_index as rsi

# get stock data for disney
dis = pf.get_stock("WIKI/DIS").data.copy(deep=True)

# plot RSI - by default this plots RSI against the price in two graphs
rsi(dis)
plt.show()

# plot RSI with custom arguments
rsi(dis, oversold = 20, overbought = 80)
plt.show()

# plot RSI standalone graph
rsi(dis, oversold = 20, overbought = 80, standalone=True)
plt.show()

# <codecell>
# plot MACD for disney stock proces
from finquant.momentum_indicators import macd

# plot MACD - by default this plots RSI against the price in two graphs
macd(dis)
plt.show()

# plot MACD using custom arguments
macd(dis, longer_ema_window = 30, shorter_ema_window = 15, signal_ema_window = 10)
plt.show()

# plot MACD standalone graph
macd(standlone = True)
plt.show()
62 changes: 24 additions & 38 deletions finquant/momentum_indicators.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
""" This module provides function(s) to compute momentum indicators
used in technical analysis such as RSI """
used in technical analysis such as RSI, MACD etc. """

import matplotlib.pyplot as plt
import pandas as pd

def relative_strength_index(data, window_length: int = 14, oversold: int = 30,
overbought: int = 70, standalone: bool = False) -> None:

def relative_strength_index(
data,
window_length: int = 14,
oversold: int = 30,
overbought: int = 70,
standalone: bool = False,
) -> None:
"""Computes and visualizes a RSI graph,
""" Computes and visualizes a RSI graph,
plotted along with the prices in another sub-graph
for comparison.
Expand Down Expand Up @@ -73,33 +68,24 @@ def relative_strength_index(
# Single plot
fig = plt.figure()
ax = fig.add_subplot(111)
ax.axhline(y=oversold, color="g", linestyle="--")
ax.axhline(y=overbought, color="r", linestyle="--")
data["rsi"].plot(ylabel="RSI", xlabel="Date", ax=ax, grid=True)
ax.axhline(y = oversold, color = 'g', linestyle = '--')
ax.axhline(y = overbought, color = 'r', linestyle ='--')
data['rsi'].plot(ylabel = 'RSI', xlabel = 'Date', ax = ax, grid = True)
plt.title("RSI Plot")
plt.legend()
else:
# RSI against price in 2 plots
fig, ax = plt.subplots(2, 1, sharex=True, sharey=False)
ax[0].axhline(y=oversold, color="g", linestyle="--")
ax[0].axhline(y=overbought, color="r", linestyle="--")
ax[0].set_title("RSI + Price Plot")
ax[0].axhline(y = oversold, color = 'g', linestyle = '--')
ax[0].axhline(y = overbought, color = 'r', linestyle ='--')
ax[0].set_title('RSI + Price Plot')
# plot 2 graphs in 2 colors
colors = plt.rcParams["axes.prop_cycle"]()
data["rsi"].plot(
ylabel="RSI", ax=ax[0], grid=True, color=next(colors)["color"], legend=True
)
data[stock].plot(
xlabel="Date",
ylabel="Price",
ax=ax[1],
grid=True,
color=next(colors)["color"],
legend=True,
)
data['rsi'].plot(ylabel = 'RSI', ax = ax[0], grid = True, color = next(colors)["color"], legend=True)
data[stock].plot(xlabel = 'Date', ylabel = 'Price', ax = ax[1], grid = True,
color = next(colors)["color"], legend = True)
plt.legend()



def macd(
data,
longer_ema_window: int = 26,
Expand Down Expand Up @@ -135,8 +121,8 @@ def macd(
if longer_ema_window < shorter_ema_window:
raise ValueError("longer ema window should be > shorter ema window")
if longer_ema_window < signal_ema_window:
raise ValueError("longer ema window should be > signal ema window")

raise ValueError("longer ema window should be > signal ema window")
# converting data to pd.DataFrame if it is a pd.Series (for subsequent function calls):
if isinstance(data, pd.Series):
data = data.to_frame()
Expand Down Expand Up @@ -183,11 +169,11 @@ def macd(
ax=ax, grid=True, label="SIGNAL", color="red", linewidth=1.5, legend=True
)

for i in range(len(hist)):
if hist[i] < 0:
ax.bar(data.index[i], hist[i], color="orange")
for i, key in enumerate(hist.index):
if hist[key] < 0:
ax.bar(data.index[i], hist[key], color = 'orange')
else:
ax.bar(data.index[i], hist[i], color="black")
ax.bar(data.index[i], hist[key], color = 'black')
else:
# RSI against price in 2 plots
fig, ax = plt.subplots(2, 1, sharex=True, sharey=False)
Expand All @@ -209,11 +195,11 @@ def macd(
ax=ax[0], grid=True, label="SIGNAL", color="red", linewidth=1.5, legend=True
)

for i in range(len(hist)):
if hist[i] < 0:
ax[0].bar(data.index[i], hist[i], color="orange")
for i, key in enumerate(hist.index):
if hist[key] < 0:
ax.bar(data.index[i], hist[key], color = 'orange')
else:
ax[0].bar(data.index[i], hist[i], color="black")
ax.bar(data.index[i], hist[key], color = 'black')

data[stock].plot(
xlabel="Date",
Expand Down
110 changes: 110 additions & 0 deletions tests/test_momentum_indicators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from finquant.momentum_indicators import (
relative_strength_index as rsi,
macd,
)

plt.switch_backend("Agg")

def test_rsi():
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "Price"
df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100))
df.index.name = "Date"
rsi(df)
# get data from axis object
ax = plt.gca()
# ax.lines[0] is the data we passed to plot_bollinger_band
line1 = ax.lines[0]
stock_plot = line1.get_xydata()
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
# tests
assert (df['Stock'].index.values == stock_plot[:, 0]).all()
assert (df["Stock"].values == stock_plot[:, 1]).all()
assert xlabel_orig == xlabel_plot
assert ylabel_orig == ylabel_plot

def test_rsi_standalone():
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "RSI"
labels_orig = ['rsi']
title_orig = 'RSI Plot'
df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100))
df.index.name = "Date"
rsi(df, standalone=True)
# get data from axis object
ax = plt.gca()
# ax.lines[2] is the RSI data
line1 = ax.lines[2]
rsi_plot = line1.get_xydata()
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
print (xlabel_plot, ylabel_plot)
# tests
assert (df['rsi'].index.values == rsi_plot[:, 0]).all()
# for comparing values, we need to remove nan
a, b = df['rsi'].values, rsi_plot[:, 1]
a, b = map(lambda x: x[~np.isnan(x)], (a, b))
assert (a == b).all()
labels_plot = ax.get_legend_handles_labels()[1]
title_plot = ax.get_title()
assert labels_plot == labels_orig
assert xlabel_plot == xlabel_orig
assert ylabel_plot == ylabel_orig
assert title_plot == title_orig

def test_macd():
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "Price"
df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100))
df.index.name = "Date"
macd(df)
# get data from axis object
ax = plt.gca()
# ax.lines[0] is the data we passed to plot_bollinger_band
line1 = ax.lines[0]
stock_plot = line1.get_xydata()
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
# tests
assert (df['Stock'].index.values == stock_plot[:, 0]).all()
assert (df["Stock"].values == stock_plot[:, 1]).all()
assert xlabel_orig == xlabel_plot
assert ylabel_orig == ylabel_plot

def test_macd_standalone():
labels_orig = ['MACD', 'diff', 'SIGNAL']
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "MACD"
df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100))
df.index.name = "Date"
macd(df, standalone=True)
# get data from axis object
ax = plt.gca()
labels_plot = ax.get_legend_handles_labels()[1]
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
assert labels_plot == labels_orig
assert xlabel_plot == xlabel_orig
assert ylabel_plot == ylabel_orig
# ax.lines[0] is macd data
# ax.lines[1] is diff data
# ax.lines[2] is macd_s data
# tests
for i, key in ((0, 'macd'), (1, 'diff'), (2, 'macd_s')):
line = ax.lines[i]
data_plot = line.get_xydata()
# tests
assert (df[key].index.values == data_plot[:, 0]).all()
# for comparing values, we need to remove nan
a, b = df[key].values, data_plot[:, 1]
a, b = map(lambda x: x[~np.isnan(x)], (a, b))
assert (a == b).all()

0 comments on commit afe4d60

Please sign in to comment.