Skip to content

Commit

Permalink
Fix function call and return handling in test
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Sep 23, 2024
1 parent 39b3c51 commit a534051
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
44 changes: 28 additions & 16 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,11 @@ def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel:
self.notify_dataset_alias_created(dataset_alias=dataset_alias)
return models

@classmethod
def _slow_path_add_dataset_alias_association(
cls,
alias_names: Iterable[str],
dataset_model: DatasetModel,
) -> None:
# For databases not supporting ON CONFLICT DO NOTHING, we need to fetch
# the existing names to figure out what we can add.
existing = {alias.name for alias in dataset_model.aliases}
dataset_model.aliases.extend(DatasetAliasModel(name=n) for n in alias_names if n not in existing)

@classmethod
def _postgres_add_dataset_alias_association(
cls,
alias_names: Iterable[str],
dataset_id: int,
dataset: DatasetModel,
*,
session: Session,
) -> None:
Expand All @@ -116,11 +105,29 @@ def _postgres_add_dataset_alias_association(
insert(DatasetAliasModel.datasets.prop.secondary)
.from_select(
["alias_id", "dataset_id"],
select(DatasetAliasModel.id, dataset_id).where(DatasetAliasModel.name.in_(alias_names)),
select(DatasetAliasModel.id, dataset.id).where(DatasetAliasModel.name.in_(alias_names)),
)
.on_conflict_do_nothing()
)

@classmethod
def _slow_path_add_dataset_alias_association(
cls,
alias_names: Collection[str],
dataset: DatasetModel,
*,
session: Session,
) -> None:
# For databases not supporting ON CONFLICT DO NOTHING, we need to fetch
# the existing names to figure out what we can add.
already_related = {m.name for m in dataset.aliases}
existing_not_related = {
m.name: m
for m in session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names)))
if m.name not in already_related
}
dataset.aliases.extend(existing_not_related.get(n, DatasetAliasModel(name=n)) for n in alias_names)

@classmethod
@internal_api_call
@provide_session
Expand All @@ -145,18 +152,23 @@ def register_dataset_change(
dataset_model = session.scalar(
select(DatasetModel)
.where(DatasetModel.uri == dataset.uri)
.options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
.options(
joinedload(DatasetModel.aliases),
joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag),
)
)
if not dataset_model:
cls.logger().warning("DatasetModel %s not found", dataset)
return None

if session.bind.dialect.name == "postgresql":
cls._postgres_add_dataset_alias_association(
(alias.name for alias in aliases), dataset_model.id, session=session
(alias.name for alias in aliases), dataset_model, session=session
)
else:
cls._slow_path_add_dataset_alias_association((alias.name for alias in aliases), dataset_model)
cls._slow_path_add_dataset_alias_association(
{alias.name for alias in aliases}, dataset_model, session=session
)

event_kwargs = {
"dataset_id": dataset_model.id,
Expand Down
5 changes: 3 additions & 2 deletions tests/datasets/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ def test_create_datasets_notifies_dataset_listener(self, session):

ds = Dataset(uri="test_dataset_uri_3")

dsm = dsem.create_datasets([ds], session)
dsms = dsem.create_datasets([ds], session=session)

# Ensure the listener was notified
assert len(dataset_listener.created) == 1
assert dataset_listener.created[0].uri == ds.uri == dsm.uri
assert len(dsms) == 1
assert dataset_listener.created[0].uri == ds.uri == dsms[0].uri

0 comments on commit a534051

Please sign in to comment.