Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 1
"modification": 2
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 12
"modification": 13
}
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Python examples added for CloudSQL enrichment handler on [Beam website](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment-cloudsql/) (Python) ([#35473](https://github.com/apache/beam/issues/36095)).
* Support for batch mode execution in WriteToPubSub transform added (Python) ([#35990](https://github.com/apache/beam/issues/35990)).

## Breaking Changes

Expand Down
83 changes: 76 additions & 7 deletions sdks/python/apache_beam/io/gcp/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

"""Google Cloud PubSub sources and sinks.

Cloud Pub/Sub sources and sinks are currently supported only in streaming
pipelines, during remote execution.
Cloud Pub/Sub sources are currently supported only in streaming pipelines,
during remote execution. Cloud Pub/Sub sinks (WriteToPubSub) support both
streaming and batch pipelines.

This API is currently under development and is subject to change.

Expand All @@ -42,7 +43,6 @@
from apache_beam import coders
from apache_beam.io import iobase
from apache_beam.io.iobase import Read
from apache_beam.io.iobase import Write
from apache_beam.metrics.metric import Lineage
from apache_beam.transforms import DoFn
from apache_beam.transforms import Flatten
Expand Down Expand Up @@ -376,7 +376,12 @@ def report_lineage_once(self):


class WriteToPubSub(PTransform):
"""A ``PTransform`` for writing messages to Cloud Pub/Sub."""
"""A ``PTransform`` for writing messages to Cloud Pub/Sub.

This transform supports both streaming and batch pipelines. In streaming mode,
messages are written continuously as they arrive. In batch mode, all messages
are written when the pipeline completes.
"""

# Implementation note: This ``PTransform`` is overridden by Directrunner.

Expand Down Expand Up @@ -435,7 +440,7 @@ def expand(self, pcoll):
self.bytes_to_proto_str, self.project,
self.topic_name)).with_input_types(Union[bytes, str])
pcoll.element_type = bytes
return pcoll | Write(self._sink)
return pcoll | ParDo(_PubSubWriteDoFn(self))

def to_runner_api_parameter(self, context):
# Required as this is identified by type in PTransformOverrides.
Expand Down Expand Up @@ -541,11 +546,75 @@ def is_bounded(self):
return False


# TODO(BEAM-27443): Remove in favor of a proper WriteToPubSub transform.
class _PubSubWriteDoFn(DoFn):
"""DoFn for writing messages to Cloud Pub/Sub.

This DoFn handles both streaming and batch modes by buffering messages
and publishing them in batches to optimize performance.
"""
BUFFER_SIZE_ELEMENTS = 100
FLUSH_TIMEOUT_SECS = 5 * 60 # 5 minutes

def __init__(self, transform):
self.project = transform.project
self.short_topic_name = transform.topic_name
self.id_label = transform.id_label
self.timestamp_attribute = transform.timestamp_attribute
self.with_attributes = transform.with_attributes

# TODO(https://github.com/apache/beam/issues/18939): Add support for
# id_label and timestamp_attribute.
if transform.id_label:
raise NotImplementedError('id_label is not supported for PubSub writes')
if transform.timestamp_attribute:
raise NotImplementedError(
'timestamp_attribute is not supported for PubSub writes')

def setup(self):
from google.cloud import pubsub
self._pub_client = pubsub.PublisherClient()
self._topic = self._pub_client.topic_path(
self.project, self.short_topic_name)

def start_bundle(self):
self._buffer = []

def process(self, elem):
self._buffer.append(elem)
if len(self._buffer) >= self.BUFFER_SIZE_ELEMENTS:
self._flush()

def finish_bundle(self):
self._flush()

def _flush(self):
if not self._buffer:
return

import time

# The elements in buffer are already serialized bytes from the previous
# transforms
futures = [
self._pub_client.publish(self._topic, elem) for elem in self._buffer
]

timer_start = time.time()
for future in futures:
remaining = self.FLUSH_TIMEOUT_SECS - (time.time() - timer_start)
if remaining <= 0:
raise TimeoutError(
f"PubSub publish timeout exceeded {self.FLUSH_TIMEOUT_SECS} seconds"
)
future.result(remaining)
self._buffer = []


class _PubSubSink(object):
"""Sink for a Cloud Pub/Sub topic.

This ``NativeSource`` is overridden by a native Pubsub implementation.
This sink works for both streaming and batch pipelines by using a DoFn
that buffers and batches messages for efficient publishing.
"""
def __init__(
self,
Expand Down
85 changes: 85 additions & 0 deletions sdks/python/apache_beam/io/gcp/pubsub_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from apache_beam.io.gcp import pubsub_it_pipeline
from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.io.gcp.pubsub import WriteToPubSub
from apache_beam.io.gcp.tests.pubsub_matcher import PubSubMessageMatcher
from apache_beam.runners.runner import PipelineState
from apache_beam.testing import test_utils
Expand Down Expand Up @@ -220,6 +221,90 @@ def test_streaming_data_only(self):
def test_streaming_with_attributes(self):
self._test_streaming(with_attributes=True)

def _test_batch_write(self, with_attributes):
"""Tests batch mode WriteToPubSub functionality.

Args:
with_attributes: False - Writes message data only.
True - Writes message data and attributes.
"""
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.transforms import Create

# Create test messages for batch mode
test_messages = [
PubsubMessage(b'batch_data001', {'batch_attr': 'value1'}),
PubsubMessage(b'batch_data002', {'batch_attr': 'value2'}),
PubsubMessage(b'batch_data003', {'batch_attr': 'value3'})
]

pipeline_options = PipelineOptions()
# Explicitly set streaming to False for batch mode
pipeline_options.view_as(StandardOptions).streaming = False

with TestPipeline(options=pipeline_options) as p:
if with_attributes:
messages = p | 'CreateMessages' >> Create(test_messages)
_ = messages | 'WriteToPubSub' >> WriteToPubSub(
self.output_topic.name, with_attributes=True)
else:
# For data-only mode, extract just the data
message_data = [msg.data for msg in test_messages]
messages = p | 'CreateData' >> Create(message_data)
_ = messages | 'WriteToPubSub' >> WriteToPubSub(
self.output_topic.name, with_attributes=False)

# Verify messages were published by reading from the subscription
time.sleep(10) # Allow time for messages to be published and received

# Pull messages from the output subscription to verify they were written
response = self.sub_client.pull(
request={
"subscription": self.output_sub.name,
"max_messages": 10,
})

received_messages = []
for received_message in response.received_messages:
if with_attributes:
# Parse attributes
attrs = dict(received_message.message.attributes)
received_messages.append(
PubsubMessage(received_message.message.data, attrs))
else:
received_messages.append(received_message.message.data)

# Acknowledge the message
self.sub_client.acknowledge(
request={
"subscription": self.output_sub.name,
"ack_ids": [received_message.ack_id],
})

# Verify we received the expected number of messages
self.assertEqual(len(received_messages), len(test_messages))

if with_attributes:
# Verify message content and attributes
received_data = [msg.data for msg in received_messages]
expected_data = [msg.data for msg in test_messages]
self.assertEqual(sorted(received_data), sorted(expected_data))
else:
# Verify message data only
expected_data = [msg.data for msg in test_messages]
self.assertEqual(sorted(received_messages), sorted(expected_data))

@pytest.mark.it_postcommit
def test_batch_write_data_only(self):
"""Test WriteToPubSub in batch mode with data only."""
self._test_batch_write(with_attributes=False)

@pytest.mark.it_postcommit
def test_batch_write_with_attributes(self):
"""Test WriteToPubSub in batch mode with attributes."""
self._test_batch_write(with_attributes=True)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
Expand Down
65 changes: 58 additions & 7 deletions sdks/python/apache_beam/io/gcp/pubsub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,12 +867,14 @@ def test_write_messages_success(self, mock_pubsub):
| Create(payloads)
| WriteToPubSub(
'projects/fakeprj/topics/a_topic', with_attributes=False))
mock_pubsub.return_value.publish.assert_has_calls(
[mock.call(mock.ANY, data)])
# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_deprecated(self, mock_pubsub):
data = 'data'
data_bytes = b'data'
payloads = [data]

options = PipelineOptions([])
Expand All @@ -882,8 +884,11 @@ def test_write_messages_deprecated(self, mock_pubsub):
p
| Create(payloads)
| WriteStringsToPubSub('projects/fakeprj/topics/a_topic'))
mock_pubsub.return_value.publish.assert_has_calls(
[mock.call(mock.ANY, data_bytes)])
# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_with_attributes_success(self, mock_pubsub):
data = b'data'
Expand All @@ -898,8 +903,54 @@ def test_write_messages_with_attributes_success(self, mock_pubsub):
| Create(payloads)
| WriteToPubSub(
'projects/fakeprj/topics/a_topic', with_attributes=True))
mock_pubsub.return_value.publish.assert_has_calls(
[mock.call(mock.ANY, data, **attributes)])
# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_batch_mode_success(self, mock_pubsub):
"""Test WriteToPubSub works in batch mode (non-streaming)."""
data = 'data'
payloads = [data]

options = PipelineOptions([])
# Explicitly set streaming to False for batch mode
options.view_as(StandardOptions).streaming = False
with TestPipeline(options=options) as p:
_ = (
p
| Create(payloads)
| WriteToPubSub(
'projects/fakeprj/topics/a_topic', with_attributes=False))

# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_with_attributes_batch_mode_success(self, mock_pubsub):
"""Test WriteToPubSub with attributes works in batch mode."""
data = b'data'
attributes = {'key': 'value'}
payloads = [PubsubMessage(data, attributes)]

options = PipelineOptions([])
# Explicitly set streaming to False for batch mode
options.view_as(StandardOptions).streaming = False
with TestPipeline(options=options) as p:
_ = (
p
| Create(payloads)
| WriteToPubSub(
'projects/fakeprj/topics/a_topic', with_attributes=True))

# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_with_attributes_error(self, mock_pubsub):
data = 'data'
Expand Down
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,14 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None):
# contain any added PTransforms.
pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)

# Apply DataflowRunner-specific overrides (e.g., streaming PubSub
# optimizations)
from apache_beam.runners.dataflow.ptransform_overrides import (
get_dataflow_transform_overrides)
dataflow_overrides = get_dataflow_transform_overrides(options)
if dataflow_overrides:
pipeline.replace_all(dataflow_overrides)

if options.view_as(DebugOptions).lookup_experiment('use_legacy_bq_sink'):
warnings.warn(
"Native sinks no longer implemented; "
Expand Down
Loading
Loading