From f3151491b318075d6e23e542b6a0150d665b8b47 Mon Sep 17 00:00:00 2001 From: Sachin Arora <61383002+arorasachin9@users.noreply.github.com> Date: Mon, 23 Dec 2024 19:23:02 +0530 Subject: [PATCH] Add support for custom celery configs (#45038) * #45037: Support for additional celery config directly from airflow.cfg file * Added unit test for the additional config and addressed comments * Added unit test for the additional config and addressed comments * Addressed comments: Added sample config in the docs as well as reference link for all the celery configs. * Addressed comments: Added sample config in the docs as well as reference link for all the celery configs. * Fixed the Unit tests * Fixed the Unit tests * Fixed the Unit tests --------- Co-authored-by: Sachin Arora --- .../providers/celery/executors/default_celery.py | 3 +++ providers/src/airflow/providers/celery/provider.yaml | 11 +++++++++++ .../tests/celery/executors/test_celery_executor.py | 9 +++++++++ 3 files changed, 23 insertions(+) diff --git a/providers/src/airflow/providers/celery/executors/default_celery.py b/providers/src/airflow/providers/celery/executors/default_celery.py index 75f8cc2bfdf43..20c307a77b04f 100644 --- a/providers/src/airflow/providers/celery/executors/default_celery.py +++ b/providers/src/airflow/providers/celery/executors/default_celery.py @@ -69,6 +69,8 @@ def _broker_supports_visibility_timeout(url): log.debug("Value for celery result_backend not found. Using sql_alchemy_conn with db+ prefix.") result_backend = f'db+{conf.get("database", "SQL_ALCHEMY_CONN")}' +extra_celery_config = conf.getjson("celery", "extra_celery_config", fallback={}) + DEFAULT_CELERY_CONFIG = { "accept_content": ["json"], "event_serializer": "json", @@ -85,6 +87,7 @@ def _broker_supports_visibility_timeout(url): ), "worker_concurrency": conf.getint("celery", "WORKER_CONCURRENCY", fallback=16), "worker_enable_remote_control": conf.getboolean("celery", "worker_enable_remote_control", fallback=True), + **(extra_celery_config if isinstance(extra_celery_config, dict) else {}), } diff --git a/providers/src/airflow/providers/celery/provider.yaml b/providers/src/airflow/providers/celery/provider.yaml index 906a76130456f..5a45989804614 100644 --- a/providers/src/airflow/providers/celery/provider.yaml +++ b/providers/src/airflow/providers/celery/provider.yaml @@ -330,6 +330,17 @@ config: type: string example: ~ default: "False" + extra_celery_config: + description: | + Extra celery configs to include in the celery worker. + Any of the celery config can be added to this config and it + will be applied while starting the celery worker. e.g. {"worker_max_tasks_per_child": 10} + See also: + https://docs.celeryq.dev/en/stable/userguide/configuration.html#configuration-and-defaults + version_added: ~ + type: string + example: ~ + default: "{{}}" celery_broker_transport_options: description: | This section is for specifying options which can be passed to the diff --git a/providers/tests/celery/executors/test_celery_executor.py b/providers/tests/celery/executors/test_celery_executor.py index 7dc918082b62a..7a33e0cfbc17c 100644 --- a/providers/tests/celery/executors/test_celery_executor.py +++ b/providers/tests/celery/executors/test_celery_executor.py @@ -399,3 +399,12 @@ def test_celery_task_acks_late_loaded_from_string(): # reload celery conf to apply the new config importlib.reload(default_celery) assert default_celery.DEFAULT_CELERY_CONFIG["task_acks_late"] is False + + +@conf_vars({("celery", "extra_celery_config"): '{"worker_max_tasks_per_child": 10}'}) +def test_celery_extra_celery_config_loaded_from_string(): + import importlib + + # reload celery conf to apply the new config + importlib.reload(default_celery) + assert default_celery.DEFAULT_CELERY_CONFIG["worker_max_tasks_per_child"] == 10