Skip to content

Commit

Permalink
Refactor _register_dataset_changes (apache#42343)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored and ellisms committed Nov 13, 2024
1 parent 39843c7 commit 7020d50
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 68 deletions.
37 changes: 16 additions & 21 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,21 +299,15 @@ def add_datasets(self, *, session: Session) -> dict[str, DatasetModel]:
dm.uri: dm
for dm in session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets)))
}

def _resolve_dataset_addition() -> Iterator[DatasetModel]:
for uri, dataset in self.datasets.items():
try:
dm = orm_datasets[uri]
except KeyError:
dm = orm_datasets[uri] = DatasetModel.from_public(dataset)
yield dm
else:
# The orphaned flag was bulk-set to True before parsing, so we
# don't need to handle rows in the db without a public entry.
dm.is_orphaned = expression.false()
dm.extra = dataset.extra

dataset_manager.create_datasets(list(_resolve_dataset_addition()), session=session)
for model in orm_datasets.values():
model.is_orphaned = expression.false()
orm_datasets.update(
(model.uri, model)
for model in dataset_manager.create_datasets(
[dataset for uri, dataset in self.datasets.items() if uri not in orm_datasets],
session=session,
)
)
return orm_datasets

def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]:
Expand All @@ -326,12 +320,13 @@ def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasMode
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases))
)
}
for name, alias in self.dataset_aliases.items():
try:
da = orm_aliases[name]
except KeyError:
da = orm_aliases[name] = DatasetAliasModel.from_public(alias)
session.add(da)
orm_aliases.update(
(model.name, model)
for model in dataset_manager.create_dataset_aliases(
[alias for name, alias in self.dataset_aliases.items() if name not in orm_aliases],
session=session,
)
)
return orm_aliases

def add_dag_dataset_references(
Expand Down
77 changes: 62 additions & 15 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
# under the License.
from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Collection, Iterable
from typing import TYPE_CHECKING

from sqlalchemy import exc, select
from sqlalchemy.orm import joinedload

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.listeners.listener import get_listener_manager
from airflow.models.dagbag import DagPriorityParsingRequest
from airflow.models.dataset import (
Expand All @@ -43,6 +42,7 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.datasets import Dataset, DatasetAlias
from airflow.models.dag import DagModel
from airflow.models.taskinstance import TaskInstance

Expand All @@ -58,12 +58,51 @@ class DatasetManager(LoggingMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None:
def create_datasets(self, datasets: list[Dataset], *, session: Session) -> list[DatasetModel]:
"""Create new datasets."""
for dataset_model in dataset_models:
session.add(dataset_model)
for dataset_model in dataset_models:
self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))

def _add_one(dataset: Dataset) -> DatasetModel:
model = DatasetModel.from_public(dataset)
session.add(model)
self.notify_dataset_created(dataset=dataset)
return model

return [_add_one(d) for d in datasets]

def create_dataset_aliases(
self,
dataset_aliases: list[DatasetAlias],
*,
session: Session,
) -> list[DatasetAliasModel]:
"""Create new dataset aliases."""

def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel:
model = DatasetAliasModel.from_public(dataset_alias)
session.add(model)
self.notify_dataset_alias_created(dataset_alias=dataset_alias)
return model

return [_add_one(a) for a in dataset_aliases]

@classmethod
def _add_dataset_alias_association(
cls,
alias_names: Collection[str],
dataset: DatasetModel,
*,
session: Session,
) -> None:
already_related = {m.name for m in dataset.aliases}
existing_aliases = {
m.name: m
for m in session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names)))
}
dataset.aliases.extend(
existing_aliases.get(name, DatasetAliasModel(name=name))
for name in alias_names
if name not in already_related
)

