Skip to content
75 changes: 68 additions & 7 deletions sdks/python/apache_beam/io/gcp/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

"""Google Cloud PubSub sources and sinks.

Cloud Pub/Sub sources and sinks are currently supported only in streaming
pipelines, during remote execution.
Cloud Pub/Sub sources are currently supported only in streaming pipelines,
during remote execution. Cloud Pub/Sub sinks (WriteToPubSub) support both
streaming and batch pipelines.

This API is currently under development and is subject to change.

Expand All @@ -42,7 +43,6 @@
from apache_beam import coders
from apache_beam.io import iobase
from apache_beam.io.iobase import Read
from apache_beam.io.iobase import Write
from apache_beam.metrics.metric import Lineage
from apache_beam.transforms import DoFn
from apache_beam.transforms import Flatten
Expand Down Expand Up @@ -376,7 +376,12 @@ def report_lineage_once(self):


class WriteToPubSub(PTransform):
"""A ``PTransform`` for writing messages to Cloud Pub/Sub."""
"""A ``PTransform`` for writing messages to Cloud Pub/Sub.

This transform supports both streaming and batch pipelines. In streaming mode,
messages are written continuously as they arrive. In batch mode, all messages
are written when the pipeline completes.
"""

# Implementation note: This ``PTransform`` is overridden by Directrunner.

Expand Down Expand Up @@ -435,7 +440,7 @@ def expand(self, pcoll):
self.bytes_to_proto_str, self.project,
self.topic_name)).with_input_types(Union[bytes, str])
pcoll.element_type = bytes
return pcoll | Write(self._sink)
return pcoll | ParDo(_PubSubWriteDoFn(self))

def to_runner_api_parameter(self, context):
# Required as this is identified by type in PTransformOverrides.
Expand Down Expand Up @@ -541,11 +546,67 @@ def is_bounded(self):
return False


# TODO(BEAM-27443): Remove in favor of a proper WriteToPubSub transform.
class _PubSubWriteDoFn(DoFn):
"""DoFn for writing messages to Cloud Pub/Sub.

This DoFn handles both streaming and batch modes by buffering messages
and publishing them in batches to optimize performance.
"""
BUFFER_SIZE_ELEMENTS = 100
FLUSH_TIMEOUT_SECS = BUFFER_SIZE_ELEMENTS * 0.5

def __init__(self, transform):
self.project = transform.project
self.short_topic_name = transform.topic_name
self.id_label = transform.id_label
self.timestamp_attribute = transform.timestamp_attribute
self.with_attributes = transform.with_attributes

# TODO(https://github.com/apache/beam/issues/18939): Add support for
# id_label and timestamp_attribute.
if transform.id_label:
raise NotImplementedError('id_label is not supported for PubSub writes')
if transform.timestamp_attribute:
raise NotImplementedError(
'timestamp_attribute is not supported for PubSub writes')

def start_bundle(self):
self._buffer = []

def process(self, elem):
self._buffer.append(elem)
if len(self._buffer) >= self.BUFFER_SIZE_ELEMENTS:
self._flush()

def finish_bundle(self):
self._flush()

def _flush(self):
if not self._buffer:
return

from google.cloud import pubsub
import time

pub_client = pubsub.PublisherClient()
topic = pub_client.topic_path(self.project, self.short_topic_name)

# The elements in buffer are already serialized bytes from the previous
# transforms
futures = [pub_client.publish(topic, elem) for elem in self._buffer]

timer_start = time.time()
for future in futures:
remaining = self.FLUSH_TIMEOUT_SECS - (time.time() - timer_start)
future.result(remaining)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is there a flush timeout? Completing processing without waiting for all of the messages to be consumed by pubsub could lead to data loss

Copy link
Contributor Author

@liferoad liferoad Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the timeout exception, which should trigger Dataflow to retry for batch jobs. The idea is to avoid any stuckness when publishing the messages.

self._buffer = []


class _PubSubSink(object):
"""Sink for a Cloud Pub/Sub topic.

This ``NativeSource`` is overridden by a native Pubsub implementation.
This sink works for both streaming and batch pipelines by using a DoFn
that buffers and batches messages for efficient publishing.
"""
def __init__(
self,
Expand Down
85 changes: 85 additions & 0 deletions sdks/python/apache_beam/io/gcp/pubsub_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from apache_beam.io.gcp import pubsub_it_pipeline
from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.io.gcp.pubsub import WriteToPubSub
from apache_beam.io.gcp.tests.pubsub_matcher import PubSubMessageMatcher
from apache_beam.runners.runner import PipelineState
from apache_beam.testing import test_utils
Expand Down Expand Up @@ -220,6 +221,90 @@ def test_streaming_data_only(self):
def test_streaming_with_attributes(self):
self._test_streaming(with_attributes=True)

def _test_batch_write(self, with_attributes):
"""Tests batch mode WriteToPubSub functionality.

