Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating Time Base component for Media Contribution #752

Merged
merged 16 commits into from
Jun 22, 2024
Merged
712 changes: 394 additions & 318 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

121 changes: 90 additions & 31 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
_get_saturation_function,
)
from pymc_marketing.mmm.lift_test import (
add_lift_measurements_to_likelihood,
add_lift_measurements_to_likelihood_from_saturation,
scale_lift_measurements,
)
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
from pymc_marketing.mmm.tvp import create_time_varying_intercept, infer_time_index
from pymc_marketing.mmm.tvp import create_time_varying_gp_multiplier, infer_time_index
from pymc_marketing.mmm.utils import (
apply_sklearn_transformer_across_dim,
create_new_spend_data,
Expand Down Expand Up @@ -81,6 +81,7 @@
adstock: str | AdstockTransformation,
saturation: str | SaturationTransformation,
time_varying_intercept: bool = False,
time_varying_media: bool = False,
model_config: dict | None = None,
sampler_config: dict | None = None,
validate_data: bool = True,
Expand All @@ -105,6 +106,13 @@
Type of saturation transformation to apply.
time_varying_intercept : bool, optional
Whether to consider time-varying intercept, by default False.
Because the `time-varying` variable is centered around 1 and acts as a multiplier,
the variable `base_intercept` now represents the mean of the time-varying intercept.
time_varying_media : bool, optional
Whether to consider time-varying media contributions, by default False.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should clarify is a global baseline right? (not tvp per channel)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure we are align, it's not a baseline. The baseline is the variable channel_contributions, the media varying variable would be more a time scaling factor adding variance to the mean.

The `time-varying-media` creates a time media variable centered around 1,
this variable acts as a global multiplier (scaling factor) for all channels,
meaning all media channels share the same latent fluctiation.
model_config : Dictionary, optional
dictionary of parameters that initialise model configuration.
Class-default defined by the user default_model_config method.
Expand All @@ -121,6 +129,7 @@
self.control_columns = control_columns
self.adstock_max_lag = adstock_max_lag
self.time_varying_intercept = time_varying_intercept
self.time_varying_media = time_varying_media

Check warning on line 132 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L132

Added line #L132 was not covered by tests
self.yearly_seasonality = yearly_seasonality
self.date_column = date_column
self.validate_data = validate_data
Expand Down Expand Up @@ -220,7 +229,7 @@
self.X: pd.DataFrame = X_data
self.y: pd.Series | np.ndarray = y

if self.time_varying_intercept:
if self.time_varying_intercept | self.time_varying_media:

Check warning on line 232 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L232

Added line #L232 was not covered by tests
self._time_index = np.arange(0, X.shape[0])
self._time_index_mid = X.shape[0] // 2
self._time_resolution = (
Expand Down Expand Up @@ -344,33 +353,66 @@
dims="date",
mutable=True,
)

if self.time_varying_intercept:
if self.time_varying_intercept | self.time_varying_media:

Check warning on line 356 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L356

Added line #L356 was not covered by tests
time_index = pm.Data(
"time_index",
self._time_index,
dims="date",
)
intercept_dist = get_distribution(

if self.time_varying_intercept:
intercept_distribution = get_distribution(

Check warning on line 364 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L363-L364

Added lines #L363 - L364 were not covered by tests
name=self.model_config["intercept"]["dist"]
)
intercept = create_time_varying_intercept(
time_index,
self._time_index_mid,
self._time_resolution,
intercept_dist,
self.model_config,
base_intercept = intercept_distribution(

Check warning on line 367 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L367

Added line #L367 was not covered by tests
name="base_intercept", **self.model_config["intercept"]["kwargs"]
)

intercept_latent_process = create_time_varying_gp_multiplier(

Check warning on line 371 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L371

Added line #L371 was not covered by tests
name="intercept",
dims="date",
time_index=time_index,
time_index_mid=self._time_index_mid,
time_resolution=self._time_resolution,
model_config=self.model_config,
)
intercept = pm.Deterministic(

Check warning on line 379 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L379

Added line #L379 was not covered by tests
name="intercept",
var=base_intercept * intercept_latent_process,
dims="date",
)
else:
intercept = create_distribution_from_config(
name="intercept", config=self.model_config
)

channel_contributions = pm.Deterministic(
name="channel_contributions",
var=self.forward_pass(x=channel_data_),
dims=("date", "channel"),
)
if self.time_varying_media:
base_channel_contributions = pm.Deterministic(

Check warning on line 390 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L389-L390

Added lines #L389 - L390 were not covered by tests
name="base_channel_contributions",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "baseline" a better name than "base" ? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with both, I didn't made it but I can if you feel is better!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think baseline makes a bit more sense!

var=self.forward_pass(x=channel_data_),
dims=("date", "channel"),
)

media_latent_process = create_time_varying_gp_multiplier(

Check warning on line 396 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L396

Added line #L396 was not covered by tests
name="media",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to specify this is a global baseline go per channel. So maybe "baseline_media" is a better name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the same comment the media multiplier is not the baseline, is more a scaling factor! But let's agree on the words 👀

dims="date",
time_index=time_index,
time_index_mid=self._time_index_mid,
time_resolution=self._time_resolution,
model_config=self.model_config,
)
channel_contributions = pm.Deterministic(

Check warning on line 404 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L404

Added line #L404 was not covered by tests
name="channel_contributions",
var=base_channel_contributions * media_latent_process[:, None],
dims=("date", "channel"),
)

else:
channel_contributions = pm.Deterministic(

Check warning on line 411 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L411

Added line #L411 was not covered by tests
name="channel_contributions",
var=self.forward_pass(x=channel_data_),
dims=("date", "channel"),
)

mu_var = intercept + channel_contributions.sum(axis=-1)

Expand Down Expand Up @@ -465,7 +507,7 @@

@property
def default_model_config(self) -> dict:
base_config = {
base_config: dict[str, Any] = {

Check warning on line 510 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L510

Added line #L510 was not covered by tests
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"likelihood": {
"dist": "Normal",
Expand All @@ -483,15 +525,26 @@
"kwargs": {"mu": 0, "b": 1},
"dims": "fourier_mode",
},
"intercept_tvp_kwargs": {
}

if self.time_varying_intercept:
base_config["intercept_tvp_config"] = {

Check warning on line 531 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L530-L531

Added lines #L530 - L531 were not covered by tests
"m": 200,
"L": None,
"eta_lam": 1,
"ls_mu": None,
"ls_sigma": 10,
"cov_func": None,
},
}
}
if self.time_varying_media:
base_config["media_tvp_config"] = {

Check warning on line 540 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L539-L540

Added lines #L539 - L540 were not covered by tests
"m": 200,
"L": None,
"eta_lam": 1,
"ls_mu": None,
"ls_sigma": 10,
"cov_func": None,
}

for media_transform in [self.adstock, self.saturation]:
for param, config in media_transform.function_priors.items():
Expand Down Expand Up @@ -712,7 +765,7 @@
if hasattr(self, "fourier_columns"):
data["fourier_data"] = self._get_fourier_models_data(X)

if self.time_varying_intercept:
if self.time_varying_intercept | self.time_varying_media:

Check warning on line 768 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L768

Added line #L768 was not covered by tests
data["time_index"] = infer_time_index(
X[self.date_column], self.X[self.date_column], self._time_resolution
)
Expand Down Expand Up @@ -1671,7 +1724,7 @@
def add_lift_test_measurements(
self,
df_lift_test: pd.DataFrame,
dist: pm.Distribution = pm.Gamma,
dist: type[pm.Distribution] = pm.Gamma,
name: str = "lift_measurements",
) -> None:
"""Add lift tests to the model.
Expand Down Expand Up @@ -1770,14 +1823,18 @@
channel_transform=self.channel_transformer.transform,
target_transform=self.target_transformer.transform,
)
with self.model:
add_lift_measurements_to_likelihood(
df_lift_test=df_lift_test_scaled,
variable_mapping=self.saturation.variable_mapping,
saturation_function=self.saturation.function,
dist=dist,
name=name,
)
# This is coupled with the name of the
# latent process Deterministic
time_varying_var_name = (

Check warning on line 1828 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L1828

Added line #L1828 was not covered by tests
"media_temporal_latent_multiplier" if self.time_varying_media else None
)
add_lift_measurements_to_likelihood_from_saturation(

Check warning on line 1831 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L1831

Added line #L1831 was not covered by tests
df_lift_test=df_lift_test_scaled,
saturation=self.saturation,
time_varying_var_name=time_varying_var_name,
model=self.model,
dist=dist,
)

def _create_synth_dataset(
self,
Expand Down Expand Up @@ -2150,6 +2207,7 @@
channel_columns: list[str],
adstock_max_lag: int,
time_varying_intercept: bool = False,
time_varying_media: bool = False,
model_config: dict | None = None,
sampler_config: dict | None = None,
validate_data: bool = True,
Expand All @@ -2175,6 +2233,7 @@
channel_columns=channel_columns,
adstock_max_lag=adstock_max_lag,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
model_config=model_config,
sampler_config=sampler_config,
validate_data=validate_data,
Expand Down
112 changes: 102 additions & 10 deletions pymc_marketing/mmm/lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import pytensor.tensor as pt
from numpy import typing as npt

from pymc_marketing.mmm.components.saturation import SaturationTransformation


class MissingLiftTestError(Exception):
def __init__(self, missing_values: npt.NDArray[np.int_]) -> None:
Expand Down Expand Up @@ -212,12 +214,16 @@
)


SaturationFunc = Callable
VariableMapping = dict[str, str]


def add_lift_measurements_to_likelihood(
df_lift_test: pd.DataFrame,
variable_mapping,
saturation_function,
variable_mapping: VariableMapping,
saturation_function: SaturationFunc,
model: pm.Model | None = None,
dist=pm.Gamma,
dist: type[pm.Distribution] = pm.Gamma,
name: str = "lift_measurements",
) -> None:
"""Add lift measurements to the likelihood of the model.
Expand All @@ -239,7 +245,7 @@
Function that takes spend and returns saturation.
model : Optional[pm.Model], optional
PyMC model with arbitrary number of coordinates, by default None
dist : pm.Distribution, optional
dist : pm.Distribution class, optional
PyMC distribution to use for the likelihood, by default pm.Gamma
name : str, optional
Name of the likelihood, by default "lift_measurements"
Expand Down Expand Up @@ -315,12 +321,13 @@
x_before, x_after, partial_saturation_function
)

dist(
name=name,
mu=pt.abs(model_estimated_lift),
sigma=df_lift_test["sigma"].to_numpy(),
observed=np.abs(df_lift_test["delta_y"].to_numpy()),
)
with pm.modelcontext(model):
dist(

Check warning on line 325 in pymc_marketing/mmm/lift_test.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/lift_test.py#L324-L325

Added lines #L324 - L325 were not covered by tests
name=name,
mu=pt.abs(model_estimated_lift),
sigma=df_lift_test["sigma"].to_numpy(),
observed=np.abs(df_lift_test["delta_y"].to_numpy()),
)


def _swap_columns_and_last_index_level(df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -467,3 +474,88 @@
[df_lift_test_channel_scaled, df_target_scaled, df_sigma_scaled],
axis=1,
)


def create_time_varying_saturation(
saturation: SaturationTransformation,
time_varying_var_name: str,
) -> tuple[SaturationFunc, VariableMapping]:
"""Return function and variable mapping.

Parameters
----------
saturation : SaturationTransformation
Any SaturationTransformation instance.
time_varying_var_name : str, optional
Name of the time-varying variable in model.

Returns
-------
tuple[SaturationFunc, VariableMapping]
Tuple of function and variable mapping to be used in
add_lift_measurements_to_likelihood function.

"""

def function(x, time_varying: pt.TensorVariable, **kwargs):
return time_varying * saturation.function(x, **kwargs)

Check warning on line 501 in pymc_marketing/mmm/lift_test.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/lift_test.py#L500-L501

Added lines #L500 - L501 were not covered by tests

variable_mapping = {

Check warning on line 503 in pymc_marketing/mmm/lift_test.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/lift_test.py#L503

Added line #L503 was not covered by tests
**saturation.variable_mapping,
"time_varying": time_varying_var_name,
}

return function, variable_mapping

Check warning on line 508 in pymc_marketing/mmm/lift_test.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/lift_test.py#L508

Added line #L508 was not covered by tests


def add_lift_measurements_to_likelihood_from_saturation(
df_lift_test: pd.DataFrame,
saturation: SaturationTransformation,
time_varying_var_name: str | None = None,
model: pm.Model | None = None,
dist: type[pm.Distribution] = pm.Gamma,
name: str = "lift_measurements",
) -> None:
"""Wrapper around add_lift_measurements_to_likelihood to work with
SaturationTransformation instances and time-varying variables.

Parameters
----------
df_lift_test : pd.DataFrame
DataFrame with lift test results with at least the following columns:
* `x`: x axis value of the lift test.
* `delta_x`: change in x axis value of the lift test.
* `delta_y`: change in y axis value of the lift test.
* `sigma`: standard deviation of the lift test.
saturation : SaturationTransformation
Any SaturationTransformation instance.
time_varying_var_name : str, optional
Name of the time-varying variable in model.
model : Optional[pm.Model], optional
PyMC model with arbitrary number of coordinates, by default None
dist : pm.Distribution class, optional
PyMC distribution to use for the likelihood, by default pm.Gamma
name : str, optional
Name of the likelihood, by default "lift_measurements"

"""

if time_varying_var_name:
saturation_function, variable_mapping = create_time_varying_saturation(

Check warning on line 544 in pymc_marketing/mmm/lift_test.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/lift_test.py#L543-L544

Added lines #L543 - L544 were not covered by tests
saturation=saturation,
# This is coupled with the name of the
# latent process Deterministic
time_varying_var_name=time_varying_var_name,
)
else:
saturation_function = saturation.function
variable_mapping = saturation.variable_mapping

Check warning on line 552 in pymc_marketing/mmm/lift_test.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/lift_test.py#L551-L552

Added lines #L551 - L552 were not covered by tests

add_lift_measurements_to_likelihood(

Check warning on line 554 in pymc_marketing/mmm/lift_test.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/lift_test.py#L554

Added line #L554 was not covered by tests
df_lift_test=df_lift_test,
variable_mapping=variable_mapping,
saturation_function=saturation_function,
dist=dist,
name=name,
model=model,
)
Loading
Loading