Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor _register_dataset_changes #42343

Merged
merged 11 commits into from
Sep 25, 2024
37 changes: 16 additions & 21 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,15 @@ def add_datasets(self, *, session: Session) -> dict[str, DatasetModel]:
dm.uri: dm
for dm in session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets)))
}

def _resolve_dataset_addition() -> Iterator[DatasetModel]:
for uri, dataset in self.datasets.items():
try:
dm = orm_datasets[uri]
except KeyError:
dm = orm_datasets[uri] = DatasetModel.from_public(dataset)
yield dm
else:
# The orphaned flag was bulk-set to True before parsing, so we
# don't need to handle rows in the db without a public entry.
dm.is_orphaned = expression.false()
dm.extra = dataset.extra

dataset_manager.create_datasets(list(_resolve_dataset_addition()), session=session)
for model in orm_datasets.values():
model.is_orphaned = expression.false()
orm_datasets.update(
(model.uri, model)
for model in dataset_manager.create_datasets(
[dataset for uri, dataset in self.datasets.items() if uri not in orm_datasets],
session=session,
)
)
return orm_datasets

def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]:
Expand All @@ -318,12 +312,13 @@ def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasMode
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases))
)
}
for name, alias in self.dataset_aliases.items():
try:
da = orm_aliases[name]
except KeyError:
da = orm_aliases[name] = DatasetAliasModel.from_public(alias)
session.add(da)
orm_aliases.update(
(model.name, model)
for model in dataset_manager.create_dataset_aliases(
[alias for name, alias in self.dataset_aliases.items() if name not in orm_aliases],
session=session,
)
)
return orm_aliases

def add_dag_dataset_references(
Expand Down
81 changes: 66 additions & 15 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
# under the License.
from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Collection, Iterable
from typing import TYPE_CHECKING

from sqlalchemy import exc, select
from sqlalchemy.orm import joinedload

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.listeners.listener import get_listener_manager
from airflow.models.dagbag import DagPriorityParsingRequest
from airflow.models.dataset import (
Expand All @@ -43,6 +42,7 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.datasets import Dataset, DatasetAlias
from airflow.models.dag import DagModel
from airflow.models.taskinstance import TaskInstance

Expand All @@ -58,12 +58,55 @@ class DatasetManager(LoggingMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None:
def create_datasets(self, datasets: list[Dataset], *, session: Session) -> list[DatasetModel]:
"""Create new datasets."""
for dataset_model in dataset_models:
session.add(dataset_model)
for dataset_model in dataset_models:
self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))

def _add_one(dataset: Dataset) -> DatasetModel:
model = DatasetModel.from_public(dataset)
session.add(model)
return model

models = [_add_one(d) for d in datasets]
for dataset in datasets:
self.notify_dataset_created(dataset=dataset)
return models

def create_dataset_aliases(
self,
dataset_aliases: list[DatasetAlias],
*,
session: Session,
) -> list[DatasetAliasModel]:
"""Create new dataset aliases."""

def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel:
model = DatasetAliasModel.from_public(dataset_alias)
session.add(model)
return model

models = [_add_one(a) for a in dataset_aliases]
for dataset_alias in dataset_aliases:
self.notify_dataset_alias_created(dataset_alias=dataset_alias)
return models

@classmethod
def _add_dataset_alias_association(
cls,
alias_names: Collection[str],
dataset: DatasetModel,
*,
session: Session,
) -> None:
already_related = {m.name for m in dataset.aliases}
existing_aliases = {
m.name: m
for m in session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names)))
}
dataset.aliases.extend(
existing_aliases.get(name, DatasetAliasModel(name=name))
for name in alias_names
if name not in already_related
)

@classmethod
@internal_api_call
Expand All @@ -74,8 +117,9 @@ def register_dataset_change(
task_instance: TaskInstance | None = None,
dataset: Dataset,
extra=None,
session: Session = NEW_SESSION,
aliases: Collection[DatasetAlias] = (),
source_alias_names: Iterable[str] | None = None,
session: Session = NEW_SESSION,
**kwargs,
) -> DatasetEvent | None:
"""
Expand All @@ -88,24 +132,27 @@ def register_dataset_change(
dataset_model = session.scalar(
select(DatasetModel)
.where(DatasetModel.uri == dataset.uri)
.options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
.options(
joinedload(DatasetModel.aliases),
joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag),
)
)
if not dataset_model:
cls.logger().warning("DatasetModel %s not found", dataset)
return None

cls._add_dataset_alias_association({alias.name for alias in aliases}, dataset_model, session=session)

event_kwargs = {
"dataset_id": dataset_model.id,
"extra": extra,
}
if task_instance:
event_kwargs.update(
{
"source_task_id": task_instance.task_id,
"source_dag_id": task_instance.dag_id,
"source_run_id": task_instance.run_id,
"source_map_index": task_instance.map_index,
}
source_task_id=task_instance.task_id,
source_dag_id=task_instance.dag_id,
source_run_id=task_instance.run_id,
source_map_index=task_instance.map_index,
)

dataset_event = DatasetEvent(**event_kwargs)
Expand Down Expand Up @@ -155,6 +202,10 @@ def notify_dataset_created(self, dataset: Dataset):
"""Run applicable notification actions when a dataset is created."""
get_listener_manager().hook.on_dataset_created(dataset=dataset)

def notify_dataset_alias_created(self, dataset_alias: DatasetAlias):
"""Run applicable notification actions when a dataset alias is created."""
get_listener_manager().hook.on_dataset_alias_created(dataset_alias=dataset_alias)

@classmethod
def notify_dataset_changed(cls, dataset: Dataset):
"""Run applicable notification actions when a dataset is changed."""
Expand Down
9 changes: 8 additions & 1 deletion airflow/listeners/spec/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pluggy import HookspecMarker

if TYPE_CHECKING:
from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAlias

hookspec = HookspecMarker("airflow")

Expand All @@ -34,6 +34,13 @@ def on_dataset_created(
"""Execute when a new dataset is created."""


@hookspec
def on_dataset_alias_created(
dataset_alias: DatasetAlias,
):
"""Execute when a new dataset alias is created."""


@hookspec
def on_dataset_changed(
dataset: Dataset,
Expand Down
6 changes: 6 additions & 0 deletions airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def __eq__(self, other):
else:
return NotImplemented

def to_public(self) -> DatasetAlias:
return DatasetAlias(name=self.name)


class DatasetModel(Base):
"""
Expand Down Expand Up @@ -200,6 +203,9 @@ def __hash__(self):
def __repr__(self):
return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"

