Skip to content

Commit

Permalink
Add better typing in the db clean utils (#42341)
Browse files Browse the repository at this point in the history
  • Loading branch information
jedcunningham authored Sep 19, 2024
1 parent 447b221 commit 5b0b830
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions airflow/utils/db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def readable_config(self):
config_dict: dict[str, _TableConfig] = {x.orm_model.name: x for x in sorted(config_list)}


def _check_for_rows(*, query: Query, print_rows=False):
def _check_for_rows(*, query: Query, print_rows: bool = False) -> int:
num_entities = query.count()
print(f"Found {num_entities} rows meeting deletion criteria.")
if print_rows:
Expand All @@ -142,7 +142,7 @@ def _check_for_rows(*, query: Query, print_rows=False):
return num_entities


def _dump_table_to_file(*, target_table, file_path, export_format, session):
def _dump_table_to_file(*, target_table: str, file_path: str, export_format: str, session: Session) -> None:
if export_format == "csv":
with open(file_path, "w") as f:
csv_writer = csv.writer(f)
Expand All @@ -153,7 +153,7 @@ def _dump_table_to_file(*, target_table, file_path, export_format, session):
raise AirflowException(f"Export format {export_format} is not supported.")


def _do_delete(*, query, orm_model, skip_archive, session):
def _do_delete(*, query: Query, orm_model: Base, skip_archive: bool, session: Session) -> None:
import re2

print("Performing Delete...")
Expand Down Expand Up @@ -203,7 +203,9 @@ def _do_delete(*, query, orm_model, skip_archive, session):
print("Finished Performing Delete")


def _subquery_keep_last(*, recency_column, keep_last_filters, group_by_columns, max_date_colname, session):
def _subquery_keep_last(
*, recency_column, keep_last_filters, group_by_columns, max_date_colname, session: Session
) -> Query:
subquery = select(*group_by_columns, func.max(recency_column).label(max_date_colname))

if keep_last_filters is not None:
Expand Down Expand Up @@ -238,10 +240,10 @@ def _build_query(
keep_last,
keep_last_filters,
keep_last_group_by,
clean_before_timestamp,
session,
clean_before_timestamp: DateTime,
session: Session,
**kwargs,
):
) -> Query:
base_table_alias = "base"
base_table = aliased(orm_model, name=base_table_alias)
query = session.query(base_table).with_entities(text(f"{base_table_alias}.*"))
Expand Down Expand Up @@ -276,13 +278,13 @@ def _cleanup_table(
keep_last,
keep_last_filters,
keep_last_group_by,
clean_before_timestamp,
dry_run=True,
verbose=False,
skip_archive=False,
session,
clean_before_timestamp: DateTime,
dry_run: bool = True,
verbose: bool = False,
skip_archive: bool = False,
session: Session,
**kwargs,
):
) -> None:
print()
if dry_run:
print(f"Performing dry run for table {orm_model.name}")
Expand All @@ -305,7 +307,7 @@ def _cleanup_table(
session.commit()


def _confirm_delete(*, date: DateTime, tables: list[str]):
def _confirm_delete(*, date: DateTime, tables: list[str]) -> None:
for_tables = f" for tables {tables!r}" if tables else ""
question = (
f"You have requested that we purge all data prior to {date}{for_tables}.\n"
Expand All @@ -319,7 +321,7 @@ def _confirm_delete(*, date: DateTime, tables: list[str]):
raise SystemExit("User did not confirm; exiting.")


def _confirm_drop_archives(*, tables: list[str]):
def _confirm_drop_archives(*, tables: list[str]) -> None:
# if length of tables is greater than 3, show the total count
if len(tables) > 3:
text_ = f"{len(tables)} archived tables prefixed with {ARCHIVE_TABLE_PREFIX}"
Expand All @@ -341,13 +343,13 @@ def _confirm_drop_archives(*, tables: list[str]):
raise SystemExit("User did not confirm; exiting.")


def _print_config(*, configs: dict[str, _TableConfig]):
def _print_config(*, configs: dict[str, _TableConfig]) -> None:
data = [x.readable_config for x in configs.values()]
AirflowConsole().print_as_table(data=data)


@contextmanager
def _suppress_with_logging(table, session):
def _suppress_with_logging(table: str, session: Session):
"""
Suppresses errors but logs them.
Expand All @@ -363,7 +365,7 @@ def _suppress_with_logging(table, session):
session.rollback()


def _effective_table_names(*, table_names: list[str] | None):
def _effective_table_names(*, table_names: list[str] | None) -> tuple[set[str], dict[str, _TableConfig]]:
desired_table_names = set(table_names or config_dict)
effective_config_dict = {k: v for k, v in config_dict.items() if k in desired_table_names}
effective_table_names = set(effective_config_dict)
Expand All @@ -377,7 +379,7 @@ def _effective_table_names(*, table_names: list[str] | None):
return effective_table_names, effective_config_dict


def _get_archived_table_names(table_names, session):
def _get_archived_table_names(table_names: list[str] | None, session: Session) -> list[str]:
inspector = inspect(session.bind)
db_table_names = [x for x in inspector.get_table_names() if x.startswith(ARCHIVE_TABLE_PREFIX)]
effective_table_names, _ = _effective_table_names(table_names=table_names)
Expand All @@ -400,7 +402,7 @@ def run_cleanup(
confirm: bool = True,
skip_archive: bool = False,
session: Session = NEW_SESSION,
):
) -> None:
"""
Purges old records in airflow metadata database.
Expand Down Expand Up @@ -450,13 +452,13 @@ def run_cleanup(

@provide_session
def export_archived_records(
export_format,
output_path,
table_names=None,
drop_archives=False,
needs_confirm=True,
export_format: str,
output_path: str,
table_names: list[str] | None = None,
drop_archives: bool = False,
needs_confirm: bool = True,
session: Session = NEW_SESSION,
):
) -> None:
"""Export archived data to the given output path in the given format."""
archived_table_names = _get_archived_table_names(table_names, session)
# If user chose to drop archives, check there are archive tables that exists
Expand All @@ -482,7 +484,9 @@ def export_archived_records(


@provide_session
def drop_archived_tables(table_names, needs_confirm, session):
def drop_archived_tables(
table_names: list[str] | None, needs_confirm: bool, session: Session = NEW_SESSION
) -> None:
"""Drop archived tables."""
archived_table_names = _get_archived_table_names(table_names, session)
if needs_confirm and archived_table_names:
Expand Down

0 comments on commit 5b0b830

Please sign in to comment.