Skip to content

Commit

Permalink
Merge pull request #17 from Fraunhofer-IIS/hcnn_compressed
Browse files Browse the repository at this point in the history
Add hcnn compressed
  • Loading branch information
bknico-iis authored Mar 11, 2024
2 parents 79aa83a + c9f5ff3 commit ec3aa4f
Show file tree
Hide file tree
Showing 4 changed files with 637 additions and 0 deletions.
76 changes: 76 additions & 0 deletions examples/hcnn_compressed_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sys, os

sys.path.append(os.path.abspath(".."))
sys.path.append(os.path.abspath("."))

import torch
import torch.nn as nn
import torch.optim as optim

from prosper_nn.models.hcnn_compressed import HCNN_compressed
import prosper_nn.utils.generate_time_series_data as gtsd
import prosper_nn.utils.create_input_ecnn_hcnn as ci

# %% Define network parameters

n_data = 10
batchsize = 1
n_batches = 2
n_state_neurons = 10
n_features_task = 2
n_features_sup = 5
n_features_sup_comp = 3
past_horizon = 10
forecast_horizon = 5


# %% Create data and targets
target_task = torch.zeros((past_horizon, batchsize, n_features_task))
target_support = torch.zeros((past_horizon, batchsize, n_features_sup_comp))

# generate data with "unknown" variables U
support, task = gtsd.sample_data(
n_data, n_features_Y=n_features_sup, n_features_U=n_features_task
)

# Only use Y as input for the hcnn
batches_task = ci.create_input(task, past_horizon, batchsize)
batches_support = ci.create_input(support, past_horizon, batchsize)

# %% Initialize Hcnn
hcnn_model_compressed = HCNN_compressed(
n_state_neurons,
n_features_task,
n_features_sup,
n_features_sup_comp,
past_horizon,
forecast_horizon,
)


# %% Train model
optimizer = optim.Adam(hcnn_model_compressed.parameters(), lr=0.001)
loss_function = nn.MSELoss()

epochs = 10

total_loss = epochs * [0]

for epoch in range(epochs):
for batch_index in range(batches_task.shape[0]):
hcnn_model_compressed.zero_grad()

output_task, output_support = hcnn_model_compressed(
batches_task[batch_index], batches_support[batch_index]
)

past_error_task, forecast_task = torch.split(output_task, past_horizon)
past_error_support = output_support[:past_horizon]

loss_task = loss_function(past_error_task, target_task)
loss_support = loss_function(past_error_support, target_support)

loss = loss_task + loss_support
loss.backward()
optimizer.step()
total_loss[epoch] += loss.detach()
1 change: 1 addition & 0 deletions prosper_nn/models/hcnn_compressed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .hcnn_compressed import *
210 changes: 210 additions & 0 deletions prosper_nn/models/hcnn_compressed/hcnn_cell_compressed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import torch.nn as nn
import torch
import torch.nn.utils.prune as prune
from typing import Optional, Type
from operator import xor


class HCNNCell_compressed(nn.Module):
"""
The HCNNCell call is implemented to model one forecast step in a Historical Consistent Neural Network
with compressed support input.
By recursively using the cell a HCNN network can be implemented.
"""

def __init__(
self,
n_state_neurons: int,
n_features_task: int,
n_features_sup: int,
n_features_sup_comp: int,
sparsity: float = 0.0,
activation: Type[torch.autograd.Function] = torch.tanh,
teacher_forcing: float = 1,
):
"""
Parameters
----------
n_state_neurons : int
The dimension of the state in the HCNN Cell. It must be a positive integer.
n_features_task : int
The size of the task variables to predict in each timestamp.
It must be a positive integer.
n_festures_support: int
The size of the support variables which are input in each timestamp.
It must be a positive integer.
n_features_compressed_support: int
The size to which we are compressing our support variables in each timestamp.
It must be a positive integer.
sparsity : float
The share of weights that are set to zero in the matrix A.
These weights are not trainable and therefore always zero.
For big matrices (dimension > 50) this can be necessary to guarantee
numerical stability
and increases the long-term memory of the model.
activation : Type[torch.autograd.Function]
The activation function that is applied on the output of the hidden layers.
The same function is used on all hidden layers.
No function is applied if no function is given.
teacher_forcing : float
The probability that teacher forcing is applied for a single state neuron.
In each time step this is repeated and therefore enforces stochastic learning
if the value is smaller than 1.
Returns
-------
None
"""
super(HCNNCell_compressed, self).__init__()
self.n_state_neurons = n_state_neurons
self.n_features_task = n_features_task
self.n_features_sup = n_features_sup
self.n_features_sup_comp = n_features_sup_comp
self.sparsity = sparsity
self.activation = activation
self.teacher_forcing = teacher_forcing

