import warnings
warnings.filterwarnings('ignore')
from sktime.datasets import load_longley
_, y = load_longley() # 16*5
y.head()
GNPDEFL | GNP | UNEMP | ARMED | POP | |
---|---|---|---|---|---|
Period | |||||
1947 | 83.0 | 234289.0 | 2356.0 | 1590.0 | 107608.0 |
1948 | 88.5 | 259426.0 | 2325.0 | 1456.0 | 108632.0 |
1949 | 88.2 | 258054.0 | 3682.0 | 1616.0 | 109773.0 |
1950 | 89.5 | 284599.0 | 3351.0 | 1650.0 | 110929.0 |
1951 | 96.2 | 328975.0 | 2099.0 | 3099.0 | 112075.0 |
from sktime.forecasting.model_selection import temporal_train_test_split
y_train, y_test = temporal_train_test_split(y, test_size=4) # hold out last 4 years
from sktime.forecasting.var import VAR
from sktime.forecasting.model_selection import ForecastingGridSearchCV, SlidingWindowSplitter
forecaster = VAR()
forecaster.get_params()
{'dates': None, 'freq': None, 'ic': None, 'maxlags': None, 'method': 'ols', 'missing': 'none', 'random_state': None, 'trend': 'c', 'verbose': False}
param_grid = {'ic':['aic', 'fpe', 'hqic', 'bic', None], 'trend': ['c', 'ct', 'ctt', 'n']}
cv = SlidingWindowSplitter(window_length=10)
gscv = ForecastingGridSearchCV(
forecaster, strategy="refit", cv=cv, param_grid=param_grid
)
fine_tuning = gscv.fit(y_train)
fine_tuning.best_params_
{'ic': None, 'trend': 'n'}
import numpy as np
fh = np.arange(1, 5)
y_pred = fine_tuning.predict(fh)
# evaluation
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error
mean_absolute_percentage_error(y_test, y_pred, symmetric=False, multioutput = 'raw_values')
array([0.01965311, 0.01369328, 0.13334094, 0.08359001, 0.00636989])
import matplotlib.pyplot as plt
def get_plots(y_train, y_test, y_pred):
columns = list(y_train.columns)
for column in columns:
fig, ax = plt.subplots(figsize=(8, 6))
line1, = ax.plot(y_train.index.to_timestamp(), y_train[column], 'bo-')
line2, = ax.plot(y_test.index.to_timestamp(), y_test[column], 'go-')
line3, = ax.plot(y_pred.index.to_timestamp(), y_pred[column], 'yo-')
ax.legend((line1, line2, line3), ('y', 'y_test', 'y_pred'))
ax.set_ylabel(column)
# visualization
get_plots(y_train, y_test, y_pred)