diff --git a/python/src/greenbids/tailor/core/app/routers/root.py b/python/src/greenbids/tailor/core/app/routers/root.py index f5a058e..2f26582 100644 --- a/python/src/greenbids/tailor/core/app/routers/root.py +++ b/python/src/greenbids/tailor/core/app/routers/root.py @@ -1,14 +1,14 @@ -from fastapi import APIRouter +import fastapi from greenbids.tailor.core import fabric from .. import resources -router = APIRouter(tags=["Main"]) +router = fastapi.APIRouter(tags=["Main"]) @router.put("/") async def get_buyers_probabilities( - fabrics: list[fabric.Fabric], -) -> list[fabric.Fabric]: + fabrics: list[fabric.PredictionInput], +) -> list[fabric.PredictionOutput]: """Compute the probability of the buyers to provide a bid. This must be called for each adcall. @@ -18,14 +18,18 @@ async def get_buyers_probabilities( return resources.get_instance().gb_model.get_buyers_probabilities(fabrics) -@router.post("/") +@router.post( + "/", + response_class=fastapi.Response, + status_code=fastapi.status.HTTP_204_NO_CONTENT, +) async def report_buyers_status( - fabrics: list[fabric.Fabric], -) -> list[fabric.Fabric]: + fabrics: list[fabric.ReportInput], +): """Train model according to actual outcome. This must NOT be called for each adcall, but only for exploration ones. All fields of the fabrics need to be set. Returns the same data than the input. """ - return resources.get_instance().gb_model.report_buyers_status(fabrics) + resources.get_instance().gb_model.report_buyers_status(fabrics) diff --git a/python/src/greenbids/tailor/core/fabric.py b/python/src/greenbids/tailor/core/fabric.py index 33ffb00..d1d2c57 100644 --- a/python/src/greenbids/tailor/core/fabric.py +++ b/python/src/greenbids/tailor/core/fabric.py @@ -1,23 +1,23 @@ +import typing import pydantic import pydantic.alias_generators -class _CamelSerialized(pydantic.BaseModel): - model_config = pydantic.ConfigDict( - alias_generator=pydantic.alias_generators.to_camel, - populate_by_name=True, - use_attribute_docstrings=True, - ) +_BaseConfig = pydantic.ConfigDict( + alias_generator=pydantic.alias_generators.to_camel, + populate_by_name=True, + use_attribute_docstrings=True, +) +class _BaseTypedDict(typing.TypedDict): + __pydantic_config__ = _BaseConfig # type: ignore -class FeatureMap(_CamelSerialized, pydantic.RootModel): - """Mapping describing the current opportunity.""" - root: dict[str, bool | int | float | bytes | str] = pydantic.Field( - default_factory=dict - ) +FeatureMap: typing.TypeAlias = dict[str, typing.Any] -class Prediction(_CamelSerialized): +class Prediction(pydantic.BaseModel): + model_config = _BaseConfig + """Result of the shaping process.""" score: float = -1 """Confidence score returned by the model""" @@ -33,18 +33,22 @@ def should_send(self) -> bool: return self.is_exploration or (self.score > self.threshold) -class GroundTruth(_CamelSerialized): +class GroundTruth(_BaseTypedDict): """Actual outcome of the opportunity""" - has_response: bool = True + has_response: typing.Annotated[bool, pydantic.Field(default=True)] """Did this opportunity lead to a valid buyer response?""" -class Fabric(_CamelSerialized): - """Main entity used to tailor the traffic. +class ReportInput(_BaseTypedDict): + feature_map: FeatureMap + prediction: Prediction + ground_truth: GroundTruth + + +class PredictionInput(_BaseTypedDict): + feature_map: FeatureMap - All fields are optional when irrelevant. - """ - feature_map: FeatureMap = pydantic.Field(default_factory=FeatureMap) - prediction: Prediction = pydantic.Field(default_factory=Prediction) - ground_truth: GroundTruth = pydantic.Field(default_factory=GroundTruth) +class PredictionOutput(_BaseTypedDict): + feature_map: FeatureMap + prediction: Prediction diff --git a/python/src/greenbids/tailor/core/models.py b/python/src/greenbids/tailor/core/models.py index fabf141..49a627f 100644 --- a/python/src/greenbids/tailor/core/models.py +++ b/python/src/greenbids/tailor/core/models.py @@ -16,15 +16,15 @@ class Model(ABC): @abstractmethod def get_buyers_probabilities( self, - fabrics: list[fabric.Fabric], - ) -> list[fabric.Fabric]: + fabrics: list[fabric.PredictionInput], + ) -> list[fabric.PredictionOutput]: raise NotImplementedError @abstractmethod def report_buyers_status( self, - fabrics: list[fabric.Fabric], - ) -> list[fabric.Fabric]: + fabrics: list[fabric.ReportInput], + ) -> None: raise NotImplementedError def dump(self, fp: typing.BinaryIO) -> None: @@ -43,17 +43,19 @@ def __init__(self): def get_buyers_probabilities( self, - fabrics: list[fabric.Fabric], - ) -> list[fabric.Fabric]: + fabrics: list[fabric.PredictionInput], + ) -> list[fabric.PredictionOutput]: prediction = fabric.Prediction(score=1, is_exploration=(random.random() < 0.2)) - return [f.model_copy(update=dict(prediction=prediction)) for f in fabrics] + return [ + fabric.PredictionOutput(feature_map=f["feature_map"], prediction=prediction) + for f in fabrics + ] def report_buyers_status( self, - fabrics: list[fabric.Fabric], - ) -> list[fabric.Fabric]: - self._logger.debug([f.feature_map.root for f in fabrics[:1]]) - return fabrics + fabrics: list[fabric.ReportInput], + ) -> None: + self._logger.debug([f.get("feature_map", {}) for f in fabrics[:1]]) ENTRY_POINTS_GROUP = "greenbids-tailor-models"