diff --git a/examples/case_study/ECNN_DNDGRG3M086SBEA.pt b/examples/case_study/ECNN_DNDGRG3M086SBEA.pt new file mode 100644 index 0000000..91bb0f2 Binary files /dev/null and b/examples/case_study/ECNN_DNDGRG3M086SBEA.pt differ diff --git a/examples/case_study/Readme.md b/examples/case_study/Readme.md new file mode 100644 index 0000000..c08d5a3 --- /dev/null +++ b/examples/case_study/Readme.md @@ -0,0 +1,31 @@ +# FRED-MD Case Study +This case study shows a possible usage of the Error Correction Neural Network (ECNN) on a popular data set. First, we have to build the Python environment and download FRED-MD. Afterwards, we can run the benchmark and use the trained ECNN to create a heatmap forecast and a sensitivity analysis. + + +## Environment +Create environment: ``conda env create -f environment.yml`` +Activate environment: ``conda activate prosper_nn`` + +Additionally install *tqdm*: ``pip install tqdm`` + + +## Data +Download monthly data *2024-01.csv* from +https://files.stlouisfed.org/files/htdocs/fred-md/monthly/2024-01.csv +and place it in the current directory. + +## Files +Short description of the files in the case study: +* ``config.py`` Stores the hyperparameters. +* ``fredmd.py`` Creates a PyTorch dataset with the option to select train, validation or test set. All timeseries are transformed to log-differences and rolling origins are created. +* ``models.py`` Defines the benchmark models. First a context vector is created either by an *Elman*, *GRU* or *LSTM* model. Afterwards, three different forecast approaches are implemented: + * *Direct*: Maps the context vector to the forecast for all forecast horizons with an affine transformation. + * *Recursive*: From the context vector, states for each step in the forecast horizon are created with the same recurrent model that is used to create the context vector. Finally, for each step in the forecast horizon, the state is mapped to the forecast with an affine linear transformation. + * *Sequence to Sequence*: Similar approach as *Recursive*, but a second recurrent model of the same type is used to create the states in the forecast horizon. +* ``benchmark.py`` Creates dataset classes of FRED-MD for training, validation and testing. Afterwards all models are trained on the dataset and the trained ECNN is saved. The model state that performed best on the validation set is evaluated on the test set. The forecast performance is assessed individually for each forecast step. +* ``visualization.py`` Uses the saved ECNN to create a heatmap forecast and a sensitivity analysis. + +## Replicate +1) Run the ``benchmark.py`` to train an ECNN ensemble and ten other benchmark models. The trained ECNN is saved. + +2) Run ``visualization.py`` to use the saved ECNN to create a heatmap forecast and a sensitivity analysis. \ No newline at end of file diff --git a/examples/case_study/benchmark.py b/examples/case_study/benchmark.py new file mode 100644 index 0000000..e6aead8 --- /dev/null +++ b/examples/case_study/benchmark.py @@ -0,0 +1,227 @@ +# %% +import sys +import os + +sys.path.append(os.path.abspath("../..")) +sys.path.append(os.path.abspath("..")) +sys.path.append(os.path.abspath(".")) + + +# %% +import torch +import torch.nn as nn +import pandas as pd +from tqdm import tqdm +from pathlib import Path + +from prosper_nn.models.ecnn import ECNN +from prosper_nn.models.ensemble import Ensemble +from models import RNN_direct, RNN_recursive, RNN_S2S, Naive +from fredmd import Dataset_Fredmd + +from config import ( + past_horizon, + forecast_horizon, + train_test_split_period, + batch_size, + n_epochs, + patience, + n_evaluate_targets, + n_features_Y, + n_models, +) + +# %% +torch.manual_seed(0) + + +# %% Training +def train_model( + model: nn.Module, + dataloader: torch.utils.data.DataLoader, + dataset_val: torch.utils.data.Dataset, + n_epochs: int, + patience: int, +): + optimizer = torch.optim.Adam(model.parameters()) + smallest_val_loss = torch.inf + epoch_smallest_val = 0 + val_features_past, val_target_past, val_target_future = ( + dataset_val.get_all_rolling_origins() + ) + epochs = tqdm(range(n_epochs)) + + for epoch in epochs: + train_loss = 0 + for features_past, target_past, target_future in dataloader: + target_past = target_past.transpose(1, 0) + target_future = target_future.transpose(1, 0) + features_past = features_past.transpose(1, 0) + + model.zero_grad() + + forecasts = get_forecast(model, features_past, target_past) + + assert forecasts.shape == target_future.shape + loss = nn.functional.mse_loss(forecasts, target_future) + loss.backward() + train_loss += loss.detach() + optimizer.step() + + # Validation loss + forecasts_val = get_forecast(model, val_features_past, val_target_past) + val_loss = nn.functional.mse_loss(forecasts_val[0], val_target_future[0]).item() + epochs.set_postfix( + {"val_loss": round(val_loss, 3), "train_loss": round(train_loss.item(), 3)} + ) + + # Save and later use model with best validation loss + if val_loss < smallest_val_loss: + print(f"Save model_state at epoch {epoch}") + best_model_state = model.state_dict() + smallest_val_loss = val_loss + epoch_smallest_val = epoch + + # Early Stopping + if epoch >= epoch_smallest_val + patience: + print(f"No validation improvement since {patience} epochs -> Stop Training") + model.load_state_dict(best_model_state) + return + + model.load_state_dict(best_model_state) + + +def get_forecast( + model: nn.Module, features_past: torch.Tensor, target_past: torch.Tensor +) -> torch.Tensor: + model_type = model.models[0] + + # Select input + if isinstance(model_type, ECNN): + input = (features_past, target_past) + else: + input = (features_past,) + + ensemble_output = model(*input) + mean = ensemble_output[-1] + + # Extract forecasts + if isinstance(model_type, ECNN): + _, forecasts = torch.split(mean, past_horizon) + else: + forecasts = mean + return forecasts + + +def evaluate_model(model: nn.Module, dataset: torch.utils.data.Dataset) -> pd.DataFrame: + model.eval() + losses = [] + + for features_past, target_past, target_future in dataset: + features_past = features_past.unsqueeze(1) + target_past = target_past.unsqueeze(1) + + with torch.no_grad(): + forecasts = get_forecast(model, features_past, target_past) + forecasts = forecasts.squeeze(1) + assert forecasts.shape == target_future.shape + losses.append( + [ + nn.functional.mse_loss(forecasts[i], target_future[i]).item() + for i in range(forecast_horizon) + ] + ) + return pd.DataFrame(losses) + + +# %% Get Data + +fredmd = Dataset_Fredmd( + past_horizon, + forecast_horizon, + split_date=train_test_split_period, + data_type="train", +) +fredmd_val = Dataset_Fredmd( + past_horizon, + forecast_horizon, + split_date=train_test_split_period, + data_type="val", +) +fredmd_test = Dataset_Fredmd( + past_horizon, + forecast_horizon, + split_date=train_test_split_period, + data_type="test", +) + +# %% Run benchmark +n_features_U = len(fredmd.features) +n_state_neurons = n_features_U + n_features_Y + +overall_losses = {} + +for target in fredmd.features[:n_evaluate_targets]: + fredmd.target = target + fredmd_val.target = target + fredmd_test.target = target + + # Error Correction Neural Network (ECNN) + ecnn = ECNN( + n_state_neurons=n_state_neurons, + n_features_U=n_features_U, + n_features_Y=n_features_Y, + past_horizon=past_horizon, + forecast_horizon=forecast_horizon, + ) + + # Define an Ensemble for better forecasts, heatmap visualization and sensitivity analysis + ecnn_ensemble = Ensemble(model=ecnn, n_models=n_models).double() + benchmark_models = {"ECNN": ecnn_ensemble} + + # Compare to further Recurrent Neural Networks + for forecast_module in [RNN_direct, RNN_recursive, RNN_S2S]: + for recurrent_cell_type in ["elman", "gru", "lstm"]: + model = forecast_module( + n_features_U, + n_state_neurons, + n_features_Y, + forecast_horizon, + recurrent_cell_type, + ) + ensemble = Ensemble(model=model, n_models=n_models).double() + benchmark_models[f"{recurrent_cell_type}_{model.forecast_method}"] = ( + ensemble + ) + + # Train models + dataloader = torch.utils.data.DataLoader( + fredmd, batch_size=batch_size, shuffle=True + ) + + for name, model in benchmark_models.items(): + print(f"### Train {name} ###") + train_model(model, dataloader, fredmd_val, n_epochs, patience) + + if target == "DNDGRG3M086SBEA": + torch.save( + benchmark_models["ECNN"], Path(__file__).parent / f"ECNN_{target}.pt" + ) + + # Test + # Additionally, compare with the naive no-change forecast + benchmark_models["Naive"] = Ensemble( + Naive(past_horizon, forecast_horizon, n_features_Y), n_models + ) + + all_losses = { + name: evaluate_model(model, fredmd_test) + for name, model in benchmark_models.items() + } + overall_losses[target] = pd.concat(all_losses) + +overall_losses = pd.concat(overall_losses) +overall_losses.to_csv(Path(__file__).parent / f"overall_losses.csv") +mean_overall_losses = overall_losses.groupby(level=1).mean() +mean_overall_losses.to_csv(Path(__file__).parent / f"mean_overall_losses.csv") +print(mean_overall_losses) diff --git a/examples/case_study/config.py b/examples/case_study/config.py new file mode 100644 index 0000000..a00b7a5 --- /dev/null +++ b/examples/case_study/config.py @@ -0,0 +1,16 @@ +import pandas as pd + +n_evaluate_targets = 19 + +past_horizon = 24 +forecast_horizon = 3 +train_test_split_period = pd.Period("2011-01-01", freq="M") + +# Model Training +batch_size = 32 +n_epochs = 50 +patience = 10 + +# Model Parameters +n_features_Y = 1 +n_models = 25 diff --git a/examples/case_study/fredmd.py b/examples/case_study/fredmd.py new file mode 100644 index 0000000..582408c --- /dev/null +++ b/examples/case_study/fredmd.py @@ -0,0 +1,202 @@ +from pathlib import Path +import torch +import pandas as pd +import numpy as np +from typing import Tuple + + +class Dataset_Fredmd(torch.utils.data.Dataset): + """ + Creates a PyTorch suitable data set for FRED-MD. The data is transformed to + log-differences and rolling origins are created. Afterwards each rolling origin + is scaled independently and data is split into training, validation and test. + """ + + def __init__( + self, + past_horizon: int, + forecast_horizon: int, + split_date: pd.Period, + data_type: str = "test", + target: str = "CPIAUCSL", + ): + assert data_type in ["train", "val", "test"] + self.past_horizon = past_horizon + self.forecast_horizon = forecast_horizon + self.window_size = past_horizon + forecast_horizon + self.split_date = split_date + + # Select variables from "prices" group without 'OILPRICEx' + self.features = [ + "WPSFD49207", + "WPSFD49502", + "WPSID61", + "WPSID62", + # "OILPRICEx", + "PPICMM", + "CPIAUCSL", + "CPIAPPSL", + "CPITRNSL", + "CPIMEDSL", + "CUSR0000SAC", + "CUSR0000SAD", + "CUSR0000SAS", + "CPIULFSL", + "CUSR0000SA0L2", + "CUSR0000SA0L5", + "PCEPI", + "DDURRG3M086SBEA", + "DNDGRG3M086SBEA", + "DSERRG3M086SBEA", + ] + self.target = target + self.original_data = self.get_data() + self.n_rolling_origins = len(self.original_data) - self.window_size + + df = self.preprocess(self.original_data) + rolling_origins = self.get_rolling_origins(df) + self.mean, self.std = self.get_scales(rolling_origins) + df_train, df_val, df_test = self.train_test_split( + rolling_origins, + ) + if data_type == "train": + self.df = df_train + elif data_type == "val": + self.df = df_val + else: + self.df = df_test + + def __len__(self) -> int: + return self.df.index.get_level_values(0).nunique() + + def get_scales( + self, rolling_origins: pd.DataFrame + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + mean = rolling_origins.groupby("rolling_origin_start_date").apply( + lambda x: x.head(self.past_horizon).mean() + ) + std = rolling_origins.groupby("rolling_origin_start_date").apply( + lambda x: x.head(self.past_horizon).std() + ) + assert ( + (std != 0).all().all() + ), "Standard deviation is zero and will lead to NaNs" + return mean, std + + def split_past_future( + self, timeseries: pd.DataFrame + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + past = timeseries.head(self.past_horizon) + future = timeseries.tail(self.forecast_horizon) + return past, future + + def scale( + self, + past: pd.DataFrame, + future: pd.DataFrame, + rolling_origin_start_date: pd.Period, + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + mean = self.mean.loc[rolling_origin_start_date] + std = self.std.loc[rolling_origin_start_date] + past = (past - mean) / std + future = (future - mean) / std + return past, future + + def rescale( + self, forecast: torch.Tensor, rolling_origin_start_date: pd.Period + ) -> torch.Tensor: + mean = self.mean.loc[rolling_origin_start_date, self.target] + std = self.std.loc[rolling_origin_start_date, self.target] + mean = torch.tensor(mean) + std = torch.tensor(std) + forecast = forecast * std + mean + return forecast + + def get_one_rolling_origin( + self, rolling_origin_start_date: pd.Period + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + timeseries = self.df.loc[rolling_origin_start_date] + past, future = self.split_past_future(timeseries) + past, future = self.scale(past, future, rolling_origin_start_date) + + assert past.notnull().all().all() + assert future.notnull().all().all() + + features_past = torch.tensor(past[self.features].values) + target_past = torch.tensor(past[self.target].values).unsqueeze(1) + target_future = torch.tensor(future[self.target].values).unsqueeze(1) + return features_past, target_past, target_future + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + rolling_origin_start_date = self.df.index.get_level_values(0).unique()[idx] + features_past, target_past, target_future = self.get_one_rolling_origin( + rolling_origin_start_date + ) + return features_past, target_past, target_future + + def get_data(self) -> pd.DataFrame: + path = Path(__file__).parent + df = pd.read_csv( + path / "2024-01.csv", + parse_dates=["sasdate"], + index_col="sasdate", + usecols=["sasdate", self.target] + self.features, + ) + df = df.drop("Transform:") + df.index = pd.PeriodIndex(df.index, freq="M") + return df + + def get_rolling_origins(self, df: pd.DataFrame) -> pd.DataFrame: + rolling_origins = [ + df.iloc[i : i + self.window_size] for i in range(self.n_rolling_origins) + ] + rolling_origins = pd.concat( + rolling_origins, keys=df.index[: self.n_rolling_origins] + ) + rolling_origins.index.rename("rolling_origin_start_date", level=0, inplace=True) + return rolling_origins + + def train_test_split( + self, rolling_origins: pd.DataFrame + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + # Avoid data leakage + last_date_val_rolling_origin = self.split_date - self.window_size + last_date_train_rolling_origin = last_date_val_rolling_origin - 12 + df_train = rolling_origins.loc[:last_date_train_rolling_origin] + df_val = rolling_origins.loc[ + last_date_train_rolling_origin + + self.forecast_horizon : last_date_val_rolling_origin + ] + df_test = rolling_origins.loc[ + last_date_val_rolling_origin + self.forecast_horizon : + ] + return df_train, df_val, df_test + + def preprocess(self, df: pd.DataFrame) -> pd.DataFrame: + df = df.apply(np.log).diff() + df = df.iloc[1:] + return df + + def postprocess( + self, forecast: torch.Tensor, rolling_origin_start_date: pd.Period + ) -> torch.Tensor: + start_value = self.original_data.loc[rolling_origin_start_date, self.target] + start_value = torch.tensor(start_value) + forecast = start_value * torch.exp(torch.cumsum(forecast, dim=1)) + return forecast + + def get_all_rolling_origins( + self, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + all_targets_past = [] + all_targets_future = [] + all_features_past = [] + for idx in range(self.__len__()): + features_past, target_past, target_future = self.__getitem__(idx) + all_targets_past.append(target_past) + all_targets_future.append(target_future) + all_features_past.append(features_past) + all_targets_past = torch.stack(all_targets_past, dim=1) + all_targets_future = torch.stack(all_targets_future, dim=1) + all_features_past = torch.stack(all_features_past, dim=1) + return all_features_past, all_targets_past, all_targets_future diff --git a/examples/case_study/models.py b/examples/case_study/models.py new file mode 100644 index 0000000..84f10a5 --- /dev/null +++ b/examples/case_study/models.py @@ -0,0 +1,243 @@ +from typing import Union, Tuple + +import torch +import torch.nn as nn + + +class Benchmark_RNN(nn.Module): + """ + Parent class to create various RNNs based on Elman, GRU and LSTM cells. + Additionally, the forecast methods direct, recursive and + sequence to sequence (S2S) are possible. + For all approaches the past_target is merged to the past_features to enable + an autoregressive part in the models. + """ + + def __init__( + self, + n_features_U: int, + n_state_neurons: int, + n_features_Y: int, + forecast_horizon: int, + recurrent_cell_type: str, + ): + super(Benchmark_RNN, self).__init__() + self.n_features_Y = n_features_Y + self.forecast_horizon = forecast_horizon + self.n_state_neurons = n_state_neurons + self.recurrent_cell_type = recurrent_cell_type + + self.cell = self.get_recurrent_cell() + self.rnn = self.cell(input_size=n_features_U, hidden_size=n_state_neurons) + self.state_output = nn.Linear( + in_features=n_state_neurons, out_features=self.output_size_linear_decoder + ) + self.init_state = self.set_init_state() + + def forward(self, past_features: torch.Tensor) -> torch.Tensor: + batchsize = past_features.size(1) + + init_state = self.repeat_init_state(batchsize) + output_rnn = self.rnn(past_features, init_state) + return output_rnn + + def set_init_state(self) -> Union[nn.Parameter, Tuple[nn.Parameter, nn.Parameter]]: + dtype = torch.float64 + if self.recurrent_cell_type == "lstm": + init_state = ( + nn.Parameter( + torch.rand(1, self.n_state_neurons, dtype=dtype), requires_grad=True + ), + nn.Parameter( + torch.rand(1, self.n_state_neurons, dtype=dtype), requires_grad=True + ), + ) + else: + init_state = nn.Parameter( + torch.rand(1, self.n_state_neurons, dtype=dtype), requires_grad=True + ) + return init_state + + def get_recurrent_cell(self) -> nn.Module: + if self.recurrent_cell_type == "elman": + cell = nn.RNN + elif self.recurrent_cell_type == "gru": + cell = nn.GRU + elif self.recurrent_cell_type == "lstm": + cell = nn.LSTM + else: + raise ValueError( + f"recurrent_cell_type {self.recurrent_cell_type} not available." + ) + return cell + + def repeat_init_state( + self, batchsize: int + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.recurrent_cell_type == "lstm": + return self.init_state[0].repeat(batchsize, 1).unsqueeze( + 0 + ), self.init_state[1].repeat(batchsize, 1).unsqueeze(0) + else: + return self.init_state.repeat(batchsize, 1).unsqueeze(0) + + +class RNN_direct(Benchmark_RNN): + """ + Encodes the data of the past horizon into a context vector with a RNN. + Afterwards, the context vector is mapped to the forecasts of all + forecast steps in the forecast horizon by an affine linear transformation. + + .. math:: + + s_0, \dots, s_T = rnn(y_0, \dots, y_T; s_0) + + \hat{y}_{T+1}, \dots, \hat{y}_{T+\tau} = A \cdot s_{T} + with A \in \mathbb{R}^{n_features_Y \cdot forecast_horizon \times n_state_neurons} + """ + + def __init__( + self, + n_features_U: int, + n_state_neurons: int, + n_features_Y: int, + forecast_horizon: int, + recurrent_cell_type: str, + ): + self.forecast_method = "direct" + self.output_size_linear_decoder = n_features_Y * forecast_horizon + super(RNN_direct, self).__init__( + n_features_U, + n_state_neurons, + n_features_Y, + forecast_horizon, + recurrent_cell_type, + ) + + def forward(self, past_features: torch.Tensor) -> torch.Tensor: + output_rnn = super(RNN_direct, self).forward(past_features) + + context_vector = output_rnn[0][-1] + forecast = self.state_output(context_vector) + forecast = forecast.reshape(self.forecast_horizon, -1, self.n_features_Y) + return forecast + + +class RNN_recursive(Benchmark_RNN): + """ + Encodes the data of the past horizon into a context vector with a RNN. + The context vector is given to the same RNN to predict states for the forecast horizon. + Each state in the forecast horizon is then decoded by an affine linear + transformation A. + + .. math:: + + s_0, \dots, s_T = rnn(y_0, \dots, y_T; s_0) + s_{T+1}, ..., s_{T+\tau} = rnn(0_{T+1}, \dots, 0_{T+\tau}; s_T) + + \hat{y}_{T+i} = A \cdot s_{T+i} for i=1,...,\tau + with A \in \mathbb{R}^{n_features_Y \times n_state_neurons} + + """ + + def __init__( + self, + n_features_U: int, + n_state_neurons: int, + n_features_Y: int, + forecast_horizon: int, + recurrent_cell_type: str, + ): + self.forecast_method = "recursive" + self.output_size_linear_decoder = n_features_Y + super(RNN_recursive, self).__init__( + n_features_U, + n_state_neurons, + n_features_Y, + forecast_horizon, + recurrent_cell_type, + ) + + def forward(self, past_features: torch.Tensor) -> torch.Tensor: + # add zeros as RNN input for forecast horizon + future_zeros_features = torch.zeros_like(past_features)[: self.forecast_horizon] + features = torch.cat([past_features, future_zeros_features], dim=0) + + output_rnn = super(RNN_recursive, self).forward(features) + future_states = output_rnn[0][-self.forecast_horizon :] + forecast = self.state_output(future_states) + return forecast + + +class RNN_S2S(Benchmark_RNN): + """ + Encodes the data of the past horizon into a context vector with a RNN. + The context vector is given to another RNN of the same recurrent cell type + to predict a state for each step in the forecast horizon. + Each state in the forecast horizon is then decoded by an affine linear + transformation A. + + .. math:: + + s_0, \dots, s_T = rnn(y_0, \dots, y_T; s_0) + s_{T+1}, ..., s_{T+\tau} = \tilde{rnn}(0_{T+1}, \dots, 0_{T+\tau}; s_T) + + \hat{y}_{T+i} = A \cdot s_{T+i} for i=1,...,\tau + with A \in \mathbb{R}^{n_features_Y \times n_state_neurons} + """ + + def __init__( + self, + n_features_U: int, + n_state_neurons: int, + n_features_Y: int, + forecast_horizon: int, + recurrent_cell_type: str, + ): + self.forecast_method = "s2s" + self.output_size_linear_decoder = n_features_Y + super(RNN_S2S, self).__init__( + n_features_U, + n_state_neurons, + n_features_Y, + forecast_horizon, + recurrent_cell_type, + ) + self.decoder = self.cell(input_size=n_features_U, hidden_size=n_state_neurons) + + def forward(self, past_features: torch.Tensor) -> torch.Tensor: + output_rnn = super(RNN_S2S, self).forward(past_features) + + # add dummy zeros as RNN input for forecast horizon + future_zeros_features = torch.zeros_like(past_features)[: self.forecast_horizon] + + context_vector = output_rnn[1] + states_decoder = self.decoder(future_zeros_features, context_vector)[0] + forecast = self.state_output(states_decoder) + return forecast + + +class Naive(nn.Module): + """ + Model that predicts zero changes. Implemented so that it can be used like the + other benchmark recurrent neural networks. + + .. math:: + + \hat{y}_{T+i} = 0 for i=1,...,\tau + """ + + def __init__( + self, past_horizon: int, forecast_horizon: int, n_features_Y: int + ) -> None: + super().__init__() + self.past_horizon = past_horizon + self.forecast_horizon = forecast_horizon + self.n_features_Y = n_features_Y + + def forward(self, past_features: torch.Tensor) -> torch.Tensor: + return torch.zeros( + self.forecast_horizon, + past_features.size(1), + self.n_features_Y, + ) diff --git a/examples/case_study/visualization.py b/examples/case_study/visualization.py new file mode 100644 index 0000000..e66a91f --- /dev/null +++ b/examples/case_study/visualization.py @@ -0,0 +1,115 @@ +# %% +import sys +import os + +sys.path.append(os.path.abspath("../..")) +sys.path.append(os.path.abspath("..")) +sys.path.append(os.path.abspath(".")) + + +# %% + +import torch +import pandas as pd +import numpy as np +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt +import seaborn as sns + +from prosper_nn.utils import visualize_forecasts, sensitivity_analysis +from fredmd import Dataset_Fredmd + +from config import past_horizon, forecast_horizon, train_test_split_period + + +# %% Select target +target = "DNDGRG3M086SBEA" + +# %% +fredmd_test = Dataset_Fredmd( + past_horizon, + forecast_horizon, + split_date=train_test_split_period, + data_type="test", + target=target, +) + +# %% Visualization of the ECNN ensemble Heatmap + +ecnn_ensemble = torch.load(Path(__file__).parent / f"ECNN_{target}.pt") + +rolling_origin_start_date = pd.Period("2011-01-01", freq="M") +ecnn_ensemble.eval() +features_past, target_past, target_future = fredmd_test.get_one_rolling_origin( + rolling_origin_start_date +) +target_past = target_past.unsqueeze(1) +features_past = features_past.unsqueeze(1) + +with torch.no_grad(): + forecasts = ecnn_ensemble(features_past, target_past) + +forecasts = forecasts[:-1, past_horizon:, 0] +forecasts = fredmd_test.rescale(forecasts, rolling_origin_start_date) +forecasts = fredmd_test.postprocess(forecasts, rolling_origin_start_date) + +torch.save(forecasts, Path(__file__).parent / "forecasts.pt") + +# %% Figure 6: Uncertainty Heatmap + +matplotlib.rcParams.update({"font.size": 15}) + +start_point = torch.tensor( + fredmd_test.original_data.loc[rolling_origin_start_date - 1, fredmd_test.target] +) +visualize_forecasts.heatmap_forecasts( + forecasts.squeeze(2), + start_point=start_point, + sigma=0.25, + num_interp=30, + window_height=200, + xlabel="Forecast step", + save_at=Path(__file__).parent / f"heatmap_{target}.pdf", + title="Heatmap ensemble forecast", +) + + +# %% Figure 7: Sensitivity Analysis of the Input Features +matplotlib.rcParams.update({"font.size": 12}) + +all_ros_features_past, all_ros_targets_past, _ = fredmd_test.get_all_rolling_origins() +all_ros_features_past = all_ros_features_past.transpose(1, 0).unsqueeze(2) +all_ros_targets_past = all_ros_targets_past.transpose(1, 0).unsqueeze(2) + +for model in ecnn_ensemble.models: + model.batchsize = 1 + +sensitivity = sensitivity_analysis.calculate_sensitivity_analysis( + ecnn_ensemble, + *(all_ros_features_past, all_ros_targets_past), + output_neuron=(-1, past_horizon + 1, 0, 0), + batchsize=1, +) + +restricted_sensitivity_matrix = sensitivity[:, -1].squeeze(1) +labels = fredmd_test.features + [fredmd_test.target] + +fig = sns.heatmap( + restricted_sensitivity_matrix.T, + center=0, + cmap="coolwarm", + robust=True, + cbar_kws={"label": r"$\frac{\partial\ output}{\partial\ input}$"}, + vmin=-torch.max(abs(restricted_sensitivity_matrix)), + vmax=torch.max(abs(restricted_sensitivity_matrix)), + rasterized=True, +) +plt.xlabel("Rolling Origins") +plt.ylabel("Features") +plt.yticks(ticks=0.5 + np.arange(len(labels)), labels=labels, rotation=0) +plt.title(f"Sensitivity of '{fredmd_test.target}'s one step forecast") +plt.grid(visible=False) +plt.tight_layout() +plt.savefig(Path(__file__).parent / f"sensitivity_{target}.pdf")