From 0f52eec6fbb6053a53313cc86c1792a746b5c307 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 11:06:02 +0800 Subject: [PATCH 01/11] Reduce database calls in _register_dataset_changes Instead of fetching DatasetModel one by one, do a bulk fetch into a dict to save roundtrips to the database. --- airflow/datasets/manager.py | 4 +-- airflow/models/dataset.py | 3 +++ airflow/models/taskinstance.py | 49 ++++++++++++++++------------------ 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 19f6913fffbeb..eb51848c3ab91 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -25,7 +25,6 @@ from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf -from airflow.datasets import Dataset from airflow.listeners.listener import get_listener_manager from airflow.models.dagbag import DagPriorityParsingRequest from airflow.models.dataset import ( @@ -43,6 +42,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from airflow.datasets import Dataset from airflow.models.dag import DagModel from airflow.models.taskinstance import TaskInstance @@ -63,7 +63,7 @@ def create_datasets(self, dataset_models: list[DatasetModel], session: Session) for dataset_model in dataset_models: session.add(dataset_model) for dataset_model in dataset_models: - self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra)) + self.notify_dataset_created(dataset=dataset_model.to_public()) @classmethod @internal_api_call diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index 5033da48a3059..642c486d238a3 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -200,6 +200,9 @@ def __hash__(self): def __repr__(self): return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" + def to_public(self) -> Dataset: + return Dataset(uri=self.uri, extra=self.extra) + class DagScheduleDatasetAliasReference(Base): """References from a DAG to a dataset alias of which it is a consumer.""" diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 954e5ed4d0c80..225281bef44a4 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -31,7 +31,7 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple +from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Generator, Iterable, Mapping, Tuple from urllib.parse import quote import dill @@ -2893,7 +2893,7 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se # One task only triggers one dataset event for each dataset with the same extra. # This tuple[dataset uri, extra] to sets alias names mapping is used to find whether # there're datasets with same uri but different extra that we need to emit more than one dataset events. - dataset_tuple_to_alias_names_mapping: dict[tuple[str, frozenset], set[str]] = defaultdict(set) + dataset_alias_names: dict[tuple[str, frozenset], set[str]] = defaultdict(set) for obj in self.task.outlets or []: self.log.debug("outlet obj %s", obj) # Lineage can have other types of objects besides datasets @@ -2908,33 +2908,30 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se for dataset_alias_event in events[obj].dataset_alias_events: dataset_alias_name = dataset_alias_event["source_alias_name"] dataset_uri = dataset_alias_event["dest_dataset_uri"] - extra = dataset_alias_event["extra"] - frozen_extra = frozenset(extra.items()) + frozen_extra = frozenset(dataset_alias_event["extra"].items()) + dataset_alias_names[(dataset_uri, frozen_extra)].add(dataset_alias_name) - dataset_tuple_to_alias_names_mapping[(dataset_uri, frozen_extra)].add(dataset_alias_name) + class _DatasetModelCache(Dict[str, DatasetModel]): + log = self.log - dataset_objs_cache: dict[str, DatasetModel] = {} - for (uri, extra_items), alias_names in dataset_tuple_to_alias_names_mapping.items(): - if uri not in dataset_objs_cache: - dataset_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1)) - dataset_objs_cache[uri] = dataset_obj - else: - dataset_obj = dataset_objs_cache[uri] - - if not dataset_obj: - dataset_obj = DatasetModel(uri=uri) + def __missing__(self, key: str) -> DatasetModel: + dataset_obj = self[key] = DatasetModel(uri=uri) dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session) - self.log.warning("Created a new %r as it did not exist.", dataset_obj) session.flush() - dataset_objs_cache[uri] = dataset_obj - - for alias in alias_names: - alias_obj = session.scalar( - select(DatasetAliasModel).where(DatasetAliasModel.name == alias).limit(1) - ) - dataset_obj.aliases.append(alias_obj) + self.log.warning("Created a new %r as it did not exist.", dataset_obj) + return dataset_obj - extra = {k: v for k, v in extra_items} + dataset_objs_cache = _DatasetModelCache( + (dataset_obj.uri, dataset_obj) + for dataset_obj in session.scalars( + select(DatasetModel).where(DatasetModel.uri.in_(uri for uri, _ in dataset_alias_names)) + ) + ) + for (uri, extra_items), alias_names in dataset_alias_names.items(): + dataset_obj = dataset_objs_cache[uri] + dataset_obj.aliases.extend( + session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names))) + ) self.log.info( 'Creating event for %r through aliases "%s"', dataset_obj, @@ -2942,8 +2939,8 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se ) dataset_manager.register_dataset_change( task_instance=self, - dataset=dataset_obj, - extra=extra, + dataset=dataset_obj.to_public(), + extra=dict(extra_items), session=session, source_alias_names=alias_names, ) From a65e3e7719dcf0702c84847ea2126b1c84d2291c Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 14:08:21 +0800 Subject: [PATCH 02/11] Pass alias information to DatasetManager --- airflow/datasets/manager.py | 17 ++++++++--------- airflow/models/dataset.py | 3 +++ airflow/models/taskinstance.py | 8 +++++--- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index eb51848c3ab91..ee5909f67c980 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Collection, Iterable from typing import TYPE_CHECKING from sqlalchemy import exc, select @@ -42,7 +42,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session - from airflow.datasets import Dataset + from airflow.datasets import Dataset, DatasetAlias from airflow.models.dag import DagModel from airflow.models.taskinstance import TaskInstance @@ -74,8 +74,9 @@ def register_dataset_change( task_instance: TaskInstance | None = None, dataset: Dataset, extra=None, - session: Session = NEW_SESSION, + aliases: Collection[DatasetAlias] = (), source_alias_names: Iterable[str] | None = None, + session: Session = NEW_SESSION, **kwargs, ) -> DatasetEvent | None: """ @@ -100,12 +101,10 @@ def register_dataset_change( } if task_instance: event_kwargs.update( - { - "source_task_id": task_instance.task_id, - "source_dag_id": task_instance.dag_id, - "source_run_id": task_instance.run_id, - "source_map_index": task_instance.map_index, - } + source_task_id=task_instance.task_id, + source_dag_id=task_instance.dag_id, + source_run_id=task_instance.run_id, + source_map_index=task_instance.map_index, ) dataset_event = DatasetEvent(**event_kwargs) diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index 642c486d238a3..489d6b68a6f15 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -138,6 +138,9 @@ def __eq__(self, other): else: return NotImplemented + def to_public(self) -> DatasetAlias: + return DatasetAlias(name=self.name) + class DatasetModel(Base): """ diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 225281bef44a4..b685de19b378f 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2929,9 +2929,10 @@ def __missing__(self, key: str) -> DatasetModel: ) for (uri, extra_items), alias_names in dataset_alias_names.items(): dataset_obj = dataset_objs_cache[uri] - dataset_obj.aliases.extend( - session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names))) - ) + aliases = session.scalars( + select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names)) + ).all() + dataset_obj.aliases.extend(aliases) self.log.info( 'Creating event for %r through aliases "%s"', dataset_obj, @@ -2940,6 +2941,7 @@ def __missing__(self, key: str) -> DatasetModel: dataset_manager.register_dataset_change( task_instance=self, dataset=dataset_obj.to_public(), + aliases=[a.to_public() for a in aliases], extra=dict(extra_items), session=session, source_alias_names=alias_names, From a167889fcd33f942f1b8d4c9a10dbb13539f4464 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 14:23:47 +0800 Subject: [PATCH 03/11] Move dataset alias association into dataset manager --- airflow/datasets/manager.py | 4 ++++ airflow/models/taskinstance.py | 8 ++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index ee5909f67c980..f19dbd025a85c 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -95,6 +95,10 @@ def register_dataset_change( cls.logger().warning("DatasetModel %s not found", dataset) return None + dataset_model.aliases = session.scalars( + select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias.name for alias in aliases)) + ) + event_kwargs = { "dataset_id": dataset_model.id, "extra": extra, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b685de19b378f..235d01bac9030 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -89,7 +89,7 @@ from airflow.listeners.listener import get_listener_manager from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel from airflow.models.dagbag import DagBag -from airflow.models.dataset import DatasetAliasModel, DatasetModel +from airflow.models.dataset import DatasetModel from airflow.models.log import Log from airflow.models.param import process_params from airflow.models.renderedtifields import get_serialized_template_fields @@ -2929,10 +2929,6 @@ def __missing__(self, key: str) -> DatasetModel: ) for (uri, extra_items), alias_names in dataset_alias_names.items(): dataset_obj = dataset_objs_cache[uri] - aliases = session.scalars( - select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names)) - ).all() - dataset_obj.aliases.extend(aliases) self.log.info( 'Creating event for %r through aliases "%s"', dataset_obj, @@ -2941,7 +2937,7 @@ def __missing__(self, key: str) -> DatasetModel: dataset_manager.register_dataset_change( task_instance=self, dataset=dataset_obj.to_public(), - aliases=[a.to_public() for a in aliases], + aliases=[DatasetAlias(name) for name in alias_names], extra=dict(extra_items), session=session, source_alias_names=alias_names, From b8888bb0be214269ddcabaecfa81d924157dbc0f Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 15:29:51 +0800 Subject: [PATCH 04/11] Create dataset things only in DatasetManager Prior to this commit, we already only create DatasetModel rows inside the manager. This also changes how DatasetAliasModel to only be created inside create_dataset_aliases, and only associated them to DatasetEvent in register_dataset_change. All the dataset manager functions are also changed to only accept public-facing dataset classes, instead of ORM models. The register_dataset_change function now takes an additional keyword argument 'aliases' that is a list of dataset aliases associated to the DatasetEvent to be created. --- airflow/dag_processing/collection.py | 37 ++++++++++++---------------- airflow/datasets/manager.py | 37 ++++++++++++++++++++++++---- airflow/models/taskinstance.py | 4 +-- newsfragments/99999.feature.rst | 1 + newsfragments/99999.significant.rst | 7 ++++++ tests/datasets/test_manager.py | 6 ++--- 6 files changed, 61 insertions(+), 31 deletions(-) create mode 100644 newsfragments/99999.feature.rst create mode 100644 newsfragments/99999.significant.rst diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 3f75e0b23bbfd..abbb5fe0a1a25 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -291,21 +291,15 @@ def add_datasets(self, *, session: Session) -> dict[str, DatasetModel]: dm.uri: dm for dm in session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets))) } - - def _resolve_dataset_addition() -> Iterator[DatasetModel]: - for uri, dataset in self.datasets.items(): - try: - dm = orm_datasets[uri] - except KeyError: - dm = orm_datasets[uri] = DatasetModel.from_public(dataset) - yield dm - else: - # The orphaned flag was bulk-set to True before parsing, so we - # don't need to handle rows in the db without a public entry. - dm.is_orphaned = expression.false() - dm.extra = dataset.extra - - dataset_manager.create_datasets(list(_resolve_dataset_addition()), session=session) + for model in orm_datasets.values(): + model.is_orphaned = expression.false() + orm_datasets.update( + (model.uri, model) + for model in dataset_manager.create_datasets( + [dataset for uri, dataset in self.datasets.items() if uri not in orm_datasets], + session=session, + ) + ) return orm_datasets def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]: @@ -318,12 +312,13 @@ def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasMode select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases)) ) } - for name, alias in self.dataset_aliases.items(): - try: - da = orm_aliases[name] - except KeyError: - da = orm_aliases[name] = DatasetAliasModel.from_public(alias) - session.add(da) + orm_aliases.update( + (model.name, model) + for model in dataset_manager.create_dataset_aliases( + [alias for name, alias in self.dataset_aliases.items() if name not in orm_aliases], + session=session, + ) + ) return orm_aliases def add_dag_dataset_references( diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index f19dbd025a85c..e96c7dbf78680 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -58,12 +58,36 @@ class DatasetManager(LoggingMixin): def __init__(self, **kwargs): super().__init__(**kwargs) - def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None: + def create_datasets(self, datasets: list[Dataset], *, session: Session) -> list[DatasetModel]: """Create new datasets.""" - for dataset_model in dataset_models: - session.add(dataset_model) - for dataset_model in dataset_models: - self.notify_dataset_created(dataset=dataset_model.to_public()) + + def _add_one(dataset: Dataset) -> DatasetModel: + model = DatasetModel.from_public(dataset) + session.add(model) + return model + + models = [_add_one(d) for d in datasets] + for dataset in datasets: + self.notify_dataset_created(dataset=dataset) + return models + + def create_dataset_aliases( + self, + dataset_aliases: list[DatasetAlias], + *, + session: Session, + ) -> list[DatasetAliasModel]: + """Create new dataset aliases.""" + + def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel: + model = DatasetAliasModel.from_public(dataset_alias) + session.add(model) + return model + + models = [_add_one(a) for a in dataset_aliases] + for dataset_alias in dataset_aliases: + self.notify_dataset_alias_created(dataset_alias=dataset_alias) + return models @classmethod @internal_api_call @@ -158,6 +182,9 @@ def notify_dataset_created(self, dataset: Dataset): """Run applicable notification actions when a dataset is created.""" get_listener_manager().hook.on_dataset_created(dataset=dataset) + def notify_dataset_alias_created(self, dataset_alias: DatasetAlias): + get_listener_manager().hook.on_dataset_alias_created(dataset_alias=dataset_alias) + @classmethod def notify_dataset_changed(cls, dataset: Dataset): """Run applicable notification actions when a dataset is changed.""" diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 235d01bac9030..d3300207abfdc 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2915,10 +2915,10 @@ class _DatasetModelCache(Dict[str, DatasetModel]): log = self.log def __missing__(self, key: str) -> DatasetModel: - dataset_obj = self[key] = DatasetModel(uri=uri) - dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session) + (dataset_obj,) = dataset_manager.create_datasets([Dataset(uri=key)], session=session) session.flush() self.log.warning("Created a new %r as it did not exist.", dataset_obj) + self[key] = dataset_obj return dataset_obj dataset_objs_cache = _DatasetModelCache( diff --git a/newsfragments/99999.feature.rst b/newsfragments/99999.feature.rst new file mode 100644 index 0000000000000..8a7cdf335a06e --- /dev/null +++ b/newsfragments/99999.feature.rst @@ -0,0 +1 @@ +New function ``create_dataset_aliases`` added to DatasetManager for DatasetAlias creation. diff --git a/newsfragments/99999.significant.rst b/newsfragments/99999.significant.rst new file mode 100644 index 0000000000000..d9e1ba6b1229b --- /dev/null +++ b/newsfragments/99999.significant.rst @@ -0,0 +1,7 @@ +``DatasetManager.create_datasets`` now takes ``Dataset`` objects + +This function previously accepts a list of ``DatasetModel`` objects. it now +receives ``Dataset`` objects instead. A list of ``DatasetModel`` objects are +created inside, and returned by the function. + +Also, the ``session`` argument is now keyword-only. diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index 1e7b4fda40cee..e1b8e8b2bc130 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -169,10 +169,10 @@ def test_create_datasets_notifies_dataset_listener(self, session): dataset_listener.clear() get_listener_manager().add_listener(dataset_listener) - dsm = DatasetModel(uri="test_dataset_uri_3") + ds = Dataset(uri="test_dataset_uri_3") - dsem.create_datasets([dsm], session) + dsm = dsem.create_datasets([ds], session) # Ensure the listener was notified assert len(dataset_listener.created) == 1 - assert dataset_listener.created[0].uri == dsm.uri + assert dataset_listener.created[0].uri == ds.uri == dsm.uri From 7bd6274ecb79b9df41fb06e556dd3b475ed37146 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 15:29:51 +0800 Subject: [PATCH 05/11] Create dataset things only in DatasetManager Prior to this commit, we already only create DatasetModel rows inside the manager. This also changes how DatasetAliasModel to only be created inside create_dataset_aliases, and only associated them to DatasetEvent in register_dataset_change. All the dataset manager functions are also changed to only accept public-facing dataset classes, instead of ORM models. The register_dataset_change function now takes an additional keyword argument 'aliases' that is a list of dataset aliases associated to the DatasetEvent to be created. --- airflow/datasets/manager.py | 1 + airflow/listeners/spec/dataset.py | 9 ++++++++- .../administration-and-deployment/listeners.rst | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index e96c7dbf78680..1a652dff07c82 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -183,6 +183,7 @@ def notify_dataset_created(self, dataset: Dataset): get_listener_manager().hook.on_dataset_created(dataset=dataset) def notify_dataset_alias_created(self, dataset_alias: DatasetAlias): + """Run applicable notification actions when a dataset alias is created.""" get_listener_manager().hook.on_dataset_alias_created(dataset_alias=dataset_alias) @classmethod diff --git a/airflow/listeners/spec/dataset.py b/airflow/listeners/spec/dataset.py index 214ddad3ffb13..eee1a10dd7d89 100644 --- a/airflow/listeners/spec/dataset.py +++ b/airflow/listeners/spec/dataset.py @@ -22,7 +22,7 @@ from pluggy import HookspecMarker if TYPE_CHECKING: - from airflow.datasets import Dataset + from airflow.datasets import Dataset, DatasetAlias hookspec = HookspecMarker("airflow") @@ -34,6 +34,13 @@ def on_dataset_created( """Execute when a new dataset is created.""" +@hookspec +def on_dataset_alias_created( + dataset_alias: DatasetAlias, +): + """Execute when a new dataset alias is created.""" + + @hookspec def on_dataset_changed( dataset: Dataset, diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst b/docs/apache-airflow/administration-and-deployment/listeners.rst index 34909e225aaa9..4926b12ed6c6d 100644 --- a/docs/apache-airflow/administration-and-deployment/listeners.rst +++ b/docs/apache-airflow/administration-and-deployment/listeners.rst @@ -95,6 +95,7 @@ Dataset Events -------------- - ``on_dataset_created`` +- ``on_dataset_alias_created`` - ``on_dataset_changed`` Dataset events occur when Dataset management operations are run. From a106a8ad4ded37160a40dc81e166ae480902b7bd Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 17:08:55 +0800 Subject: [PATCH 06/11] Efficient single query insert into through table --- airflow/datasets/manager.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 1a652dff07c82..28a8d6feca6a3 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -119,8 +119,15 @@ def register_dataset_change( cls.logger().warning("DatasetModel %s not found", dataset) return None - dataset_model.aliases = session.scalars( - select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias.name for alias in aliases)) + # This INSERTs directly into the association table of the many-to-many + # dataset_model.aliases relationship. I don't know how to do it in ORM. + session.execute( + DatasetAliasModel.datasets.prop.secondary.insert().from_select( + ["alias_id", "dataset_id"], + select(DatasetAliasModel.id, dataset_model.id).where( + DatasetAliasModel.name.in_(alias.name for alias in aliases) + ), + ) ) event_kwargs = { From 8179e9b30e7788b192aeaf00ec09a9ba1de192fa Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 20 Sep 2024 15:24:48 +0800 Subject: [PATCH 07/11] Conditionally use the performant query --- airflow/datasets/manager.py | 46 +++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 28a8d6feca6a3..e4684fcb92d68 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -89,6 +89,38 @@ def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel: self.notify_dataset_alias_created(dataset_alias=dataset_alias) return models + @classmethod + def _slow_path_add_dataset_alias_association( + cls, + alias_names: Iterable[str], + dataset_model: DatasetModel, + ) -> None: + # For databases not supporting ON CONFLICT DO NOTHING, we need to fetch + # the existing names to figure out what we can add. + existing = {alias.name for alias in dataset_model.aliases} + dataset_model.aliases.extend(DatasetAliasModel(name=n) for n in alias_names if n not in existing) + + @classmethod + def _postgres_add_dataset_alias_association( + cls, + alias_names: Iterable[str], + dataset_id: int, + *, + session: Session, + ) -> None: + from sqlalchemy.dialects.postgresql import insert + + # This INSERTs directly into the association table of the many-to-many + # dataset_model.aliases relationship. I don't know how to do it in ORM. + session.execute( + insert(DatasetAliasModel.datasets.prop.secondary) + .from_select( + ["alias_id", "dataset_id"], + select(DatasetAliasModel.id, dataset_id).where(DatasetAliasModel.name.in_(alias_names)), + ) + .on_conflict_do_nothing() + ) + @classmethod @internal_api_call @provide_session @@ -119,16 +151,12 @@ def register_dataset_change( cls.logger().warning("DatasetModel %s not found", dataset) return None - # This INSERTs directly into the association table of the many-to-many - # dataset_model.aliases relationship. I don't know how to do it in ORM. - session.execute( - DatasetAliasModel.datasets.prop.secondary.insert().from_select( - ["alias_id", "dataset_id"], - select(DatasetAliasModel.id, dataset_model.id).where( - DatasetAliasModel.name.in_(alias.name for alias in aliases) - ), + if session.bind.dialect.name == "postgresql": + cls._postgres_add_dataset_alias_association( + (alias.name for alias in aliases), dataset_model.id, session=session ) - ) + else: + cls._slow_path_add_dataset_alias_association((alias.name for alias in aliases), dataset_model) event_kwargs = { "dataset_id": dataset_model.id, From 39b3c5147e9f8ea57094f17c4088303e99b04c6f Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 20 Sep 2024 15:29:25 +0800 Subject: [PATCH 08/11] Change news fragment name --- newsfragments/{99999.feature.rst => 42343.feature.rst} | 0 newsfragments/{99999.significant.rst => 42343.significant.rst} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename newsfragments/{99999.feature.rst => 42343.feature.rst} (100%) rename newsfragments/{99999.significant.rst => 42343.significant.rst} (100%) diff --git a/newsfragments/99999.feature.rst b/newsfragments/42343.feature.rst similarity index 100% rename from newsfragments/99999.feature.rst rename to newsfragments/42343.feature.rst diff --git a/newsfragments/99999.significant.rst b/newsfragments/42343.significant.rst similarity index 100% rename from newsfragments/99999.significant.rst rename to newsfragments/42343.significant.rst From a534051fd044756677a069bbe5f296c824877bbe Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 20 Sep 2024 16:04:36 +0800 Subject: [PATCH 09/11] Fix function call and return handling in test --- airflow/datasets/manager.py | 44 +++++++++++++++++++++------------- tests/datasets/test_manager.py | 5 ++-- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index e4684fcb92d68..b1cbee5b2c387 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -89,22 +89,11 @@ def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel: self.notify_dataset_alias_created(dataset_alias=dataset_alias) return models - @classmethod - def _slow_path_add_dataset_alias_association( - cls, - alias_names: Iterable[str], - dataset_model: DatasetModel, - ) -> None: - # For databases not supporting ON CONFLICT DO NOTHING, we need to fetch - # the existing names to figure out what we can add. - existing = {alias.name for alias in dataset_model.aliases} - dataset_model.aliases.extend(DatasetAliasModel(name=n) for n in alias_names if n not in existing) - @classmethod def _postgres_add_dataset_alias_association( cls, alias_names: Iterable[str], - dataset_id: int, + dataset: DatasetModel, *, session: Session, ) -> None: @@ -116,11 +105,29 @@ def _postgres_add_dataset_alias_association( insert(DatasetAliasModel.datasets.prop.secondary) .from_select( ["alias_id", "dataset_id"], - select(DatasetAliasModel.id, dataset_id).where(DatasetAliasModel.name.in_(alias_names)), + select(DatasetAliasModel.id, dataset.id).where(DatasetAliasModel.name.in_(alias_names)), ) .on_conflict_do_nothing() ) + @classmethod + def _slow_path_add_dataset_alias_association( + cls, + alias_names: Collection[str], + dataset: DatasetModel, + *, + session: Session, + ) -> None: + # For databases not supporting ON CONFLICT DO NOTHING, we need to fetch + # the existing names to figure out what we can add. + already_related = {m.name for m in dataset.aliases} + existing_not_related = { + m.name: m + for m in session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names))) + if m.name not in already_related + } + dataset.aliases.extend(existing_not_related.get(n, DatasetAliasModel(name=n)) for n in alias_names) + @classmethod @internal_api_call @provide_session @@ -145,7 +152,10 @@ def register_dataset_change( dataset_model = session.scalar( select(DatasetModel) .where(DatasetModel.uri == dataset.uri) - .options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag)) + .options( + joinedload(DatasetModel.aliases), + joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag), + ) ) if not dataset_model: cls.logger().warning("DatasetModel %s not found", dataset) @@ -153,10 +163,12 @@ def register_dataset_change( if session.bind.dialect.name == "postgresql": cls._postgres_add_dataset_alias_association( - (alias.name for alias in aliases), dataset_model.id, session=session + (alias.name for alias in aliases), dataset_model, session=session ) else: - cls._slow_path_add_dataset_alias_association((alias.name for alias in aliases), dataset_model) + cls._slow_path_add_dataset_alias_association( + {alias.name for alias in aliases}, dataset_model, session=session + ) event_kwargs = { "dataset_id": dataset_model.id, diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index e1b8e8b2bc130..d3013aef60c29 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -171,8 +171,9 @@ def test_create_datasets_notifies_dataset_listener(self, session): ds = Dataset(uri="test_dataset_uri_3") - dsm = dsem.create_datasets([ds], session) + dsms = dsem.create_datasets([ds], session=session) # Ensure the listener was notified assert len(dataset_listener.created) == 1 - assert dataset_listener.created[0].uri == ds.uri == dsm.uri + assert len(dsms) == 1 + assert dataset_listener.created[0].uri == ds.uri == dsms[0].uri From bbc228964c7108d077c79f870b9075846764a5ca Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 23 Sep 2024 13:32:13 +0800 Subject: [PATCH 10/11] Keep it simple(r) --- airflow/datasets/manager.py | 43 +++++++------------------------------ 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index b1cbee5b2c387..83907e5297b30 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -90,43 +90,23 @@ def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel: return models @classmethod - def _postgres_add_dataset_alias_association( - cls, - alias_names: Iterable[str], - dataset: DatasetModel, - *, - session: Session, - ) -> None: - from sqlalchemy.dialects.postgresql import insert - - # This INSERTs directly into the association table of the many-to-many - # dataset_model.aliases relationship. I don't know how to do it in ORM. - session.execute( - insert(DatasetAliasModel.datasets.prop.secondary) - .from_select( - ["alias_id", "dataset_id"], - select(DatasetAliasModel.id, dataset.id).where(DatasetAliasModel.name.in_(alias_names)), - ) - .on_conflict_do_nothing() - ) - - @classmethod - def _slow_path_add_dataset_alias_association( + def _add_dataset_alias_association( cls, alias_names: Collection[str], dataset: DatasetModel, *, session: Session, ) -> None: - # For databases not supporting ON CONFLICT DO NOTHING, we need to fetch - # the existing names to figure out what we can add. already_related = {m.name for m in dataset.aliases} - existing_not_related = { + existing_aliases = { m.name: m for m in session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names))) - if m.name not in already_related } - dataset.aliases.extend(existing_not_related.get(n, DatasetAliasModel(name=n)) for n in alias_names) + dataset.aliases.extend( + existing_aliases.get(name, DatasetAliasModel(name=name)) + for name in alias_names + if name not in already_related + ) @classmethod @internal_api_call @@ -161,14 +141,7 @@ def register_dataset_change( cls.logger().warning("DatasetModel %s not found", dataset) return None - if session.bind.dialect.name == "postgresql": - cls._postgres_add_dataset_alias_association( - (alias.name for alias in aliases), dataset_model, session=session - ) - else: - cls._slow_path_add_dataset_alias_association( - {alias.name for alias in aliases}, dataset_model, session=session - ) + cls._add_dataset_alias_association({alias.name for alias in aliases}, dataset_model, session=session) event_kwargs = { "dataset_id": dataset_model.id, From 59d75c51ab05a19393483b4463b26a52b12a5d2e Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 25 Sep 2024 08:48:33 +0800 Subject: [PATCH 11/11] Notify the listener while we're adding things --- airflow/datasets/manager.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 83907e5297b30..c5ebb2e6d7eff 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -64,12 +64,10 @@ def create_datasets(self, datasets: list[Dataset], *, session: Session) -> list[ def _add_one(dataset: Dataset) -> DatasetModel: model = DatasetModel.from_public(dataset) session.add(model) + self.notify_dataset_created(dataset=dataset) return model - models = [_add_one(d) for d in datasets] - for dataset in datasets: - self.notify_dataset_created(dataset=dataset) - return models + return [_add_one(d) for d in datasets] def create_dataset_aliases( self, @@ -82,12 +80,10 @@ def create_dataset_aliases( def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel: model = DatasetAliasModel.from_public(dataset_alias) session.add(model) + self.notify_dataset_alias_created(dataset_alias=dataset_alias) return model - models = [_add_one(a) for a in dataset_aliases] - for dataset_alias in dataset_aliases: - self.notify_dataset_alias_created(dataset_alias=dataset_alias) - return models + return [_add_one(a) for a in dataset_aliases] @classmethod def _add_dataset_alias_association(