From 99e7db32d79e8b2f3b02ed2ed454262074aac646 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 5 Mar 2025 11:55:32 +0100 Subject: [PATCH] fixup! fixup! Replace TaskInstanceNote composite primary key with TI.id --- airflow/models/taskinstance.py | 4 +++- tests/models/test_taskinstance.py | 13 +++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index eada4e6fc8611e..9e74ea91b13a53 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -3614,12 +3614,12 @@ def clear_db_references(self, session: Session): from airflow.models.renderedtifields import RenderedTaskInstanceFields tables: list[type[TaskInstanceDependencies]] = [ - TaskInstanceNote, TaskReschedule, XCom, RenderedTaskInstanceFields, TaskMap, ] + tables_by_id: list[type[Base]] = [TaskInstanceNote] for table in tables: session.execute( delete(table).where( @@ -3629,6 +3629,8 @@ def clear_db_references(self, session: Session): table.map_index == self.map_index, ) ) + for table in tables_by_id: + session.execute(delete(table).where(table.id == self.id)) @classmethod def duration_expression_update( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 5b73f3ed3fa65a..fa502b2f45d0cd 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3993,15 +3993,14 @@ def test_clear_db_references(self, session, create_task_instance): for table in tables: assert session.query(table).count() == 1 - filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) - ti_note = session.query(TaskInstanceNote).filter_by(**filter_kwargs).one() + ti_note = session.query(TaskInstanceNote).filter_by(id=ti.id).one() assert ti_note.content == "sample note" ti.clear_db_references(session) for table in tables: assert session.query(table).count() == 0 - assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None + assert session.query(TaskInstanceNote).filter_by(id=ti.id).one_or_none() is None def test_skipped_task_call_on_skipped_callback(self, dag_maker): def raise_skip_exception(): @@ -4969,16 +4968,14 @@ def test_taskinstance_with_note(create_task_instance, session): session.add(ti) session.commit() - filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) - - ti_note: TaskInstanceNote = session.query(TaskInstanceNote).filter_by(**filter_kwargs).one() + ti_note: TaskInstanceNote = session.query(TaskInstanceNote).filter_by(id=ti.id).one() assert ti_note.content == "ti with note" session.delete(ti) session.commit() - assert session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None - assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None + assert session.query(TaskInstance).filter_by(id=ti.id).one_or_none() is None + assert session.query(TaskInstanceNote).filter_by(id=ti.id).one_or_none() is None def test__refresh_from_db_should_not_increment_try_number(dag_maker, session):