Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve getting the query count in Airflow API endpoints #32630

Merged
merged 4 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from connexion import NoContent
from flask import g, request
from marshmallow import ValidationError
from sqlalchemy import func, select, update
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import or_

Expand All @@ -41,6 +41,7 @@
from airflow.models.dag import DagModel, DagTag
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session


Expand Down Expand Up @@ -99,7 +100,7 @@ def get_dags(
cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
dags_query = dags_query.where(or_(*cond))

total_entries = session.scalar(select(func.count()).select_from(dags_query))
total_entries = get_query_count(dags_query, session=session)
dags_query = apply_sorting(dags_query, order_by, {}, allowed_attrs)
dags = session.scalars(dags_query.offset(offset).limit(limit)).all()

Expand Down Expand Up @@ -159,7 +160,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
dags_query = dags_query.where(or_(*cond))

total_entries = session.scalar(select(func.count()).select_from(dags_query))
total_entries = get_query_count(dags_query, session=session)

dags = session.scalars(dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit)).all()

Expand Down
12 changes: 9 additions & 3 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flask import g
from flask_login import current_user
from marshmallow import ValidationError
from sqlalchemy import delete, func, or_, select
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

Expand All @@ -35,7 +35,12 @@
from airflow.api_connexion import security
from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
from airflow.api_connexion.parameters import (
apply_sorting,
check_limit,
format_datetime,
format_parameters,
)
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection,
clear_dagrun_form_schema,
Expand All @@ -57,6 +62,7 @@
from airflow.models import DagModel, DagRun
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.log.action_logger import action_event_from_permission
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
Expand Down Expand Up @@ -166,7 +172,7 @@ def _fetch_dag_runs(
if updated_at_lte:
query = query.where(DagRun.updated_at <= updated_at_lte)

total_entries = session.scalar(select(func.count()).select_from(query))
total_entries = get_query_count(query, session=session)
to_replace = {"dag_run_id": "run_id"}
allowed_filter_attrs = [
"id",
Expand Down
5 changes: 3 additions & 2 deletions airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from sqlalchemy import func, select
from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow.api_connexion import security
Expand All @@ -28,6 +28,7 @@
from airflow.api_connexion.types import APIResponse
from airflow.models.dagwarning import DagWarning as DagWarningModel
from airflow.security import permissions
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session


Expand All @@ -54,7 +55,7 @@ def get_dag_warnings(
query = query.where(DagWarningModel.dag_id == dag_id)
if warning_type:
query = query.where(DagWarningModel.warning_type == warning_type)
total_entries = session.scalar(select(func.count()).select_from(query))
total_entries = get_query_count(query, session=session)
query = apply_sorting(query=query, order_by=order_by, allowed_attrs=allowed_filter_attrs)
dag_warnings = session.scalars(query.offset(offset).limit(limit)).all()
return dag_warning_collection_schema.dump(
Expand Down
3 changes: 2 additions & 1 deletion airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.api_connexion.types import APIResponse
from airflow.models.dataset import DatasetEvent, DatasetModel
from airflow.security import permissions
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session


Expand Down Expand Up @@ -112,7 +113,7 @@ def get_dataset_events(

query = query.options(subqueryload(DatasetEvent.created_dagruns))

total_entries = session.scalar(select(func.count()).select_from(query))
total_entries = get_query_count(query, session=session)
query = apply_sorting(query, order_by, {}, allowed_attrs)
events = session.scalars(query.offset(offset).limit(limit)).all()
return dataset_event_collection_schema.dump(
Expand Down
12 changes: 6 additions & 6 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Iterable, TypeVar

from marshmallow import ValidationError
from sqlalchemy import and_, func, or_, select
from sqlalchemy import and_, or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.sql import ClauseElement, Select
Expand Down Expand Up @@ -48,6 +48,7 @@
from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, TaskInstanceState

Expand Down Expand Up @@ -196,7 +197,7 @@ def get_mapped_task_instances(
)

# 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404
unfiltered_total_count = session.execute(select(func.count("*")).select_from(base_query)).scalar()
unfiltered_total_count = get_query_count(base_query, session=session)
if unfiltered_total_count == 0:
dag = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
Expand Down Expand Up @@ -229,7 +230,7 @@ def get_mapped_task_instances(
base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)

# Count elements before joining extra columns
total_entries = session.execute(select(func.count("*")).select_from(base_query)).scalar()
total_entries = get_query_count(base_query, session=session)

# Add SLA miss
entry_query = (
Expand Down Expand Up @@ -355,8 +356,7 @@ def get_task_instances(
base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)

# Count elements before joining extra columns
count_query = select(func.count("*")).select_from(base_query)
total_entries = session.execute(count_query).scalar()
total_entries = get_query_count(base_query, session=session)

# Add join
entry_query = (
Expand Down Expand Up @@ -420,7 +420,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
base_query = _apply_array_filter(base_query, key=TI.queue, values=data["queue"])

# Count elements before joining extra columns
total_entries = session.execute(select(func.count("*")).select_from(base_query)).scalar()
total_entries = get_query_count(base_query, session=session)
# Add join
base_query = base_query.join(
SlaMiss,
Expand Down
5 changes: 3 additions & 2 deletions airflow/api_connexion/endpoints/xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import copy

from flask import g
from sqlalchemy import and_, func, select
from sqlalchemy import and_, select
from sqlalchemy.orm import Session

from airflow.api_connexion import security
Expand All @@ -31,6 +31,7 @@
from airflow.security import permissions
from airflow.settings import conf
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session


Expand Down Expand Up @@ -75,7 +76,7 @@ def get_xcom_entries(
if xcom_key is not None:
query = query.where(XCom.key == xcom_key)
query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, XCom.key)
total_entries = session.execute(select(func.count()).select_from(query)).scalar()
total_entries = get_query_count(query, session=session)
query = session.scalars(query.offset(offset).limit(limit))
return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries))

Expand Down
9 changes: 9 additions & 0 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.orm import Query, Session
from sqlalchemy.sql.selectable import Select

from airflow.models.base import Base
from airflow.models.connection import Connection
Expand Down Expand Up @@ -1872,3 +1873,11 @@ def get_sqla_model_classes():
return [mapper.class_ for mapper in Base.registry.mappers]
except AttributeError:
return Base._decl_class_registry.values()


def get_query_count(query_stmt: Select, session: Session) -> int:
"""Get count of query."""
# Remove ORDER BY clause from the subquery statement since it's unnecessary for count
# in order to improve the query performance.
hussein-awala marked this conversation as resolved.
Show resolved Hide resolved
count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery())
return session.scalar(count_stmt)
13 changes: 6 additions & 7 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.dag_edges import dag_edges
from airflow.utils.dates import infer_time_unit, scale_time_units
from airflow.utils.db import get_query_count
from airflow.utils.docs import get_doc_url_for_provider, get_docs_url
from airflow.utils.helpers import alchemy_to_dict, exactly_one
from airflow.utils.log import secrets_masker
Expand Down Expand Up @@ -759,7 +760,7 @@ def index(self):
dags_query = dags_query.where(DagModel.tags.any(DagTag.name.in_(arg_tags_filter)))

dags_query = dags_query.where(DagModel.dag_id.in_(filter_dag_ids))
filtered_dag_count = session.scalar(select(func.count()).select_from(dags_query))
filtered_dag_count = get_query_count(dags_query, session=session)
if filtered_dag_count == 0 and len(arg_tags_filter):
flash(
"No matching DAG tags found.",
Expand Down Expand Up @@ -811,8 +812,8 @@ def index(self):
status_count_active = is_paused_count.get(False, 0)
status_count_paused = is_paused_count.get(True, 0)

status_count_running = session.scalar(select(func.count()).select_from(running_dags))
status_count_failed = session.scalar(select(func.count()).select_from(failed_dags))
status_count_running = get_query_count(running_dags, session=session)
status_count_failed = get_query_count(failed_dags, session=session)

all_dags_count = status_count_active + status_count_paused
if arg_status_filter == "active":
Expand Down Expand Up @@ -951,9 +952,7 @@ def _iter_parsed_moved_data_table_names():
.where(Log.event == "robots")
.where(Log.dttm > (utcnow() - datetime.timedelta(days=7)))
)
robots_file_access_count = session.scalar(
select(func.count()).select_from(robots_file_access_count)
)
robots_file_access_count = get_query_count(robots_file_access_count, session=session)
if robots_file_access_count > 0:
flash(
Markup(
Expand Down Expand Up @@ -4171,7 +4170,7 @@ def audit_log(self, dag_id: str, session: Session = NEW_SESSION):
arg_sorting_direction = request.args.get("sorting_direction", default="desc")

logs_per_page = PAGE_SIZE
audit_logs_count = session.scalar(select(func.count()).select_from(query))
audit_logs_count = get_query_count(query, session=session)
num_of_pages = int(math.ceil(audit_logs_count / float(logs_per_page)))

start = current_page * logs_per_page
Expand Down