diff --git a/example/Example-Analysis.py b/example/Example-Analysis.py index ba0f4682..b9fa5ef3 100644 --- a/example/Example-Analysis.py +++ b/example/Example-Analysis.py @@ -309,3 +309,46 @@ # print(pf.data.loc[pf.data.index.year == 2017].head(3)) + +# + +# ## Momentum Indicators +# `FinQuant` provides a module `finquant.momentum_indicators` to compute and +# visualize a number of momentum indicators. Currently RSI (Relative Strength Index) +# and MACD (Moving Average Convergence Divergence) indicators are available. +# See below. + +# +# plot the RSI (Relative Strength Index) for disney stock proces +from finquant.momentum_indicators import relative_strength_index as rsi + +# get stock data for disney +dis = pf.get_stock("WIKI/DIS").data.copy(deep=True) + +# plot RSI - by default this plots RSI against the price in two graphs +rsi(dis) +plt.show() + +# plot RSI with custom arguments +rsi(dis, oversold = 20, overbought = 80) +plt.show() + +# plot RSI standalone graph +rsi(dis, oversold = 20, overbought = 80, standalone=True) +plt.show() + +# +# plot MACD for disney stock proces +from finquant.momentum_indicators import macd + +# plot MACD - by default this plots RSI against the price in two graphs +macd(dis) +plt.show() + +# plot MACD using custom arguments +macd(dis, longer_ema_window = 30, shorter_ema_window = 15, signal_ema_window = 10) +plt.show() + +# plot MACD standalone graph +macd(standlone = True) +plt.show() diff --git a/tests/test_momentum_indicators.py b/tests/test_momentum_indicators.py new file mode 100644 index 00000000..f3b091f6 --- /dev/null +++ b/tests/test_momentum_indicators.py @@ -0,0 +1,110 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from finquant.momentum_indicators import ( + relative_strength_index as rsi, + macd, +) + +plt.switch_backend("Agg") + +def test_rsi(): + x = np.sin(np.linspace(1, 10, 100)) + xlabel_orig = "Date" + ylabel_orig = "Price" + df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100)) + df.index.name = "Date" + rsi(df) + # get data from axis object + ax = plt.gca() + # ax.lines[0] is the data we passed to plot_bollinger_band + line1 = ax.lines[0] + stock_plot = line1.get_xydata() + xlabel_plot = ax.get_xlabel() + ylabel_plot = ax.get_ylabel() + # tests + assert (df['Stock'].index.values == stock_plot[:, 0]).all() + assert (df["Stock"].values == stock_plot[:, 1]).all() + assert xlabel_orig == xlabel_plot + assert ylabel_orig == ylabel_plot + +def test_rsi_standalone(): + x = np.sin(np.linspace(1, 10, 100)) + xlabel_orig = "Date" + ylabel_orig = "RSI" + labels_orig = ['rsi'] + title_orig = 'RSI Plot' + df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100)) + df.index.name = "Date" + rsi(df, standalone=True) + # get data from axis object + ax = plt.gca() + # ax.lines[2] is the RSI data + line1 = ax.lines[2] + rsi_plot = line1.get_xydata() + xlabel_plot = ax.get_xlabel() + ylabel_plot = ax.get_ylabel() + print (xlabel_plot, ylabel_plot) + # tests + assert (df['rsi'].index.values == rsi_plot[:, 0]).all() + # for comparing values, we need to remove nan + a, b = df['rsi'].values, rsi_plot[:, 1] + a, b = map(lambda x: x[~np.isnan(x)], (a, b)) + assert (a == b).all() + labels_plot = ax.get_legend_handles_labels()[1] + title_plot = ax.get_title() + assert labels_plot == labels_orig + assert xlabel_plot == xlabel_orig + assert ylabel_plot == ylabel_orig + assert title_plot == title_orig + +def test_macd(): + x = np.sin(np.linspace(1, 10, 100)) + xlabel_orig = "Date" + ylabel_orig = "Price" + df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100)) + df.index.name = "Date" + macd(df) + # get data from axis object + ax = plt.gca() + # ax.lines[0] is the data we passed to plot_bollinger_band + line1 = ax.lines[0] + stock_plot = line1.get_xydata() + xlabel_plot = ax.get_xlabel() + ylabel_plot = ax.get_ylabel() + # tests + assert (df['Stock'].index.values == stock_plot[:, 0]).all() + assert (df["Stock"].values == stock_plot[:, 1]).all() + assert xlabel_orig == xlabel_plot + assert ylabel_orig == ylabel_plot + +def test_macd_standalone(): + labels_orig = ['MACD', 'diff', 'SIGNAL'] + x = np.sin(np.linspace(1, 10, 100)) + xlabel_orig = "Date" + ylabel_orig = "MACD" + df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100)) + df.index.name = "Date" + macd(df, standalone=True) + # get data from axis object + ax = plt.gca() + labels_plot = ax.get_legend_handles_labels()[1] + xlabel_plot = ax.get_xlabel() + ylabel_plot = ax.get_ylabel() + assert labels_plot == labels_orig + assert xlabel_plot == xlabel_orig + assert ylabel_plot == ylabel_orig + # ax.lines[0] is macd data + # ax.lines[1] is diff data + # ax.lines[2] is macd_s data + # tests + for i, key in ((0, 'macd'), (1, 'diff'), (2, 'macd_s')): + line = ax.lines[i] + data_plot = line.get_xydata() + # tests + assert (df[key].index.values == data_plot[:, 0]).all() + # for comparing values, we need to remove nan + a, b = df[key].values, data_plot[:, 1] + a, b = map(lambda x: x[~np.isnan(x)], (a, b)) + assert (a == b).all()