if type(activation) == str and activation == "torch.tanh":
self.activation = torch.tanh

self.A = nn.Linear(
in_features=self.n_state_neurons,
out_features=self.n_state_neurons,
bias=False,
)
self.E = nn.Linear(
in_features=self.n_features_sup,
out_features=self.n_features_sup_comp,
bias=False,
)

self.eye_task = nn.Parameter(torch.eye(
self.n_features_task,
self.n_state_neurons,
),requires_grad=False)

self.eye_support = nn.Parameter(torch.cat(
(
torch.zeros(
self.n_features_sup_comp,
(self.n_state_neurons - self.n_features_sup_comp),
),
torch.eye(self.n_features_sup_comp, self.n_features_sup_comp),
),
1,
))

self.ptf_dropout = nn.Dropout(1 - self.teacher_forcing)

if self.sparsity > 0:
prune.random_unstructured(self.A, name="weight", amount=self.sparsity)


def forward(
self,
state: torch.Tensor,
observation_task: Optional[torch.Tensor] = None,
observation_support: Optional[torch.Tensor] = None,
):
"""
Parameters
----------
state : torch.Tensor
The previous state of the HCNN. shape = (batch_size, n_state_neurons)
observation_task : torch.Tensor
The observation_task is the data for the given timestamp which should be predicted from supports.
It has the
shape = (batchsize, n_features_task).
It is an optional variable. If no variable is given,
the observation is not subtracted
from the expectation to create the output variable.
Additionally, no teacher forcing is applied on the state vector.
observation_support : torch.Tensor
The observation_support is the data for the given timestamp which is compressed and then used to learn observation_task.
It has the
shape = (batchsize, n_features_sup).
It is an optional variable. If no variable is given,
the observation is not subtracted
from the expectation to create the output variable.
Additionally, no teacher forcing is applied on the state vector.
Returns
-------
state : torch.Tensor
The updated state of the HCNN.
output_task: torch.Tensor
The output of the HCNN Cell. If a observation_task is given,
this output is calculated by the expectation_task minus the observation_task.
If no observation_task is given, the output is equal to the expectation.
"""

expectation_task = torch.mm(state, self.eye_task.T)
expectation_support = torch.mm(state, self.eye_support.T)

if observation_task is not None and observation_support is not None:
support_compressed = self.E(observation_support)

output_task = expectation_task - observation_task
output_support = expectation_support - support_compressed

teacher_forcing_task = torch.mm(
self.ptf_dropout(output_task), self.eye_task
)
teacher_forcing_support = torch.mm(
self.ptf_dropout(output_support), self.eye_support
)

state = self.activation(
state - teacher_forcing_task - teacher_forcing_support
)

elif xor(observation_task is None, observation_support is None): # XOR only one of them is set
self.set_task_and_support_error(observation_task, observation_support)

else: # Forecasts
output_task = expectation_task
output_support = expectation_support
state = self.activation(state)
state = self.A(state)
return state, output_task, output_support

def set_teacher_forcing(self, teacher_forcing: float) -> None:
"""
Function to set teacher forcing to a specific value in layer and as self variable.
Parameters
----------
teacher_forcing: float
The value teacher forcing is set to in the cell.
Returns
-------
None
"""
if (teacher_forcing < 0) or (teacher_forcing > 1):
raise ValueError(
"{} is not a valid number for teacher_forcing. "
"It must be a value in the interval [0, 1].".format(teacher_forcing)
)
self.teacher_forcing = teacher_forcing
self.ptf_dropout.p = 1 - teacher_forcing

def set_task_and_support_error(self, observation_task, observation_support) -> None:
"""
The observation_task and observation_support tensors should either both be set or both be not set.
This is used to check and throw the error if either of them is empty and reminds to set that.
Parameters
----------
observation_task, observation_support
Returns
-------
None
"""
if observation_task is None:
raise ValueError("observation_task is empty and please set it")
elif observation_support is None:
raise ValueError("observation_support is empty and please set it")
Loading

0 comments on commit ec3aa4f

Please sign in to comment.