Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove sink_only from get_forecaster signature #203

Merged
merged 3 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions src/pymgrid/forecast/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pymgrid.utils.space import ModuleSpace


def get_forecaster(forecaster, observation_space, forecast_shape, sink_only=False, time_series=None, increase_uncertainty=False):
def get_forecaster(forecaster, observation_space, forecast_shape, time_series=None, increase_uncertainty=False):
"""
Get the forecasting function for the time series module.

Expand All @@ -32,13 +32,13 @@ def get_forecaster(forecaster, observation_space, forecast_shape, sink_only=Fals

* If ``None``, no forecast.

forecast_shape : int or tuple of int
Expected shape of forecasts. If an integer, will return forecasts of shape (shape, 1).

observation_space : :class:`ModuleSpace <pymgrid.utils.space.ModuleSpace>`
Observation space; used to determine values to pad missing forecasts when we are forecasting past the
end of the time series.

sink_only : bool
Whether the module is a sink and is not a source.

time_series: ndarray[float] or None, default None.
The underlying time series, used to validate UserDefinedForecaster.
Only used if callable(forecaster).
Expand All @@ -55,27 +55,22 @@ def get_forecaster(forecaster, observation_space, forecast_shape, sink_only=Fals
"""

if forecaster is None:
return NoForecaster(observation_space, forecast_shape, sink_only)
return NoForecaster(observation_space, forecast_shape)
elif isinstance(forecaster, (UserDefinedForecaster, OracleForecaster, GaussianNoiseForecaster)):
return forecaster
elif callable(forecaster):
return UserDefinedForecaster(forecaster, observation_space, forecast_shape, sink_only, time_series)
return UserDefinedForecaster(forecaster, observation_space, forecast_shape, time_series)
elif forecaster == "oracle":
return OracleForecaster(observation_space, forecast_shape, sink_only)
return OracleForecaster(observation_space, forecast_shape)
elif is_number(forecaster):
return GaussianNoiseForecaster(
forecaster,
observation_space,
forecast_shape,
sink_only,
increase_uncertainty=increase_uncertainty
)
return GaussianNoiseForecaster(forecaster, observation_space, forecast_shape,
increase_uncertainty=increase_uncertainty)
else:
raise ValueError(f"Unable to parse forecaster of type {type(forecaster)}")


class Forecaster:
def __init__(self, observation_space, forecast_shape, sink_only):
def __init__(self, observation_space, forecast_shape):
self._observation_space = observation_space
self._forecast_shaped_space = self._get_forecast_shaped_space(forecast_shape)
self._fill_arr = (self._observation_space.unnormalized.high + self._observation_space.unnormalized.low) / 2
Expand Down Expand Up @@ -176,7 +171,7 @@ def __repr__(self):


class UserDefinedForecaster(Forecaster):
def __init__(self, forecaster_function, observation_space, forecast_shape, sink_only, time_series):
def __init__(self, forecaster_function, observation_space, forecast_shape, time_series):
self.is_vectorized_forecaster, self.cast_to_arr = \
_validate_callable_forecaster(forecaster_function, time_series)

Expand All @@ -185,7 +180,7 @@ def __init__(self, forecaster_function, observation_space, forecast_shape, sink_

self._forecaster = forecaster_function

super().__init__(observation_space, forecast_shape, sink_only)
super().__init__(observation_space, forecast_shape)

def _cast_to_arr(self, forecast, val_c_n):
if self.cast_to_arr:
Expand All @@ -203,8 +198,8 @@ def _forecast(self, val_c, val_c_n, n):


class GaussianNoiseForecaster(Forecaster):
def __init__(self, noise_std, observation_space, forecast_shape, sink_only, increase_uncertainty=False):
super().__init__(observation_space, forecast_shape, sink_only)
def __init__(self, noise_std, observation_space, forecast_shape, increase_uncertainty=False):
super().__init__(observation_space, forecast_shape)

self.input_noise_std = noise_std
self.increase_uncertainty = increase_uncertainty
Expand Down
2 changes: 0 additions & 2 deletions src/pymgrid/modules/base/timeseries/base_timeseries_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self,
self._forecaster = get_forecaster(forecaster,
self._get_observation_spaces(),
forecast_shape=(self.forecast_horizon, len(self.state_components)),
sink_only=self.is_sink and not self.is_source,
time_series=self.time_series,
increase_uncertainty=forecaster_increase_uncertainty)

Expand Down Expand Up @@ -228,7 +227,6 @@ def set_forecaster(self,
self._forecaster = get_forecaster(forecaster,
self._observation_space,
(self.forecast_horizon, len(self.state_components)),
self.is_sink and not self.is_source,
self.time_series,
increase_uncertainty=forecaster_increase_uncertainty)

Expand Down
Loading