Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.fit doesn't remove prior samples #741

Merged
merged 9 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
raise ValueError(f"Column {required_col} has duplicate entries")

def __repr__(self):
if self.model is None:
if not hasattr(self, "model"):

Check warning on line 61 in pymc_marketing/clv/models/basic.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/models/basic.py#L61

Added line #L61 was not covered by tests
return self._model_type
else:
return f"{self._model_type}\n{self.model.str_repr()}"
Expand Down
12 changes: 8 additions & 4 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,16 @@
def prior(self) -> Dataset:
if self.idata is None or "prior" not in self.idata:
raise RuntimeError(
"The model hasn't been fit yet, call .sample_prior_predictive() with extend_idata=True first"
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
)
return self.idata["prior"]

@property
def prior_predictive(self) -> az.InferenceData:
def prior_predictive(self) -> Dataset:
if self.idata is None or "prior_predictive" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
raise RuntimeError(

Check warning on line 277 in pymc_marketing/mmm/base.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/base.py#L277

Added line #L277 was not covered by tests
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
)
return self.idata["prior_predictive"]

@property
Expand All @@ -286,7 +288,9 @@
@property
def posterior_predictive(self) -> Dataset:
if self.idata is None or "posterior_predictive" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
raise RuntimeError(

Check warning on line 291 in pymc_marketing/mmm/base.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/base.py#L291

Added line #L291 was not covered by tests
"The model hasn't been fit yet, call .sample_posterior_predictive() first"
)
return self.idata["posterior_predictive"]

