-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
67 lines (47 loc) · 1.91 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import datetime
from pathlib import Path
import joblib
import pandas as pd
import yfinance as yf
from prophet import Prophet
import argparse
BASE_DIR = Path(__file__).resolve(strict=True).parent
TODAY = datetime.date.today()
def train(ticker="MSFT"):
data = yf.download(ticker, "2020-01-01", TODAY.strftime("%Y-%m-%d"))
df_forecast = data.copy()
df_forecast.reset_index(inplace=True)
df_forecast["ds"] = df_forecast["Date"]
df_forecast["y"] = df_forecast["Adj Close"]
df_forecast = df_forecast[["ds", "y"]]
df_forecast
model = Prophet()
model.fit(df_forecast)
joblib.dump(model, Path(BASE_DIR).joinpath(f"{ticker}.joblib"))
def predict(ticker="MSFT", days=7):
model_file = Path(BASE_DIR).joinpath(f"{ticker}.joblib")
if not model_file.exists():
return False
model = joblib.load(model_file)
future = TODAY + datetime.timedelta(days=days)
dates = pd.date_range(start="2020-01-01", end=future.strftime("%m/%d/%Y"),)
df = pd.DataFrame({"ds": dates})
forecast = model.predict(df)
#model.plot(forecast).savefig(f"{ticker}_plot.png")
#model.plot_components(forecast).savefig(f"{ticker}_plot_components.png")
return forecast.tail(days).to_dict("records")
def convert(prediction_list):
output = {}
for data in prediction_list:
date = data["ds"].strftime("%m/%d/%Y")
output[date] = data["trend"]
return output
# if __name__ == "__main__":
# parser = argparse.ArgumentParser(description='Predict')
# parser.add_argument('--ticker', type=str, default='MSFT', help='Stock Ticker')
# parser.add_argument('--days', type=int, default=7, help='Number of days to predict')
# args = parser.parse_args()
# train(args.ticker)
# prediction_list = predict(ticker=args.ticker, days=args.days)
# output = convert(prediction_list)
# print(output)