Skip to content

Commit

Permalink
fixup! fixup! Replace TaskInstanceNote composite primary key with TI.id
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Mar 5, 2025
1 parent 7dc6d1b commit 99e7db3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
4 changes: 3 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3614,12 +3614,12 @@ def clear_db_references(self, session: Session):
from airflow.models.renderedtifields import RenderedTaskInstanceFields

tables: list[type[TaskInstanceDependencies]] = [
TaskInstanceNote,
TaskReschedule,
XCom,
RenderedTaskInstanceFields,
TaskMap,
]
tables_by_id: list[type[Base]] = [TaskInstanceNote]
for table in tables:
session.execute(
delete(table).where(
Expand All @@ -3629,6 +3629,8 @@ def clear_db_references(self, session: Session):
table.map_index == self.map_index,
)
)
for table in tables_by_id:
session.execute(delete(table).where(table.id == self.id))

@classmethod
def duration_expression_update(
Expand Down
13 changes: 5 additions & 8 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3993,15 +3993,14 @@ def test_clear_db_references(self, session, create_task_instance):
for table in tables:
assert session.query(table).count() == 1

filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index)
ti_note = session.query(TaskInstanceNote).filter_by(**filter_kwargs).one()
ti_note = session.query(TaskInstanceNote).filter_by(id=ti.id).one()
assert ti_note.content == "sample note"

ti.clear_db_references(session)
for table in tables:
assert session.query(table).count() == 0

assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None
assert session.query(TaskInstanceNote).filter_by(id=ti.id).one_or_none() is None

def test_skipped_task_call_on_skipped_callback(self, dag_maker):
def raise_skip_exception():
Expand Down Expand Up @@ -4969,16 +4968,14 @@ def test_taskinstance_with_note(create_task_instance, session):
session.add(ti)
session.commit()

filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index)

ti_note: TaskInstanceNote = session.query(TaskInstanceNote).filter_by(**filter_kwargs).one()
ti_note: TaskInstanceNote = session.query(TaskInstanceNote).filter_by(id=ti.id).one()
assert ti_note.content == "ti with note"

session.delete(ti)
session.commit()

assert session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None
assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None
assert session.query(TaskInstance).filter_by(id=ti.id).one_or_none() is None
assert session.query(TaskInstanceNote).filter_by(id=ti.id).one_or_none() is None


def test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
Expand Down

0 comments on commit 99e7db3

Please sign in to comment.