Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add deferrable mode to the PubSubPullOperator #45835

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 56 additions & 19 deletions providers/src/airflow/providers/google/cloud/operators/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@
SchemaSettings,
)

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.pubsub import PubSubHook
from airflow.providers.google.cloud.links.pubsub import PubSubSubscriptionLink, PubSubTopicLink
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID

if TYPE_CHECKING:
Expand Down Expand Up @@ -746,6 +749,9 @@ class PubSubPullOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: If True, run the task in the deferrable mode.
:param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
The default is 300 seconds.
"""

template_fields: Sequence[str] = (
Expand All @@ -764,6 +770,8 @@ def __init__(
messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
poll_interval: int = 300,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -774,32 +782,61 @@ def __init__(
self.ack_messages = ack_messages
self.messages_callback = messages_callback
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.poll_interval = poll_interval

def execute(self, context: Context) -> list:
hook = PubSubHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

pulled_messages = hook.pull(
project_id=self.project_id,
subscription=self.subscription,
max_messages=self.max_messages,
return_immediately=True,
)

handle_messages = self.messages_callback or self._default_message_callback

ret = handle_messages(pulled_messages, context)
if self.deferrable:
self.defer(
trigger=PubsubPullTrigger(
subscription=self.subscription,
project_id=self.project_id,
max_messages=self.max_messages,
ack_messages=self.ack_messages,
gcp_conn_id=self.gcp_conn_id,
poke_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)
else:
hook = PubSubHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

if pulled_messages and self.ack_messages:
hook.acknowledge(
pulled_messages = hook.pull(
project_id=self.project_id,
subscription=self.subscription,
messages=pulled_messages,
max_messages=self.max_messages,
return_immediately=True,
)

return ret
handle_messages = self.messages_callback or self._default_message_callback

ret = handle_messages(pulled_messages, context)

if pulled_messages and self.ack_messages:
hook.acknowledge(
project_id=self.project_id,
subscription=self.subscription,
messages=pulled_messages,
)

return ret

def execute_complete(self, context: Context, event: dict[str, Any]):
"""If messages_callback is provided, execute it; otherwise, return immediately with trigger event message."""
if event["status"] == "success":
self.log.info("Sensor pulls messages: %s", event["message"])
if self.messages_callback:
received_messages = self._convert_to_received_messages(event["message"])
_return_value = self.messages_callback(received_messages, context)
return _return_value

return event["message"]
self.log.info("Sensor failed: %s", event["message"])
raise AirflowException(event["message"])

def _default_message_callback(
self,
Expand Down
20 changes: 20 additions & 0 deletions providers/tests/google/cloud/operators/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from typing import Any
from unittest import mock

import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.exceptions import TaskDeferred
from airflow.providers.google.cloud.operators.pubsub import (
PubSubCreateSubscriptionOperator,
PubSubCreateTopicOperator,
Expand Down Expand Up @@ -337,3 +339,21 @@ def messages_callback(
messages_callback.assert_called_once()

assert response == messages_callback_return_value

@mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
def test_execute_deferred(self, mock_hook, create_task_instance_of_operator):
"""
Asserts that a task is deferred and a PubSubPullOperator will be fired
when the PubSubPullOperator is executed with deferrable=True.
"""
ti = create_task_instance_of_operator(
PubSubPullOperator,
dag_id="dag_id",
task_id=TASK_ID,
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
deferrable=True,
)

with pytest.raises(TaskDeferred) as _:
ti.task.execute(mock.MagicMock())
Loading