@classmethod
@internal_api_call
Expand All @@ -74,8 +113,9 @@ def register_dataset_change(
task_instance: TaskInstance | None = None,
dataset: Dataset,
extra=None,
session: Session = NEW_SESSION,
aliases: Collection[DatasetAlias] = (),
source_alias_names: Iterable[str] | None = None,
session: Session = NEW_SESSION,
**kwargs,
) -> DatasetEvent | None:
"""
Expand All @@ -88,24 +128,27 @@ def register_dataset_change(
dataset_model = session.scalar(
select(DatasetModel)
.where(DatasetModel.uri == dataset.uri)
.options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
.options(
joinedload(DatasetModel.aliases),
joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag),
)
)
if not dataset_model:
cls.logger().warning("DatasetModel %s not found", dataset)
return None

cls._add_dataset_alias_association({alias.name for alias in aliases}, dataset_model, session=session)

event_kwargs = {
"dataset_id": dataset_model.id,
"extra": extra,
}
if task_instance:
event_kwargs.update(
{
"source_task_id": task_instance.task_id,
"source_dag_id": task_instance.dag_id,
"source_run_id": task_instance.run_id,
"source_map_index": task_instance.map_index,
}
source_task_id=task_instance.task_id,
source_dag_id=task_instance.dag_id,
source_run_id=task_instance.run_id,
source_map_index=task_instance.map_index,
)

dataset_event = DatasetEvent(**event_kwargs)
Expand Down Expand Up @@ -155,6 +198,10 @@ def notify_dataset_created(self, dataset: Dataset):
"""Run applicable notification actions when a dataset is created."""
get_listener_manager().hook.on_dataset_created(dataset=dataset)

def notify_dataset_alias_created(self, dataset_alias: DatasetAlias):
"""Run applicable notification actions when a dataset alias is created."""
get_listener_manager().hook.on_dataset_alias_created(dataset_alias=dataset_alias)

@classmethod
def notify_dataset_changed(cls, dataset: Dataset):
"""Run applicable notification actions when a dataset is changed."""
Expand Down
9 changes: 8 additions & 1 deletion airflow/listeners/spec/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pluggy import HookspecMarker

if TYPE_CHECKING:
from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAlias

hookspec = HookspecMarker("airflow")

Expand All @@ -34,6 +34,13 @@ def on_dataset_created(
"""Execute when a new dataset is created."""


@hookspec
def on_dataset_alias_created(
dataset_alias: DatasetAlias,
):
"""Execute when a new dataset alias is created."""


@hookspec
def on_dataset_changed(
dataset: Dataset,
Expand Down
6 changes: 6 additions & 0 deletions airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def __eq__(self, other):
else:
return NotImplemented

def to_public(self) -> DatasetAlias:
return DatasetAlias(name=self.name)


class DatasetModel(Base):
"""
Expand Down Expand Up @@ -200,6 +203,9 @@ def __hash__(self):
def __repr__(self):
return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"

def to_public(self) -> Dataset:
return Dataset(uri=self.uri, extra=self.extra)


class DagScheduleDatasetAliasReference(Base):
"""References from a DAG to a dataset alias of which it is a consumer."""
Expand Down
51 changes: 23 additions & 28 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from contextlib import nullcontext
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple
from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Generator, Iterable, Mapping, Tuple
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -89,7 +89,7 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
from airflow.models.dagbag import DagBag
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from airflow.models.dataset import DatasetModel
from airflow.models.log import Log
from airflow.models.param import process_params
from airflow.models.renderedtifields import get_serialized_template_fields
Expand Down Expand Up @@ -2893,7 +2893,7 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
# One task only triggers one dataset event for each dataset with the same extra.
# This tuple[dataset uri, extra] to sets alias names mapping is used to find whether
# there're datasets with same uri but different extra that we need to emit more than one dataset events.
dataset_tuple_to_alias_names_mapping: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
dataset_alias_names: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
for obj in self.task.outlets or []:
self.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides datasets
Expand All @@ -2908,42 +2908,37 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
for dataset_alias_event in events[obj].dataset_alias_events:
dataset_alias_name = dataset_alias_event["source_alias_name"]
dataset_uri = dataset_alias_event["dest_dataset_uri"]
extra = dataset_alias_event["extra"]
frozen_extra = frozenset(extra.items())
frozen_extra = frozenset(dataset_alias_event["extra"].items())
dataset_alias_names[(dataset_uri, frozen_extra)].add(dataset_alias_name)