Args:
with_attributes: False - Writes message data only.
True - Writes message data and attributes.
"""
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.transforms import Create

# Create test messages for batch mode
test_messages = [
PubsubMessage(b'batch_data001', {'batch_attr': 'value1'}),
PubsubMessage(b'batch_data002', {'batch_attr': 'value2'}),
PubsubMessage(b'batch_data003', {'batch_attr': 'value3'})
]

pipeline_options = PipelineOptions()
# Explicitly set streaming to False for batch mode
pipeline_options.view_as(StandardOptions).streaming = False

with TestPipeline(options=pipeline_options) as p:
if with_attributes:
messages = p | 'CreateMessages' >> Create(test_messages)
_ = messages | 'WriteToPubSub' >> WriteToPubSub(
self.output_topic.name, with_attributes=True)
else:
# For data-only mode, extract just the data
message_data = [msg.data for msg in test_messages]
messages = p | 'CreateData' >> Create(message_data)
_ = messages | 'WriteToPubSub' >> WriteToPubSub(
self.output_topic.name, with_attributes=False)

# Verify messages were published by reading from the subscription
time.sleep(10) # Allow time for messages to be published and received

# Pull messages from the output subscription to verify they were written
response = self.sub_client.pull(
request={
"subscription": self.output_sub.name,
"max_messages": 10,
})

received_messages = []
for received_message in response.received_messages:
if with_attributes:
# Parse attributes
attrs = dict(received_message.message.attributes)
received_messages.append(
PubsubMessage(received_message.message.data, attrs))
else:
received_messages.append(received_message.message.data)

# Acknowledge the message
self.sub_client.acknowledge(
request={
"subscription": self.output_sub.name,
"ack_ids": [received_message.ack_id],
})

# Verify we received the expected number of messages
self.assertEqual(len(received_messages), len(test_messages))

if with_attributes:
# Verify message content and attributes
received_data = [msg.data for msg in received_messages]
expected_data = [msg.data for msg in test_messages]
self.assertEqual(sorted(received_data), sorted(expected_data))
else:
# Verify message data only
expected_data = [msg.data for msg in test_messages]
self.assertEqual(sorted(received_messages), sorted(expected_data))

@pytest.mark.it_postcommit
def test_batch_write_data_only(self):
"""Test WriteToPubSub in batch mode with data only."""
self._test_batch_write(with_attributes=False)

@pytest.mark.it_postcommit
def test_batch_write_with_attributes(self):
"""Test WriteToPubSub in batch mode with attributes."""
self._test_batch_write(with_attributes=True)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
Expand Down
65 changes: 58 additions & 7 deletions sdks/python/apache_beam/io/gcp/pubsub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,12 +867,14 @@ def test_write_messages_success(self, mock_pubsub):
| Create(payloads)
| WriteToPubSub(
'projects/fakeprj/topics/a_topic', with_attributes=False))
mock_pubsub.return_value.publish.assert_has_calls(
[mock.call(mock.ANY, data)])
# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_deprecated(self, mock_pubsub):
data = 'data'
data_bytes = b'data'
payloads = [data]

options = PipelineOptions([])
Expand All @@ -882,8 +884,11 @@ def test_write_messages_deprecated(self, mock_pubsub):
p
| Create(payloads)
| WriteStringsToPubSub('projects/fakeprj/topics/a_topic'))
mock_pubsub.return_value.publish.assert_has_calls(
[mock.call(mock.ANY, data_bytes)])
# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_with_attributes_success(self, mock_pubsub):
data = b'data'
Expand All @@ -898,8 +903,54 @@ def test_write_messages_with_attributes_success(self, mock_pubsub):
| Create(payloads)
| WriteToPubSub(
'projects/fakeprj/topics/a_topic', with_attributes=True))
mock_pubsub.return_value.publish.assert_has_calls(
[mock.call(mock.ANY, data, **attributes)])
# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_batch_mode_success(self, mock_pubsub):
"""Test WriteToPubSub works in batch mode (non-streaming)."""
data = 'data'
payloads = [data]

options = PipelineOptions([])
# Explicitly set streaming to False for batch mode
options.view_as(StandardOptions).streaming = False
with TestPipeline(options=options) as p:
_ = (
p
| Create(payloads)
| WriteToPubSub(
'projects/fakeprj/topics/a_topic', with_attributes=False))

# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_with_attributes_batch_mode_success(self, mock_pubsub):
"""Test WriteToPubSub with attributes works in batch mode."""
data = b'data'
attributes = {'key': 'value'}
payloads = [PubsubMessage(data, attributes)]

options = PipelineOptions([])
# Explicitly set streaming to False for batch mode
options.view_as(StandardOptions).streaming = False
with TestPipeline(options=options) as p:
_ = (
p
| Create(payloads)
| WriteToPubSub(
'projects/fakeprj/topics/a_topic', with_attributes=True))

# Verify that publish was called (data will be protobuf serialized)
mock_pubsub.return_value.publish.assert_called()
# Check that the call was made with the topic and some data
call_args = mock_pubsub.return_value.publish.call_args
self.assertEqual(len(call_args[0]), 2) # topic and data

def test_write_messages_with_attributes_error(self, mock_pubsub):
data = 'data'
Expand Down
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,14 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None):
# contain any added PTransforms.
pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)

# Apply DataflowRunner-specific overrides (e.g., streaming PubSub
# optimizations)
from apache_beam.runners.dataflow.ptransform_overrides import (
get_dataflow_transform_overrides)
dataflow_overrides = get_dataflow_transform_overrides(options)
if dataflow_overrides:
pipeline.replace_all(dataflow_overrides)

if options.view_as(DebugOptions).lookup_experiment('use_legacy_bq_sink'):
warnings.warn(
"Native sinks no longer implemented; "
Expand Down
65 changes: 63 additions & 2 deletions sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,70 @@

# pytype: skip-file

from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.pipeline import PTransformOverride


class StreamingPubSubWriteDoFnOverride(PTransformOverride):
"""Override ParDo(_PubSubWriteDoFn) for streaming mode in DataflowRunner.

