Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve getting the query count in Airflow API endpoints #32630

Merged
merged 4 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
from connexion import NoContent
from flask import g, request
from marshmallow import ValidationError
from sqlalchemy import func, select, update
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import or_

from airflow import DAG
from airflow.api_connexion import security
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters, get_query_count
from airflow.api_connexion.schemas.dag_schema import (
DAGCollection,
dag_detail_schema,
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_dags(
cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
dags_query = dags_query.where(or_(*cond))

total_entries = session.scalar(select(func.count()).select_from(dags_query))
total_entries = get_query_count(dags_query, session=session)
dags_query = apply_sorting(dags_query, order_by, {}, allowed_attrs)
dags = session.scalars(dags_query.offset(offset).limit(limit)).all()

Expand Down Expand Up @@ -159,7 +159,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
dags_query = dags_query.where(or_(*cond))

total_entries = session.scalar(select(func.count()).select_from(dags_query))
total_entries = get_query_count(dags_query, session=session)

dags = session.scalars(dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit)).all()

Expand Down
12 changes: 9 additions & 3 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flask import g
from flask_login import current_user
from marshmallow import ValidationError
from sqlalchemy import delete, func, or_, select
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

Expand All @@ -35,7 +35,13 @@
from airflow.api_connexion import security
from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
from airflow.api_connexion.parameters import (
apply_sorting,
check_limit,
format_datetime,
format_parameters,
get_query_count,
)
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection,
clear_dagrun_form_schema,
Expand Down Expand Up @@ -166,7 +172,7 @@ def _fetch_dag_runs(
if updated_at_lte:
query = query.where(DagRun.updated_at <= updated_at_lte)

total_entries = session.scalar(select(func.count()).select_from(query))
total_entries = get_query_count(query, session=session)
to_replace = {"dag_run_id": "run_id"}
allowed_filter_attrs = [
"id",
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# under the License.
from __future__ import annotations

from sqlalchemy import func, select
from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow.api_connexion import security
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters, get_query_count
from airflow.api_connexion.schemas.dag_warning_schema import (
DagWarningCollection,
dag_warning_collection_schema,
Expand Down Expand Up @@ -54,7 +54,7 @@ def get_dag_warnings(
query = query.where(DagWarningModel.dag_id == dag_id)
if warning_type:
query = query.where(DagWarningModel.warning_type == warning_type)
total_entries = session.scalar(select(func.count()).select_from(query))
total_entries = get_query_count(query, session=session)
query = apply_sorting(query=query, order_by=order_by, allowed_attrs=allowed_filter_attrs)
dag_warnings = session.scalars(query.offset(offset).limit(limit)).all()
return dag_warning_collection_schema.dump(
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters, get_query_count
from airflow.api_connexion.schemas.dataset_schema import (
DatasetCollection,
DatasetEventCollection,
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_dataset_events(

query = query.options(subqueryload(DatasetEvent.created_dagruns))

total_entries = session.scalar(select(func.count()).select_from(query))
total_entries = get_query_count(query, session=session)
query = apply_sorting(query, order_by, {}, allowed_attrs)
events = session.scalars(query.offset(offset).limit(limit)).all()
return dataset_event_collection_schema.dump(
Expand Down
13 changes: 6 additions & 7 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from typing import Any, Iterable, TypeVar

from marshmallow import ValidationError
from sqlalchemy import and_, func, or_, select
from sqlalchemy import and_, or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.sql import ClauseElement, Select

from airflow.api_connexion import security
from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.parameters import format_datetime, format_parameters
from airflow.api_connexion.parameters import format_datetime, format_parameters, get_query_count
from airflow.api_connexion.schemas.task_instance_schema import (
TaskInstanceCollection,
TaskInstanceReferenceCollection,
Expand Down Expand Up @@ -196,7 +196,7 @@ def get_mapped_task_instances(
)

# 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404
unfiltered_total_count = session.execute(select(func.count("*")).select_from(base_query)).scalar()
unfiltered_total_count = get_query_count(base_query, session=session)
if unfiltered_total_count == 0:
dag = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
Expand Down Expand Up @@ -229,7 +229,7 @@ def get_mapped_task_instances(
base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)

# Count elements before joining extra columns
total_entries = session.execute(select(func.count("*")).select_from(base_query)).scalar()
total_entries = get_query_count(base_query, session=session)

# Add SLA miss
entry_query = (
Expand Down Expand Up @@ -355,8 +355,7 @@ def get_task_instances(
base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)

# Count elements before joining extra columns
count_query = select(func.count("*")).select_from(base_query)
total_entries = session.execute(count_query).scalar()
total_entries = get_query_count(base_query, session=session)

# Add join
entry_query = (
Expand Down Expand Up @@ -420,7 +419,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
base_query = _apply_array_filter(base_query, key=TI.queue, values=data["queue"])

# Count elements before joining extra columns
total_entries = session.execute(select(func.count("*")).select_from(base_query)).scalar()
total_entries = get_query_count(base_query, session=session)
# Add join
base_query = base_query.join(
SlaMiss,
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_connexion/endpoints/xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import copy

from flask import g
from sqlalchemy import and_, func, select
from sqlalchemy import and_, select
from sqlalchemy.orm import Session

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.parameters import check_limit, format_parameters, get_query_count
from airflow.api_connexion.schemas.xcom_schema import XComCollection, xcom_collection_schema, xcom_schema
from airflow.api_connexion.types import APIResponse
from airflow.models import DagRun as DR, XCom
Expand Down Expand Up @@ -75,7 +75,7 @@ def get_xcom_entries(
if xcom_key is not None:
query = query.where(XCom.key == xcom_key)
query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, XCom.key)
total_entries = session.execute(select(func.count()).select_from(query)).scalar()
total_entries = get_query_count(query, session=session)
query = session.scalars(query.offset(offset).limit(limit))
return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries))

Expand Down
9 changes: 8 additions & 1 deletion airflow/api_connexion/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from functools import wraps
from typing import Any, Callable, Container, TypeVar, cast

import sqlalchemy.orm
from pendulum.parsing import ParserError
from sqlalchemy import text
from sqlalchemy import func, select, text
from sqlalchemy.sql import Select

from airflow.api_connexion.exceptions import BadRequest
Expand Down Expand Up @@ -125,3 +126,9 @@ def apply_sorting(
else:
order_by = f"{lstriped_orderby} asc"
return query.order_by(text(order_by))


def get_query_count(query_stmt: sqlalchemy.sql.selectable.Select, session: sqlalchemy.orm.Session) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a good idea to put this in airflow.utils.db instead? Importing this in airflow.www.views feels wrong.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is the order_by reset necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a good idea to put this in airflow.utils.db instead? Importing this in airflow.www.views feels wrong.

I agree, I just moved it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is the order_by reset necessary?

TBH, I am not sure how sqlalchemy processes this and generates the query, but according to the documentation, it just select from the subquery: <query> FROM (<subquery>) AS subquery, so if we compare this format with and without order by:

airflow=# EXPLAIN SELECT COUNT(1) FROM (SELECT * FROM dag_run WHERE state='failed' ORDER BY data_interval_start) AS subquery;
                              QUERY PLAN                              
----------------------------------------------------------------------
 Aggregate  (cost=1.04..1.05 rows=1 width=8)
   ->  Sort  (cost=1.02..1.03 rows=1 width=1459)
         Sort Key: dag_run.data_interval_start
         ->  Seq Scan on dag_run  (cost=0.00..1.01 rows=1 width=1459)
               Filter: ((state)::text = 'failed'::text)
(5 rows)

airflow=# EXPLAIN SELECT COUNT(1) FROM (SELECT * FROM dag_run WHERE state='failed') AS subquery;
                         QUERY PLAN                          
-------------------------------------------------------------
 Aggregate  (cost=1.01..1.02 rows=1 width=8)
   ->  Seq Scan on dag_run  (cost=0.00..1.01 rows=1 width=0)
         Filter: ((state)::text = 'failed'::text)
(3 rows)

There is no many rows in my DB (fresh breeze DB), so this slight difference could be much bigger in the big DB.

Personally I always reset the order_by when I don't need it to be safe and to ensure a better performance. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So seems like a db optimisation issue. Since this is in a function, adding the reset seems to have no drawbacks. Perhaps a comment explaining the query difference would prevent someone coming in and mistakenly assuming the order_by is superfulous.

"""Get count of query."""
count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery())
return session.scalar(count_stmt)
13 changes: 6 additions & 7 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
set_dag_run_state_to_success,
set_state,
)
from airflow.api_connexion.parameters import get_query_count
from airflow.configuration import AIRFLOW_CONFIG, auth_manager, conf
from airflow.datasets import Dataset
from airflow.exceptions import (
Expand Down Expand Up @@ -759,7 +760,7 @@ def index(self):
dags_query = dags_query.where(DagModel.tags.any(DagTag.name.in_(arg_tags_filter)))

dags_query = dags_query.where(DagModel.dag_id.in_(filter_dag_ids))
filtered_dag_count = session.scalar(select(func.count()).select_from(dags_query))
filtered_dag_count = get_query_count(dags_query, session=session)
if filtered_dag_count == 0 and len(arg_tags_filter):
flash(
"No matching DAG tags found.",
Expand Down Expand Up @@ -811,8 +812,8 @@ def index(self):
status_count_active = is_paused_count.get(False, 0)
status_count_paused = is_paused_count.get(True, 0)

status_count_running = session.scalar(select(func.count()).select_from(running_dags))
status_count_failed = session.scalar(select(func.count()).select_from(failed_dags))
status_count_running = get_query_count(running_dags, session=session)
status_count_failed = get_query_count(failed_dags, session=session)

all_dags_count = status_count_active + status_count_paused
if arg_status_filter == "active":
Expand Down Expand Up @@ -951,9 +952,7 @@ def _iter_parsed_moved_data_table_names():
.where(Log.event == "robots")
.where(Log.dttm > (utcnow() - datetime.timedelta(days=7)))
)
robots_file_access_count = session.scalar(
select(func.count()).select_from(robots_file_access_count)
)
robots_file_access_count = get_query_count(robots_file_access_count, session=session)
if robots_file_access_count > 0:
flash(
Markup(
Expand Down Expand Up @@ -4171,7 +4170,7 @@ def audit_log(self, dag_id: str, session: Session = NEW_SESSION):
arg_sorting_direction = request.args.get("sorting_direction", default="desc")

logs_per_page = PAGE_SIZE
audit_logs_count = session.scalar(select(func.count()).select_from(query))
audit_logs_count = get_query_count(query, session=session)
num_of_pages = int(math.ceil(audit_logs_count / float(logs_per_page)))

start = current_page * logs_per_page
Expand Down