dataset_tuple_to_alias_names_mapping[(dataset_uri, frozen_extra)].add(dataset_alias_name)
class _DatasetModelCache(Dict[str, DatasetModel]):
log = self.log

dataset_objs_cache: dict[str, DatasetModel] = {}
for (uri, extra_items), alias_names in dataset_tuple_to_alias_names_mapping.items():
if uri not in dataset_objs_cache:
dataset_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1))
dataset_objs_cache[uri] = dataset_obj
else:
dataset_obj = dataset_objs_cache[uri]

if not dataset_obj:
dataset_obj = DatasetModel(uri=uri)
dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session)
self.log.warning("Created a new %r as it did not exist.", dataset_obj)
def __missing__(self, key: str) -> DatasetModel:
(dataset_obj,) = dataset_manager.create_datasets([Dataset(uri=key)], session=session)
session.flush()
dataset_objs_cache[uri] = dataset_obj

for alias in alias_names:
alias_obj = session.scalar(
select(DatasetAliasModel).where(DatasetAliasModel.name == alias).limit(1)
)
dataset_obj.aliases.append(alias_obj)
self.log.warning("Created a new %r as it did not exist.", dataset_obj)
self[key] = dataset_obj
return dataset_obj

extra = {k: v for k, v in extra_items}
dataset_objs_cache = _DatasetModelCache(
(dataset_obj.uri, dataset_obj)
for dataset_obj in session.scalars(
select(DatasetModel).where(DatasetModel.uri.in_(uri for uri, _ in dataset_alias_names))
)
)
for (uri, extra_items), alias_names in dataset_alias_names.items():
dataset_obj = dataset_objs_cache[uri]
self.log.info(
'Creating event for %r through aliases "%s"',
dataset_obj,
", ".join(alias_names),
)
dataset_manager.register_dataset_change(
task_instance=self,
dataset=dataset_obj,
extra=extra,
dataset=dataset_obj.to_public(),
aliases=[DatasetAlias(name) for name in alias_names],
extra=dict(extra_items),
session=session,
source_alias_names=alias_names,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Dataset Events
--------------

- ``on_dataset_created``
- ``on_dataset_alias_created``
- ``on_dataset_changed``

Dataset events occur when Dataset management operations are run.
Expand Down
1 change: 1 addition & 0 deletions newsfragments/42343.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New function ``create_dataset_aliases`` added to DatasetManager for DatasetAlias creation.
7 changes: 7 additions & 0 deletions newsfragments/42343.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
``DatasetManager.create_datasets`` now takes ``Dataset`` objects

This function previously accepts a list of ``DatasetModel`` objects. it now
receives ``Dataset`` objects instead. A list of ``DatasetModel`` objects are
created inside, and returned by the function.

Also, the ``session`` argument is now keyword-only.
7 changes: 4 additions & 3 deletions tests/datasets/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ def test_create_datasets_notifies_dataset_listener(self, session):
dataset_listener.clear()
get_listener_manager().add_listener(dataset_listener)

dsm = DatasetModel(uri="test_dataset_uri_3")
ds = Dataset(uri="test_dataset_uri_3")

dsem.create_datasets([dsm], session)
dsms = dsem.create_datasets([ds], session=session)

# Ensure the listener was notified
assert len(dataset_listener.created) == 1
assert dataset_listener.created[0].uri == dsm.uri
assert len(dsms) == 1
assert dataset_listener.created[0].uri == ds.uri == dsms[0].uri

0 comments on commit 7020d50

Please sign in to comment.