This override specifically targets the final ParDo step in WriteToPubSub
and replaces it with Write(sink) for streaming optimization.
"""
def matches(self, applied_ptransform):
from apache_beam.transforms import ParDo
from apache_beam.io.gcp.pubsub import _PubSubWriteDoFn

if not isinstance(applied_ptransform.transform, ParDo):
return False

# Check if this ParDo uses _PubSubWriteDoFn
dofn = applied_ptransform.transform.dofn
return isinstance(dofn, _PubSubWriteDoFn)

def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
from apache_beam.io.iobase import Write

# Get the WriteToPubSub transform from the DoFn constructor parameter
dofn = applied_ptransform.transform.dofn

# The DoFn was initialized with the WriteToPubSub transform
# We need to reconstruct the sink from the DoFn's stored properties
if hasattr(dofn, 'project') and hasattr(dofn, 'short_topic_name'):
from apache_beam.io.gcp.pubsub import _PubSubSink

# Create a sink with the same properties as the original
topic = f"projects/{dofn.project}/topics/{dofn.short_topic_name}"
sink = _PubSubSink(
topic=topic,
id_label=getattr(dofn, 'id_label', None),
timestamp_attribute=getattr(dofn, 'timestamp_attribute', None))
return Write(sink)
else:
# Fallback: return the original transform if we can't reconstruct it
return applied_ptransform.transform


def get_dataflow_transform_overrides(pipeline_options):
"""Returns DataflowRunner-specific transform overrides.

Args:
pipeline_options: Pipeline options to determine which overrides to apply.

Returns:
List of PTransformOverride objects for DataflowRunner.
"""
overrides = []

# Only add streaming-specific overrides when in streaming mode
if pipeline_options.view_as(StandardOptions).streaming:
# Add PubSub ParDo streaming override that targets only the final step
overrides.append(StreamingPubSubWriteDoFnOverride())

return overrides


class NativeReadPTransformOverride(PTransformOverride):
"""A ``PTransformOverride`` for ``Read`` using native sources.

Expand Down Expand Up @@ -54,7 +115,7 @@ def expand(self, pbegin):
return pvalue.PCollection.from_(pbegin)

# Use the source's coder type hint as this replacement's output. Otherwise,
# the typing information is not properly forwarded to the DataflowRunner and
# will choose the incorrect coder for this transform.
# the typing information is not properly forwarded to the DataflowRunner
# and will choose the incorrect coder for this transform.
return Read(ptransform.source).with_output_types(
ptransform.source.coder.to_type_hint())
Loading
Loading