def plot_prior_predictive(
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,7 @@
model.add_lift_test_measurements(df_lift_test)

"""
if self.model is None:
if not hasattr(self, "model"):

Check warning on line 1832 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L1832

Added line #L1832 was not covered by tests
raise RuntimeError(
"The model has not been built yet. Please, build the model first."
)
Expand Down
70 changes: 40 additions & 30 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
self.model_config = (
self.default_model_config | model_config
) # parameters for priors etc.
self.model: pm.Model | None = None # Set by build_model
self.model: pm.Model
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't have to check self.model is None all the time...

self.idata: az.InferenceData | None = None # idata is generated during fitting
self.is_fitted_ = False

Expand Down Expand Up @@ -458,19 +458,22 @@
if self.X is None or self.y is None:
raise ValueError("X and y must be set before calling build_model!")

if self.model is None:
if not hasattr(self, "model"):

Check warning on line 461 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L461

Added line #L461 was not covered by tests
self.build_model(self.X, self.y)

sampler_config = self.sampler_config.copy()
sampler_config["progressbar"] = progressbar
sampler_config["random_seed"] = random_seed
sampler_config.update(**kwargs)

sampler_config.update(**kwargs)
if self.model is not None:
with self.model:
sampler_args = {**self.sampler_config, **kwargs}
self.idata = pm.sample(**sampler_args)
sampler_args = {**self.sampler_config, **kwargs}
with self.model:
idata = pm.sample(**sampler_args)

Check warning on line 471 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L469-L471

Added lines #L469 - L471 were not covered by tests

if self.idata:
self.idata.extend(idata, join="right")

Check warning on line 474 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L473-L474

Added lines #L473 - L474 were not covered by tests
else:
self.idata = idata

Check warning on line 476 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L476

Added line #L476 was not covered by tests

X_df = pd.DataFrame(X, columns=X.columns)
combined_data = pd.concat([X_df, y_df], axis=1)
Expand Down Expand Up @@ -537,7 +540,7 @@
X_pred,
y_pred=None,
samples: int | None = None,
extend_idata: bool = False,
extend_idata: bool = True,
combined: bool = True,
**kwargs,
):
Expand All @@ -552,7 +555,7 @@
Number of samples from the prior parameter distributions to generate.
If not set, uses sampler_config['draws'] if that is available, otherwise defaults to 500.
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to False.
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_prior_predictive
Expand All @@ -567,21 +570,19 @@
if samples is None:
samples = self.sampler_config.get("draws", 500)

if self.model is None:
if not hasattr(self, "model"):

Check warning on line 573 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L573

Added line #L573 was not covered by tests
self.build_model(X_pred, y_pred)

self._data_setter(X_pred, y_pred)
if self.model is not None:
with self.model: # sample with new input data
prior_pred: az.InferenceData = pm.sample_prior_predictive(
samples, **kwargs
)
self.set_idata_attrs(prior_pred)
if extend_idata:
if self.idata is not None:
self.idata.extend(prior_pred, join="right")
else:
self.idata = prior_pred
with self.model: # sample with new input data
prior_pred: az.InferenceData = pm.sample_prior_predictive(samples, **kwargs)
self.set_idata_attrs(prior_pred)

Check warning on line 579 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L577-L579

Added lines #L577 - L579 were not covered by tests

if extend_idata:
if self.idata is not None:
self.idata.extend(prior_pred, join="right")

Check warning on line 583 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L581-L583

Added lines #L581 - L583 were not covered by tests
else:
self.idata = prior_pred

Check warning on line 585 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L585

Added line #L585 was not covered by tests

prior_predictive_samples = az.extract(
prior_pred, "prior_predictive", combined=combined
Expand All @@ -590,7 +591,11 @@
return prior_predictive_samples

def sample_posterior_predictive(
self, X_pred, extend_idata: bool = True, combined: bool = True, **kwargs
self,
X_pred,
extend_idata: bool = True,
combined: bool = True,
**sample_posterior_predictive_kwargs,
):
"""
Sample from the model's posterior predictive distribution.
Expand All @@ -603,7 +608,7 @@
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
**sample_posterior_predictive_kwargs: Additional arguments to pass to pymc.sample_posterior_predictive

Returns
-------
Expand All @@ -612,16 +617,21 @@
"""
self._data_setter(X_pred)

with self.model: # type: ignore
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs)
if extend_idata:
self.idata.extend(post_pred, join="right") # type: ignore
with self.model:
post_pred = pm.sample_posterior_predictive(

Check warning on line 621 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L620-L621

Added lines #L620 - L621 were not covered by tests
self.idata, **sample_posterior_predictive_kwargs
)

if extend_idata:
self.idata.extend(post_pred, join="right") # type: ignore

Check warning on line 626 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L625-L626

Added lines #L625 - L626 were not covered by tests

posterior_predictive_samples = az.extract(
post_pred, "posterior_predictive", combined=combined
variable_name = (

Check warning on line 628 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L628

Added line #L628 was not covered by tests
"predictions"
if sample_posterior_predictive_kwargs.get("predictions")
else "posterior_predictive"
)

return posterior_predictive_samples
return az.extract(post_pred, variable_name, combined=combined)

Check warning on line 634 in pymc_marketing/model_builder.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/model_builder.py#L634

Added line #L634 was not covered by tests

def get_params(self, deep=True):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def set_model_fit(model: CLVModel, fit: InferenceData | Dataset):
assert "posterior" in fit.groups()
else:
fit = InferenceData(posterior=fit)
if model.model is None:
if not hasattr(model, "model"):
model.build_model()
model.idata = fit
model.idata.add_groups(fit_data=model.data.to_xarray())
Expand Down
6 changes: 4 additions & 2 deletions tests/mmm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def test_calling_prior_predictive_before_fit_raises_error(test_mmm, toy_X, toy_y
test_mmm.idata = None
with pytest.raises(
RuntimeError,
match=re.escape("The model hasn't been fit yet, call .fit() first"),
match=re.escape(
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
),
):
test_mmm.prior_predictive

Expand All @@ -297,7 +299,7 @@ def test_calling_prior_before_sample_prior_predictive_raises_error(
with pytest.raises(
RuntimeError,
match=re.escape(
"The model hasn't been fit yet, call .sample_prior_predictive() with extend_idata=True first"
"The model hasn't been sampled yet, call .sample_prior_predictive() first",
),
):
test_mmm.prior
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def _save_input_params(self, idata):
def output_var(self):
return "output"

def _data_setter(self, X: pd.Series, y: pd.Series = None):
def _data_setter(self, X: pd.DataFrame, y: pd.Series = None):
with self.model:
pm.set_data({"x": X.values})
pm.set_data({"x": X["input"].values})
if y is not None:
y = y.values if isinstance(y, pd.Series) else y
pm.set_data({"y_data": y})
Expand Down Expand Up @@ -195,8 +195,8 @@ def test_save_load(fitted_model_instance):
assert fitted_model_instance.id == test_builder2.id
x_pred = rng.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred1 = fitted_model_instance.predict(prediction_data["input"])
pred2 = test_builder2.predict(prediction_data["input"])
pred1 = fitted_model_instance.predict(prediction_data)
pred2 = test_builder2.predict(prediction_data)
assert pred1.shape == pred2.shape
temp.close()

Expand Down Expand Up @@ -230,9 +230,9 @@ def test_fit(fitted_model_instance):
assert fitted_model_instance.idata.posterior.dims["draw"] == 100

prediction_data = pd.DataFrame({"input": rng.uniform(low=0, high=1, size=100)})
fitted_model_instance.predict(prediction_data["input"])
fitted_model_instance.predict(prediction_data)
post_pred = fitted_model_instance.sample_posterior_predictive(
prediction_data["input"], extend_idata=True, combined=True
prediction_data, extend_idata=True, combined=True
)
assert (
post_pred[fitted_model_instance.output_var].shape[0]
Expand All @@ -256,7 +256,7 @@ def test_predict(fitted_model_instance):
rng = np.random.default_rng(42)
x_pred = rng.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred = fitted_model_instance.predict(prediction_data["input"])
pred = fitted_model_instance.predict(prediction_data)
# Perform elementwise comparison using numpy
assert type(pred) == np.ndarray
assert len(pred) > 0
Expand All @@ -269,7 +269,7 @@ def test_sample_posterior_predictive(fitted_model_instance, combined):
x_pred = rng.uniform(low=0, high=1, size=n_pred)
prediction_data = pd.DataFrame({"input": x_pred})
pred = fitted_model_instance.sample_posterior_predictive(
prediction_data["input"], combined=combined, extend_idata=True
prediction_data, combined=combined, extend_idata=True
)
chains = fitted_model_instance.idata.sample_stats.dims["chain"]
draws = fitted_model_instance.idata.sample_stats.dims["draw"]
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_sample_xxx_predictive_keeps_second(
method_name = f"sample_{name}"
method = getattr(fitted_model_instance, method_name)

X_pred = toy_X["input"]
X_pred = toy_X

kwargs = {
"X_pred": X_pred,
Expand All @@ -329,3 +329,26 @@ def test_sample_xxx_predictive_keeps_second(

sample = getattr(fitted_model_instance.idata, name)
xr.testing.assert_allclose(sample, second_sample)


def test_prediction_kwarg(fitted_model_instance, toy_X):
result = fitted_model_instance.sample_posterior_predictive(
toy_X,
extend_idata=True,
predictions=True,
)
assert "predictions" in fitted_model_instance.idata
assert "predictions_constant_data" in fitted_model_instance.idata

assert isinstance(result, xr.Dataset)


def test_fit_after_prior_keeps_prior(toy_X, toy_y):
model = ModelBuilderTest()
model.sample_prior_predictive(toy_X)
assert "prior" in model.idata
assert "prior_predictive" in model.idata

model.fit(X=toy_X, y=toy_y, chains=1, draws=100, tune=100)
assert "prior" in model.idata
assert "prior_predictive" in model.idata
Loading