Skip to content

Commit

Permalink
make hill pass through the origin (#920)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz authored Aug 12, 2024
1 parent b070375 commit 285f704
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
6 changes: 3 additions & 3 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ class HillSaturation(SaturationTransformation):
function = hill_saturation

default_priors = {
"sigma": Prior("HalfNormal", sigma=2),
"beta": Prior("HalfNormal", sigma=2),
"lam": Prior("HalfNormal", sigma=2),
"sigma": Prior("HalfNormal", sigma=1.5),
"beta": Prior("HalfNormal", sigma=1.5),
"lam": Prior("HalfNormal", sigma=1.5),
}


Expand Down
11 changes: 7 additions & 4 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,10 +908,10 @@ def hill_saturation(
r"""Hill Saturation Function
.. math::
f(x) = \frac{\sigma}{1 + e^{-\beta(x - \lambda)}}
f(x) = \frac{\sigma}{1 + e^{-\beta(x - \lambda)}} - \frac{\sigma}{1 + e^{\beta\lambda}}
where:
- :math:`\sigma` is the maximum value (upper asymptote)
- :math:`\sigma` is the upper asymptote
- :math:`\beta` is the slope parameter
- :math:`\lambda` is the transition point on the X-axis
- :math:`x` is the independent variable
Expand All @@ -920,7 +920,9 @@ def hill_saturation(
used to describe the saturation effect in biological systems. The curve is
characterized by its sigmoidal shape, representing a gradual transition from
a low, nearly zero level to a high plateau, the maximum value the function
will approach as the independent variable grows large.
will approach as the independent variable grows large. In this implementation,
we add an offset to the sigmoid function to ensure that the function always passes
through the origin as we expect zero spend to result in zero contribution.
.. plot::
:context: close-figs
Expand Down Expand Up @@ -968,6 +970,7 @@ def hill_saturation(
plt.ylabel('Hill Saturation')
plt.tight_layout()
plt.show()
Parameters
----------
x : float or array-like
Expand All @@ -987,7 +990,7 @@ def hill_saturation(
float or array-like
The value of the Hill function for each input value of x.
"""
return sigma / (1 + pt.exp(-beta * (x - lam)))
return sigma / (1 + pt.exp(-beta * (x - lam))) - sigma / (1 + pt.exp(beta * lam))


def root_saturation(
Expand Down
30 changes: 23 additions & 7 deletions tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,23 @@ def test_michaelis_menten(self, x, alpha, lam, expected):
(3, 2, -1),
],
)
def test_monotonicity(self, sigma, beta, lam):
def test_hill_monotonicity(self, sigma, beta, lam):
x = np.linspace(-10, 10, 100)
y = hill_saturation(x, sigma, beta, lam).eval()
assert np.all(np.diff(y) >= 0), "The function is not monotonic."

@pytest.mark.parametrize(
"sigma, beta, lam",
[
(1, 1, 0),
(2, 0.5, 1),
(3, 2, -1),
],
)
def test_hill_zero(self, sigma, beta, lam):
y = hill_saturation(0, sigma, beta, lam).eval()
assert y == pytest.approx(0.0)

@pytest.mark.parametrize(
"x, sigma, beta, lam",
[
Expand All @@ -479,7 +491,7 @@ def test_monotonicity(self, sigma, beta, lam):
(-3, 3, 2, -1),
],
)
def test_sigma_upper_bound(self, x, sigma, beta, lam):
def test_hill_sigma_upper_bound(self, x, sigma, beta, lam):
y = hill_saturation(x, sigma, beta, lam).eval()
assert y <= sigma, f"The output {y} exceeds the upper bound sigma {sigma}."

Expand All @@ -491,11 +503,13 @@ def test_sigma_upper_bound(self, x, sigma, beta, lam):
(-1, 3, 2, -1, 1.5),
],
)
def test_behavior_at_lambda(self, x, sigma, beta, lam, expected):
def test_hill_behavior_at_lambda(self, x, sigma, beta, lam, expected):
y = hill_saturation(x, sigma, beta, lam).eval()
offset = sigma / (1 + np.exp(beta * lam))
expected_with_offset = expected - offset
np.testing.assert_almost_equal(
y,
expected,
expected_with_offset,
decimal=5,
err_msg="The function does not behave as expected at lambda.",
)
Expand All @@ -508,7 +522,7 @@ def test_behavior_at_lambda(self, x, sigma, beta, lam, expected):
(np.array([1, 2, 3]), 3, 2, 2),
],
)
def test_vectorized_input(self, x, sigma, beta, lam):
def test_hill_vectorized_input(self, x, sigma, beta, lam):
y = hill_saturation(x, sigma, beta, lam).eval()
assert (
y.shape == x.shape
Expand All @@ -522,12 +536,14 @@ def test_vectorized_input(self, x, sigma, beta, lam):
(3, 2, -1),
],
)
def test_asymptotic_behavior(self, sigma, beta, lam):
def test_hill_asymptotic_behavior(self, sigma, beta, lam):
x = 1e6 # A very large value to approximate infinity
y = hill_saturation(x, sigma, beta, lam).eval()
offset = sigma / (1 + np.exp(beta * lam))
expected = sigma - offset
np.testing.assert_almost_equal(
y,
sigma,
expected,
decimal=5,
err_msg="The function does not approach sigma as x approaches infinity.",
)
Expand Down

0 comments on commit 285f704

Please sign in to comment.