Skip to content

Commit

Permalink
Specify task/support variables as data
Browse files Browse the repository at this point in the history
  • Loading branch information
bknico-iis committed Mar 11, 2024
1 parent 253a007 commit c9f5ff3
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions prosper_nn/models/hcnn_compressed/hcnn_cell_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,24 @@ def __init__(
def forward(
self,
state: torch.Tensor,
task: Optional[torch.Tensor] = None,
support: Optional[torch.Tensor] = None,
observation_task: Optional[torch.Tensor] = None,
observation_support: Optional[torch.Tensor] = None,
):
"""
Parameters
----------
state : torch.Tensor
The previous state of the HCNN. shape = (batch_size, n_state_neurons)
task : torch.Tensor
The task is the data for the given timestamp which should be predicted from supports.
observation_task : torch.Tensor
The observation_task is the data for the given timestamp which should be predicted from supports.
It has the
shape = (batchsize, n_features_task).
It is an optional variable. If no variable is given,
the observation is not subtracted
from the expectation to create the output variable.
Additionally, no teacher forcing is applied on the state vector.
support : torch.Tensor
The support is the data for the given timestamp which is compressed and then used to learn task.
observation_support : torch.Tensor
The observation_support is the data for the given timestamp which is compressed and then used to learn observation_task.
It has the
shape = (batchsize, n_features_sup).
It is an optional variable. If no variable is given,
Expand All @@ -134,18 +134,18 @@ def forward(
state : torch.Tensor
The updated state of the HCNN.
output_task: torch.Tensor
The output of the HCNN Cell. If a task is given,
this output is calculated by the expectation_task minus the task.
If no task is given, the output is equal to the expectation.
The output of the HCNN Cell. If a observation_task is given,
this output is calculated by the expectation_task minus the observation_task.
If no observation_task is given, the output is equal to the expectation.
"""

expectation_task = torch.mm(state, self.eye_task.T)
expectation_support = torch.mm(state, self.eye_support.T)

if task is not None and support is not None:
support_compressed = self.E(support)
if observation_task is not None and observation_support is not None:
support_compressed = self.E(observation_support)

output_task = expectation_task - task
output_task = expectation_task - observation_task
output_support = expectation_support - support_compressed

teacher_forcing_task = torch.mm(
Expand All @@ -159,8 +159,8 @@ def forward(
state - teacher_forcing_task - teacher_forcing_support
)

elif xor(task is None, support is None): # XOR only one of them is set
self.set_task_and_support_error(task, support)
elif xor(observation_task is None, observation_support is None): # XOR only one of them is set
self.set_task_and_support_error(observation_task, observation_support)

else: # Forecasts
output_task = expectation_task
Expand Down Expand Up @@ -190,21 +190,21 @@ def set_teacher_forcing(self, teacher_forcing: float) -> None:
self.teacher_forcing = teacher_forcing
self.ptf_dropout.p = 1 - teacher_forcing

def set_task_and_support_error(self, task, support) -> None:
def set_task_and_support_error(self, observation_task, observation_support) -> None:
"""
The task and support tensors should either both be set or both be not set.
The observation_task and observation_support tensors should either both be set or both be not set.
This is used to check and throw the error if either of them is empty and reminds to set that.
Parameters
----------
task, support
observation_task, observation_support
Returns
-------
None
"""
if task is None:
raise ValueError("Task is empty and please set it")
elif support is None:
raise ValueError("Support is empty and please set it")
if observation_task is None:
raise ValueError("observation_task is empty and please set it")
elif observation_support is None:
raise ValueError("observation_support is empty and please set it")

0 comments on commit c9f5ff3

Please sign in to comment.