Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up grid_data endpoint by 10x #24284

Merged
merged 4 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion airflow/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from flask.json import JSONEncoder

from airflow.utils.timezone import convert_to_utc, is_naive

try:
import numpy as np
except ImportError:
Expand All @@ -45,7 +47,9 @@ def __init__(self, *args, **kwargs):
def _default(obj):
"""Convert dates and numpy objects in a json serializable format."""
if isinstance(obj, datetime):
return obj.strftime('%Y-%m-%dT%H:%M:%SZ')
if is_naive(obj):
obj = convert_to_utc(obj)
return obj.isoformat()
elif isinstance(obj, date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, Decimal):
Expand Down
14 changes: 8 additions & 6 deletions airflow/www/static/js/grid/components/InstanceTooltip.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const InstanceTooltip = ({
const summary = [];

const numMap = finalStatesMap();
let numMapped = 0;
if (isGroup) {
group.children.forEach((child) => {
const taskInstance = child.instances.find((ti) => ti.runId === runId);
Expand All @@ -44,9 +45,10 @@ const InstanceTooltip = ({
}
});
} else if (isMapped && mappedStates) {
mappedStates.forEach((s) => {
const stateKey = s || 'no_status';
if (numMap.has(stateKey)) numMap.set(stateKey, numMap.get(stateKey) + 1);
Object.keys(mappedStates).forEach((stateKey) => {
const num = mappedStates[stateKey];
numMapped += num;
numMap.set(stateKey || 'no_status', num);
});
}

Expand All @@ -68,12 +70,12 @@ const InstanceTooltip = ({
{group.tooltip && (
<Text>{group.tooltip}</Text>
)}
{isMapped && !!mappedStates.length && (
{isMapped && numMapped > 0 && (
<Text>
{mappedStates.length}
{numMapped}
{' '}
mapped task
{mappedStates.length > 1 && 's'}
{numMapped > 1 && 's'}
</Text>
)}
<Text>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ describe('Test Task InstanceTooltip', () => {
const { getByText } = render(
<InstanceTooltip
group={{ isMapped: true }}
instance={{ ...instance, mappedStates: ['success', 'success'] }}
instance={{ ...instance, mappedStates: { success: 2 } }}
/>,
{ wrapper: Wrapper },
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const Details = ({ instance, group, operator }) => {
} = group;

const numMap = finalStatesMap();
let numMapped = 0;
if (isGroup) {
children.forEach((child) => {
const taskInstance = child.instances.find((ti) => ti.runId === runId);
Expand All @@ -59,9 +60,10 @@ const Details = ({ instance, group, operator }) => {
}
});
} else if (isMapped && mappedStates) {
mappedStates.forEach((s) => {
const stateKey = s || 'no_status';
if (numMap.has(stateKey)) numMap.set(stateKey, numMap.get(stateKey) + 1);
Object.keys(mappedStates).forEach((stateKey) => {
const num = mappedStates[stateKey];
numMapped += num;
numMap.set(stateKey || 'no_status', num);
});
}

Expand Down Expand Up @@ -92,11 +94,11 @@ const Details = ({ instance, group, operator }) => {
<br />
</>
)}
{mappedStates && mappedStates.length > 0 && (
{mappedStates && numMapped > 0 && (
<Text>
{mappedStates.length}
{numMapped}
{' '}
{mappedStates.length === 1 ? 'Task ' : 'Tasks '}
{numMapped === 1 ? 'Task ' : 'Tasks '}
Mapped
</Text>
)}
Expand Down
49 changes: 0 additions & 49 deletions airflow/www/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@
from pygments import highlight, lexers
from pygments.formatters import HtmlFormatter
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.orm import Session

from airflow import models
from airflow.models import errors
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.code_utils import get_python_source
Expand Down Expand Up @@ -129,53 +127,6 @@ def get_mapped_summary(parent_instance, task_instances):
}


def get_task_summaries(task, dag_runs: List[DagRun], session: Session) -> List[Dict[str, Any]]:
tis = (
session.query(
TaskInstance.dag_id,
TaskInstance.task_id,
TaskInstance.run_id,
TaskInstance.map_index,
TaskInstance.state,
TaskInstance.start_date,
TaskInstance.end_date,
TaskInstance._try_number,
)
.filter(
TaskInstance.dag_id == task.dag_id,
TaskInstance.run_id.in_([dag_run.run_id for dag_run in dag_runs]),
TaskInstance.task_id == task.task_id,
# Only get normal task instances or the first mapped task
TaskInstance.map_index <= 0,
)
.order_by(TaskInstance.run_id.asc())
)

def _get_summary(task_instance):
if task_instance.map_index > -1:
return get_mapped_summary(
task_instance, task_instances=get_mapped_instances(task_instance, session)
)

try_count = (
task_instance._try_number
if task_instance._try_number != 0 or task_instance.state in State.running
else task_instance._try_number + 1
)

return {
'task_id': task_instance.task_id,
'run_id': task_instance.run_id,
'map_index': task_instance.map_index,
'state': task_instance.state,
'start_date': datetime_to_string(task_instance.start_date),
'end_date': datetime_to_string(task_instance.end_date),
'try_number': try_count,
}

return [_get_summary(ti) for ti in tis]


def encode_dag_run(dag_run: Optional[models.DagRun]) -> Optional[Dict[str, Any]]:
if not dag_run:
return None
Expand Down
186 changes: 137 additions & 49 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#
import collections
import copy
import itertools
import json
import logging
import math
Expand Down Expand Up @@ -252,64 +253,151 @@ def _safe_parse_datetime(v):
abort(400, f"Invalid datetime: {v!r}")


def task_group_to_grid(task_item_or_group, dag, dag_runs, session):
def dag_to_grid(dag, dag_runs, session):
"""
Create a nested dict representation of this TaskGroup and its children used to construct
the Graph.
Create a nested dict representation of the DAG's TaskGroup and its children
used to construct the Graph and Grid views.
"""
if isinstance(task_item_or_group, AbstractOperator):
return {
'id': task_item_or_group.task_id,
'instances': wwwutils.get_task_summaries(task_item_or_group, dag_runs, session),
'label': task_item_or_group.label,
'extra_links': task_item_or_group.extra_links,
'is_mapped': task_item_or_group.is_mapped,
}
query = (
session.query(
TaskInstance.task_id,
TaskInstance.run_id,
TaskInstance.state,
sqla.func.count(sqla.func.coalesce(TaskInstance.state, sqla.literal('no_status'))).label(
'state_count'
),
sqla.func.min(TaskInstance.start_date).label('start_date'),
sqla.func.max(TaskInstance.end_date).label('end_date'),
sqla.func.max(TaskInstance._try_number).label('_try_number'),
)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id.in_([dag_run.run_id for dag_run in dag_runs]),
)
.group_by(TaskInstance.task_id, TaskInstance.run_id, TaskInstance.state)
.order_by(TaskInstance.task_id, TaskInstance.run_id)
)

# Task Group
task_group = task_item_or_group
grouped_tis = {task_id: list(tis) for task_id, tis in itertools.groupby(query, key=lambda ti: ti.task_id)}

def task_group_to_grid(item, dag_runs, grouped_tis):
if isinstance(item, AbstractOperator):

def _get_summary(task_instance):
try_count = (
task_instance._try_number
if task_instance._try_number != 0 or task_instance.state in State.running
else task_instance._try_number + 1
)

return {
'task_id': task_instance.task_id,
'run_id': task_instance.run_id,
'state': task_instance.state,
'start_date': task_instance.start_date,
'end_date': task_instance.end_date,
'try_number': try_count,
}

def _mapped_summary(ti_summaries):
run_id = None
record = None

def set_overall_state(record):
for state in wwwutils.priority:
if state in record['mapped_states']:
record['state'] = state
break
if None in record['mapped_states']:
# When turnong the dict into JSON we can't have None as a key, so use the string that
# the UI does
record['mapped_states']['no_status'] = record['mapped_states'].pop(None)

for ti_summary in ti_summaries:
if ti_summary.state is None:
ti_summary.state == 'no_status'
if run_id != ti_summary.run_id:
run_id = ti_summary.run_id
if record:
set_overall_state(record)
yield record
record = {
'task_id': ti_summary.task_id,
'run_id': run_id,
'start_date': ti_summary.start_date,
'end_date': ti_summary.end_date,
'mapped_states': {ti_summary.state: ti_summary.state_count},
'state': None, # We change this before yielding
}
continue
record['start_date'] = min(
filter(None, [record['start_date'], ti_summary.start_date]), default=None
)
record['end_date'] = max(
filter(None, [record['end_date'], ti_summary.end_date]), default=None
)
record['mapped_states'][ti_summary.state] = ti_summary.state_count
if record:
set_overall_state(record)
yield record

if item.is_mapped:
instances = list(_mapped_summary(grouped_tis.get(item.task_id, [])))
else:
instances = list(map(_get_summary, grouped_tis.get(item.task_id, [])))

return {
'id': item.task_id,
'instances': instances,
'label': item.label,
'extra_links': item.extra_links,
'is_mapped': item.is_mapped,
}

children = [task_group_to_grid(child, dag, dag_runs, session) for child in task_group.topological_sort()]
# Task Group
task_group = item

def get_summary(dag_run, children):
child_instances = [child['instances'] for child in children if 'instances' in child]
child_instances = [
item for sublist in child_instances for item in sublist if item['run_id'] == dag_run.run_id
children = [
task_group_to_grid(child, dag_runs, grouped_tis) for child in task_group.topological_sort()
]

children_start_dates = [item['start_date'] for item in child_instances if item]
children_end_dates = [item['end_date'] for item in child_instances if item]
children_states = [item['state'] for item in child_instances if item]

group_state = None
for state in wwwutils.priority:
if state in children_states:
group_state = state
break
group_start_date = wwwutils.datetime_to_string(
min((timezone.parse(date) for date in children_start_dates if date), default=None)
)
group_end_date = wwwutils.datetime_to_string(
max((timezone.parse(date) for date in children_end_dates if date), default=None)
)
def get_summary(dag_run, children):
child_instances = [child['instances'] for child in children if 'instances' in child]
child_instances = [
item for sublist in child_instances for item in sublist if item['run_id'] == dag_run.run_id
]

children_start_dates = (item['start_date'] for item in child_instances if item)
children_end_dates = (item['end_date'] for item in child_instances if item)
children_states = {item['state'] for item in child_instances if item}

group_state = None
for state in wwwutils.priority:
if state in children_states:
group_state = state
break
group_start_date = min(filter(None, children_start_dates), default=None)
group_end_date = max(filter(None, children_end_dates), default=None)

return {
'task_id': task_group.group_id,
'run_id': dag_run.run_id,
'state': group_state,
'start_date': group_start_date,
'end_date': group_end_date,
}

group_summaries = [get_summary(dr, children) for dr in dag_runs]

return {
'task_id': task_group.group_id,
'run_id': dag_run.run_id,
'state': group_state,
'start_date': group_start_date,
'end_date': group_end_date,
'id': task_group.group_id,
'label': task_group.label,
'children': children,
'tooltip': task_group.tooltip,
'instances': group_summaries,
}

group_summaries = [get_summary(dr, children) for dr in dag_runs]

return {
'id': task_group.group_id,
'label': task_group.label,
'children': children,
'tooltip': task_group.tooltip,
'instances': group_summaries,
}
return task_group_to_grid(dag.task_group, dag_runs, grouped_tis)


def task_group_to_dict(task_item_or_group):
Expand Down Expand Up @@ -3535,12 +3623,12 @@ def grid_data(self):
dag_runs.reverse()
encoded_runs = [wwwutils.encode_dag_run(dr) for dr in dag_runs]
data = {
'groups': task_group_to_grid(dag.task_group, dag, dag_runs, session),
'groups': dag_to_grid(dag, dag_runs, session),
'dag_runs': encoded_runs,
}
# avoid spaces to reduce payload size
return (
htmlsafe_json_dumps(data, separators=(',', ':')),
htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
{'Content-Type': 'application/json; charset=utf-8'},
)

Expand Down
Loading