引き続きUdemyでPython for Time Series Data Analysisの学習を進めている。クラスのシラバスでは、これまででPython/numpy/pandasの基本学習は完了したので、Time Series Dataの解析・予測手法へ移っていく。内容は大ボリュームで最終的には全てを学習するが、まずは簡単に導入でき予測結果を評価できる手法を整備するために、FacebookのProphetライブラリの学習を進める。
Prophetは基本的には、日、週、月や年周期での時系列データの予測に使用できるようだ。時間の変数dsをdatetimeクラスで設定することが重要。簡単に予測が得られる。
from fbprophet import Prophet import matplotlib.pyplot as plt from datetime import datetime %matplotlib inline df = pd.read_csv('./UDEMY_TSA_FINAL/Data/BeerWineLiquor.csv') df.info df.columns=['ds','y'] df.head() df['ds'] = pd.to_datetime(df['ds']) df.head() df.info() m = Prophet() m.fit(df) future = m.make_future_dataframe(periods=24,freq='MS') future.tail() df.tail() len(df) len(future) forcast = m.predict(future) forcast.head() forcast.columns forcast[['ds','yhat_lower','yhat_upper','yhat']].tail(12) m.plot(forcast) xmin = datetime(2014,1,1) xmax = datetime(2021,1,1) plt.xlim(xmin,xmax) forcast.plot('ds','yhat') m.plot_components(forcast)
季節性のトレンドも簡単に分離可能。
一定期間学習を行い、その先の将来予測を行う例。
import pandas as pd from fbprophet import Prophet from fbprophet.diagnostics import cross_validation,performance_metrics from fbprophet.plot import plot_cross_validation_metric from datetime import datetime from statsmodels.tools.eval_measures import rmse df = pd.read_csv('./UDEMY_TSA_FINAL/Data/Miles_Traveled.csv') df.info() df.columns = ['ds','y'] df['ds'] = pd.to_datetime(df['ds']) df.head() df.plot(x='ds',y='y') len(df) train = df.iloc[:576] test = df.iloc[576:] m = Prophet() m.fit(train) future = m.make_future_dataframe(periods=12,freq='MS') forecast = m.predict(future) forecast.tail ax = forecast.plot(x='ds',y='yhat',label='Predicitons',legend=True,figsize=(12,8)) xmin=datetime(2018,1,1) xmax=datetime(2019,1,1) test.plot(x='ds',y='y',label='True Test Data',legend=True,ax=ax,xlim=(xmin,xmax)) predictions = forecast.iloc[-12:]['yhat'] predictions rmse(predictions,test['y']) test.mean() initial = 5*365 initial = str(initial) + ' days' initial period = 5*365 period = str(period) + ' days' horizon = 365 horizon = str(horizon) + ' days' df_cv = cross_validation(m,initial=initial,period=period,horizon=horizon) df_cv.head() len(df_cv) performance_metrics(df_cv) plot_cross_validation_metric(df_cv,metric='rmse')
トレンド転換点を探す例。
import pandas as pd from fbprophet import Prophet from fbprophet.plot import add_changepoints_to_plot df = pd.read_csv('./UDEMY_TSA_FINAL/Data/HospitalityEmployees.csv') df.columns = ['ds','y'] df['ds'] = pd.to_datetime(df['ds']) df.plot(x='ds',y='y',figsize=(12,10)) m = Prophet() m.fit(df) future = m.make_future_dataframe(periods=12,freq='MS') forecast = m.predict(future) fig = m.plot(forecast) a = add_changepoints_to_plot(fig.gca(),m,forecast)
季節性のトレンドを分析し、将来予測を行う例。
import pandas as pd from fbprophet import Prophet df = pd.read_csv('./UDEMY_TSA_FINAL/Data/airline_passengers.csv') df.head() df.columns = ['ds','y'] df['ds'] = pd.to_datetime(df['ds']) m = Prophet() m.fit(df) future = m.make_future_dataframe(50,freq='MS') forecast = m.predict(future) fig = m.plot(forecast) fig = m.plot_components(forecast) from fbprophet.plot import add_changepoints_to_plot fig = m.plot_components(forecast) from fbprophet.plot import add_changepoints_to_plot fig = m.plot(forecast) a = add_changepoints_to_plot(fig.gca(),m,forecast) m = Prophet(seasonality_mode='multiplicative') m.fit(df) future = m.make_future_dataframe(50,freq='MS') forecast = m.predict(future) fig = m.plot(forecast)
ひとまずこのFbprophetを用いて株価予測を行えば、Algorithm Trade Systemの枠組み開発を進められそうだ。