Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MessageDeduplicationId support to AWS SqsPublishOperator #45051

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions providers/src/airflow/providers/amazon/aws/hooks/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def send_message(
delay_seconds: int = 0,
message_attributes: dict | None = None,
message_group_id: str | None = None,
message_deduplication_id: str | None = None,
) -> dict:
"""
Send message to the queue.
Expand All @@ -71,6 +72,7 @@ def send_message(
:param delay_seconds: seconds to delay the message
:param message_attributes: additional attributes for the message (default: None)
:param message_group_id: This applies only to FIFO (first-in-first-out) queues. (default: None)
:param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues.
:return: dict with the information about the message sent
"""
params = {
Expand All @@ -81,5 +83,7 @@ def send_message(
}
if message_group_id:
params["MessageGroupId"] = message_group_id
if message_deduplication_id:
params["MessageDeduplicationId"] = message_deduplication_id

return self.get_conn().send_message(**params)
6 changes: 6 additions & 0 deletions providers/src/airflow/providers/amazon/aws/operators/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
:param delay_seconds: message delay (templated) (default: 1 second)
:param message_group_id: This parameter applies only to FIFO (first-in-first-out) queues. (default: None)
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues.
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -63,6 +65,7 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
"delay_seconds",
"message_attributes",
"message_group_id",
"message_deduplication_id",
)
template_fields_renderers = {"message_attributes": "json"}
ui_color = "#6ad3fa"
Expand All @@ -75,6 +78,7 @@ def __init__(
message_attributes: dict | None = None,
delay_seconds: int = 0,
message_group_id: str | None = None,
message_deduplication_id: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -83,6 +87,7 @@ def __init__(
self.delay_seconds = delay_seconds
self.message_attributes = message_attributes or {}
self.message_group_id = message_group_id
self.message_deduplication_id = message_deduplication_id

def execute(self, context: Context) -> dict:
"""
Expand All @@ -98,6 +103,7 @@ def execute(self, context: Context) -> dict:
delay_seconds=self.delay_seconds,
message_attributes=self.message_attributes,
message_group_id=self.message_group_id,
message_deduplication_id=self.message_deduplication_id,
)

self.log.info("send_message result: %s", result)
Expand Down
19 changes: 18 additions & 1 deletion providers/tests/amazon/aws/operators/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ def test_execute_failure_fifo_queue(self, mocked_context):
with pytest.raises(ClientError, match=error_message):
op.execute(mocked_context)

@mock_aws
def test_deduplication_failure(self, mocked_context):
self.sqs_client.create_queue(
QueueName=FIFO_QUEUE_NAME, Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "false"}
)

op = SqsPublishOperator(**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc")
error_message = (
r"An error occurred \(InvalidParameterValue\) when calling the SendMessage operation: "
r"The queue should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly"
)
with pytest.raises(ClientError, match=error_message):
op.execute(mocked_context)

@mock_aws
def test_execute_success_fifo_queue(self, mocked_context):
self.sqs_client.create_queue(
Expand All @@ -124,6 +138,9 @@ def test_execute_success_fifo_queue(self, mocked_context):

def test_template_fields(self):
operator = SqsPublishOperator(
**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc"
**self.default_op_kwargs,
sqs_queue=FIFO_QUEUE_NAME,
message_group_id="abc",
message_deduplication_id="abc",
)
validate_template_fields(operator)