def to_public(self) -> Dataset:
return Dataset(uri=self.uri, extra=self.extra)


class DagScheduleDatasetAliasReference(Base):
"""References from a DAG to a dataset alias of which it is a consumer."""
Expand Down
51 changes: 23 additions & 28 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from contextlib import nullcontext
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple
from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Generator, Iterable, Mapping, Tuple
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -89,7 +89,7 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
from airflow.models.dagbag import DagBag
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from airflow.models.dataset import DatasetModel
from airflow.models.log import Log
from airflow.models.param import process_params
from airflow.models.renderedtifields import get_serialized_template_fields
Expand Down Expand Up @@ -2893,7 +2893,7 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
# One task only triggers one dataset event for each dataset with the same extra.
# This tuple[dataset uri, extra] to sets alias names mapping is used to find whether
# there're datasets with same uri but different extra that we need to emit more than one dataset events.
dataset_tuple_to_alias_names_mapping: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
dataset_alias_names: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
for obj in self.task.outlets or []:
self.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides datasets
Expand All @@ -2908,42 +2908,37 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
for dataset_alias_event in events[obj].dataset_alias_events:
dataset_alias_name = dataset_alias_event["source_alias_name"]
dataset_uri = dataset_alias_event["dest_dataset_uri"]
extra = dataset_alias_event["extra"]
frozen_extra = frozenset(extra.items())
frozen_extra = frozenset(dataset_alias_event["extra"].items())
dataset_alias_names[(dataset_uri, frozen_extra)].add(dataset_alias_name)

dataset_tuple_to_alias_names_mapping[(dataset_uri, frozen_extra)].add(dataset_alias_name)
class _DatasetModelCache(Dict[str, DatasetModel]):
log = self.log

dataset_objs_cache: dict[str, DatasetModel] = {}
for (uri, extra_items), alias_names in dataset_tuple_to_alias_names_mapping.items():
if uri not in dataset_objs_cache:
dataset_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1))
dataset_objs_cache[uri] = dataset_obj
else:
dataset_obj = dataset_objs_cache[uri]

if not dataset_obj:
dataset_obj = DatasetModel(uri=uri)
dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session)
self.log.warning("Created a new %r as it did not exist.", dataset_obj)
def __missing__(self, key: str) -> DatasetModel:
(dataset_obj,) = dataset_manager.create_datasets([Dataset(uri=key)], session=session)
session.flush()
dataset_objs_cache[uri] = dataset_obj

for alias in alias_names:
alias_obj = session.scalar(
select(DatasetAliasModel).where(DatasetAliasModel.name == alias).limit(1)
)
dataset_obj.aliases.append(alias_obj)
self.log.warning("Created a new %r as it did not exist.", dataset_obj)
self[key] = dataset_obj
return dataset_obj

extra = {k: v for k, v in extra_items}
dataset_objs_cache = _DatasetModelCache(
(dataset_obj.uri, dataset_obj)
for dataset_obj in session.scalars(
select(DatasetModel).where(DatasetModel.uri.in_(uri for uri, _ in dataset_alias_names))
)
)
for (uri, extra_items), alias_names in dataset_alias_names.items():
dataset_obj = dataset_objs_cache[uri]
self.log.info(
'Creating event for %r through aliases "%s"',
dataset_obj,
", ".join(alias_names),
)
dataset_manager.register_dataset_change(
task_instance=self,
dataset=dataset_obj,
extra=extra,
dataset=dataset_obj.to_public(),
aliases=[DatasetAlias(name) for name in alias_names],
extra=dict(extra_items),
session=session,
source_alias_names=alias_names,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Dataset Events
--------------

- ``on_dataset_created``
- ``on_dataset_alias_created``
- ``on_dataset_changed``

Dataset events occur when Dataset management operations are run.
Expand Down
1 change: 1 addition & 0 deletions newsfragments/42343.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New function ``create_dataset_aliases`` added to DatasetManager for DatasetAlias creation.
7 changes: 7 additions & 0 deletions newsfragments/42343.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
``DatasetManager.create_datasets`` now takes ``Dataset`` objects

This function previously accepts a list of ``DatasetModel`` objects. it now
receives ``Dataset`` objects instead. A list of ``DatasetModel`` objects are
created inside, and returned by the function.

Also, the ``session`` argument is now keyword-only.
7 changes: 4 additions & 3 deletions tests/datasets/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ def test_create_datasets_notifies_dataset_listener(self, session):
dataset_listener.clear()
get_listener_manager().add_listener(dataset_listener)

dsm = DatasetModel(uri="test_dataset_uri_3")
ds = Dataset(uri="test_dataset_uri_3")

dsem.create_datasets([dsm], session)
dsms = dsem.create_datasets([ds], session=session)

# Ensure the listener was notified
assert len(dataset_listener.created) == 1
assert dataset_listener.created[0].uri == dsm.uri
assert len(dsms) == 1
assert dataset_listener.created[0].uri == ds.uri == dsms[0].uri