-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Creating Time Base component for Media Contribution #752
Changes from 13 commits
ca1e8ac
72cbb35
081f9a1
403b095
2e475c7
d9dd569
25e6865
493fe84
b44871e
1ab59da
4fedacc
b6d6b55
4f695a7
96c1092
eba67dc
3fc7caf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,11 +40,11 @@ | |
_get_saturation_function, | ||
) | ||
from pymc_marketing.mmm.lift_test import ( | ||
add_lift_measurements_to_likelihood, | ||
add_lift_measurements_to_likelihood_from_saturation, | ||
scale_lift_measurements, | ||
) | ||
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget | ||
from pymc_marketing.mmm.tvp import create_time_varying_intercept, infer_time_index | ||
from pymc_marketing.mmm.tvp import create_time_varying_gp_multiplier, infer_time_index | ||
from pymc_marketing.mmm.utils import ( | ||
apply_sklearn_transformer_across_dim, | ||
create_new_spend_data, | ||
|
@@ -81,6 +81,7 @@ | |
adstock: str | AdstockTransformation, | ||
saturation: str | SaturationTransformation, | ||
time_varying_intercept: bool = False, | ||
time_varying_media: bool = False, | ||
model_config: dict | None = None, | ||
sampler_config: dict | None = None, | ||
validate_data: bool = True, | ||
|
@@ -105,6 +106,13 @@ | |
Type of saturation transformation to apply. | ||
time_varying_intercept : bool, optional | ||
Whether to consider time-varying intercept, by default False. | ||
Because the `time-varying` variable is centered around 1 and acts as a multiplier, | ||
the variable `base_intercept` now represents the mean of the time-varying intercept. | ||
time_varying_media : bool, optional | ||
Whether to consider time-varying media contributions, by default False. | ||
The `time-varying-media` creates a time media variable centered around 1, | ||
this variable acts as a global multiplier (scaling factor) for all channels, | ||
meaning all media channels share the same latent fluctiation. | ||
model_config : Dictionary, optional | ||
dictionary of parameters that initialise model configuration. | ||
Class-default defined by the user default_model_config method. | ||
|
@@ -121,6 +129,7 @@ | |
self.control_columns = control_columns | ||
self.adstock_max_lag = adstock_max_lag | ||
self.time_varying_intercept = time_varying_intercept | ||
self.time_varying_media = time_varying_media | ||
self.yearly_seasonality = yearly_seasonality | ||
self.date_column = date_column | ||
self.validate_data = validate_data | ||
|
@@ -220,7 +229,7 @@ | |
self.X: pd.DataFrame = X_data | ||
self.y: pd.Series | np.ndarray = y | ||
|
||
if self.time_varying_intercept: | ||
if self.time_varying_intercept | self.time_varying_media: | ||
self._time_index = np.arange(0, X.shape[0]) | ||
self._time_index_mid = X.shape[0] // 2 | ||
self._time_resolution = ( | ||
|
@@ -344,33 +353,66 @@ | |
dims="date", | ||
mutable=True, | ||
) | ||
|
||
if self.time_varying_intercept: | ||
if self.time_varying_intercept | self.time_varying_media: | ||
time_index = pm.Data( | ||
"time_index", | ||
self._time_index, | ||
dims="date", | ||
) | ||
intercept_dist = get_distribution( | ||
|
||
if self.time_varying_intercept: | ||
intercept_distribution = get_distribution( | ||
name=self.model_config["intercept"]["dist"] | ||
) | ||
intercept = create_time_varying_intercept( | ||
time_index, | ||
self._time_index_mid, | ||
self._time_resolution, | ||
intercept_dist, | ||
self.model_config, | ||
base_intercept = intercept_distribution( | ||
name="base_intercept", **self.model_config["intercept"]["kwargs"] | ||
) | ||
|
||
intercept_latent_process = create_time_varying_gp_multiplier( | ||
name="intercept", | ||
dims="date", | ||
time_index=time_index, | ||
time_index_mid=self._time_index_mid, | ||
time_resolution=self._time_resolution, | ||
model_config=self.model_config, | ||
) | ||
intercept = pm.Deterministic( | ||
name="intercept", | ||
var=base_intercept * intercept_latent_process, | ||
dims="date", | ||
) | ||
else: | ||
intercept = create_distribution_from_config( | ||
name="intercept", config=self.model_config | ||
) | ||
|
||
channel_contributions = pm.Deterministic( | ||
name="channel_contributions", | ||
var=self.forward_pass(x=channel_data_), | ||
dims=("date", "channel"), | ||
) | ||
if self.time_varying_media: | ||
base_channel_contributions = pm.Deterministic( | ||
name="base_channel_contributions", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is "baseline" a better name than "base" ? 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm good with both, I didn't made it but I can if you feel is better! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think |
||
var=self.forward_pass(x=channel_data_), | ||
dims=("date", "channel"), | ||
) | ||
|
||
media_latent_process = create_time_varying_gp_multiplier( | ||
name="media", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to specify this is a global baseline go per channel. So maybe "baseline_media" is a better name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to the same comment the media multiplier is not the baseline, is more a scaling factor! But let's agree on the words 👀 |
||
dims="date", | ||
time_index=time_index, | ||
time_index_mid=self._time_index_mid, | ||
time_resolution=self._time_resolution, | ||
model_config=self.model_config, | ||
) | ||
channel_contributions = pm.Deterministic( | ||
name="channel_contributions", | ||
var=base_channel_contributions * media_latent_process[:, None], | ||
cetagostini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dims=("date", "channel"), | ||
) | ||
|
||
else: | ||
channel_contributions = pm.Deterministic( | ||
name="channel_contributions", | ||
var=self.forward_pass(x=channel_data_), | ||
dims=("date", "channel"), | ||
) | ||
|
||
mu_var = intercept + channel_contributions.sum(axis=-1) | ||
|
||
|
@@ -465,7 +507,7 @@ | |
|
||
@property | ||
def default_model_config(self) -> dict: | ||
base_config = { | ||
base_config: dict[str, Any] = { | ||
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, | ||
"likelihood": { | ||
"dist": "Normal", | ||
|
@@ -483,15 +525,26 @@ | |
"kwargs": {"mu": 0, "b": 1}, | ||
"dims": "fourier_mode", | ||
}, | ||
"intercept_tvp_kwargs": { | ||
} | ||
|
||
if self.time_varying_intercept: | ||
base_config["intercept_tvp_config"] = { | ||
"m": 200, | ||
"L": None, | ||
"eta_lam": 1, | ||
"ls_mu": None, | ||
"ls_sigma": 10, | ||
"cov_func": None, | ||
}, | ||
} | ||
} | ||
if self.time_varying_media: | ||
base_config["media_tvp_config"] = { | ||
"m": 200, | ||
"L": None, | ||
"eta_lam": 1, | ||
"ls_mu": None, | ||
"ls_sigma": 10, | ||
"cov_func": None, | ||
} | ||
|
||
for media_transform in [self.adstock, self.saturation]: | ||
for param, config in media_transform.function_priors.items(): | ||
|
@@ -712,7 +765,7 @@ | |
if hasattr(self, "fourier_columns"): | ||
data["fourier_data"] = self._get_fourier_models_data(X) | ||
|
||
if self.time_varying_intercept: | ||
if self.time_varying_intercept | self.time_varying_media: | ||
data["time_index"] = infer_time_index( | ||
X[self.date_column], self.X[self.date_column], self._time_resolution | ||
) | ||
|
@@ -1671,7 +1724,7 @@ | |
def add_lift_test_measurements( | ||
self, | ||
df_lift_test: pd.DataFrame, | ||
dist: pm.Distribution = pm.Gamma, | ||
dist: type[pm.Distribution] = pm.Gamma, | ||
name: str = "lift_measurements", | ||
) -> None: | ||
"""Add lift tests to the model. | ||
|
@@ -1770,14 +1823,18 @@ | |
channel_transform=self.channel_transformer.transform, | ||
target_transform=self.target_transformer.transform, | ||
) | ||
with self.model: | ||
add_lift_measurements_to_likelihood( | ||
df_lift_test=df_lift_test_scaled, | ||
variable_mapping=self.saturation.variable_mapping, | ||
saturation_function=self.saturation.function, | ||
dist=dist, | ||
name=name, | ||
) | ||
# This is coupled with the name of the | ||
# latent process Deterministic | ||
time_varying_var_name = ( | ||
"media_temporal_latent_multiplier" if self.time_varying_media else None | ||
) | ||
add_lift_measurements_to_likelihood_from_saturation( | ||
df_lift_test=df_lift_test_scaled, | ||
saturation=self.saturation, | ||
time_varying_var_name=time_varying_var_name, | ||
model=self.model, | ||
dist=dist, | ||
) | ||
|
||
def _create_synth_dataset( | ||
self, | ||
|
@@ -2150,6 +2207,7 @@ | |
channel_columns: list[str], | ||
adstock_max_lag: int, | ||
time_varying_intercept: bool = False, | ||
time_varying_media: bool = False, | ||
model_config: dict | None = None, | ||
sampler_config: dict | None = None, | ||
validate_data: bool = True, | ||
|
@@ -2175,6 +2233,7 @@ | |
channel_columns=channel_columns, | ||
adstock_max_lag=adstock_max_lag, | ||
time_varying_intercept=time_varying_intercept, | ||
time_varying_media=time_varying_media, | ||
model_config=model_config, | ||
sampler_config=sampler_config, | ||
validate_data=validate_data, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should clarify is a global baseline right? (not tvp per channel)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be sure we are align, it's not a baseline. The baseline is the variable
channel_contributions
, the media varying variable would be more a time scaling factor adding variance to the mean.