Skip to content

Commit 8ea7d84

Browse files
committed
feat(spanner): add ClientContext support to options
This change adds support for ClientContext in Options and ensures it is propagated to ExecuteSql, Read, Commit, and BeginTransaction requests. It aligns with go/spanner-client-scoped-session-state design. ClientContext allows passing opaque, RPC-scoped side-channel information (like application-level user context) to Spanner. This implementation supports setting ClientContext at the Client, Database, and Request levels, with request-level options taking precedence. Key changes: - Added ClientContext to types/spanner.py and exposed it. - Updated Client.__init__ to accept a default client_context. - Added helpers for merging ClientContext with correct precedence. - Updated Snapshot, Transaction, Batch, and Database wrappers to propagate the context. - Added comprehensive unit tests in tests/unit/test_client_context.py.
1 parent 7e79920 commit 8ea7d84

File tree

10 files changed

+605
-28
lines changed

10 files changed

+605
-28
lines changed

google/cloud/spanner_v1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .types.spanner import BatchWriteRequest
3939
from .types.spanner import BatchWriteResponse
4040
from .types.spanner import BeginTransactionRequest
41+
from .types.spanner import ClientContext
4142
from .types.spanner import CommitRequest
4243
from .types.spanner import CreateSessionRequest
4344
from .types.spanner import DeleteSessionRequest
@@ -110,6 +111,7 @@
110111
"BatchWriteRequest",
111112
"BatchWriteResponse",
112113
"BeginTransactionRequest",
114+
"ClientContext",
113115
"CommitRequest",
114116
"CommitResponse",
115117
"CreateSessionRequest",

google/cloud/spanner_v1/_helpers.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from google.cloud._helpers import _date_from_iso8601_date
3535
from google.cloud.spanner_v1.types import ExecuteSqlRequest
3636
from google.cloud.spanner_v1.types import TransactionOptions
37+
from google.cloud.spanner_v1.types import ClientContext
38+
from google.cloud.spanner_v1.types import RequestOptions
3739
from google.cloud.spanner_v1.data_types import JsonObject, Interval
3840
from google.cloud.spanner_v1.request_id_header import (
3941
with_request_id,
@@ -191,6 +193,74 @@ def _merge_query_options(base, merge):
191193
return combined
192194

193195

196+
def _merge_client_context(base, merge):
197+
"""Merge higher precedence ClientContext with current ClientContext.
198+
199+
:type base: :class:`~google.cloud.spanner_v1.types.ClientContext`
200+
or :class:`dict` or None
201+
:param base: The current ClientContext that is intended for use.
202+
203+
:type merge: :class:`~google.cloud.spanner_v1.types.ClientContext`
204+
or :class:`dict` or None
205+
:param merge:
206+
The ClientContext that has a higher priority than base. These options
207+
should overwrite the fields in base.
208+
209+
:rtype: :class:`~google.cloud.spanner_v1.types.ClientContext`
210+
or None
211+
:returns:
212+
ClientContext object formed by merging the two given ClientContexts.
213+
"""
214+
if base is None and merge is None:
215+
return None
216+
217+
combined = base or ClientContext()
218+
if type(combined) is dict:
219+
combined = ClientContext(combined)
220+
221+
merge = merge or ClientContext()
222+
if type(merge) is dict:
223+
merge = ClientContext(merge)
224+
225+
type(combined).pb(combined).MergeFrom(type(merge).pb(merge))
226+
if not combined.secure_context:
227+
return None
228+
return combined
229+
230+
231+
def _merge_request_options(request_options, client_context):
232+
"""Merge RequestOptions and ClientContext.
233+
234+
:type request_options: :class:`~google.cloud.spanner_v1.types.RequestOptions`
235+
or :class:`dict` or None
236+
:param request_options: The current RequestOptions that is intended for use.
237+
238+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
239+
or :class:`dict` or None
240+
:param client_context:
241+
The ClientContext to merge into request_options.
242+
243+
:rtype: :class:`~google.cloud.spanner_v1.types.RequestOptions`
244+
or None
245+
:returns:
246+
RequestOptions object formed by merging the given ClientContext.
247+
"""
248+
if request_options is None and client_context is None:
249+
return None
250+
251+
if request_options is None:
252+
request_options = RequestOptions()
253+
elif type(request_options) is dict:
254+
request_options = RequestOptions(request_options)
255+
256+
if client_context:
257+
request_options.client_context = _merge_client_context(
258+
client_context, request_options.client_context
259+
)
260+
261+
return request_options
262+
263+
194264
def _assert_numeric_precision_and_scale(value):
195265
"""
196266
Asserts that input numeric field is within Spanner supported range.

google/cloud/spanner_v1/batch.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
_metadata_with_prefix,
2828
_metadata_with_leader_aware_routing,
2929
_merge_Transaction_Options,
30+
_merge_client_context,
31+
_merge_request_options,
3032
AtomicCounter,
3133
)
3234
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
@@ -36,6 +38,7 @@
3638
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
3739
from google.api_core.exceptions import InternalServerError
3840
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
41+
from google.cloud.spanner_v1.types import ClientContext
3942
import time
4043

4144
DEFAULT_RETRY_TIMEOUT_SECS = 30
@@ -46,9 +49,14 @@ class _BatchBase(_SessionWrapper):
4649
4750
:type session: :class:`~google.cloud.spanner_v1.session.Session`
4851
:param session: the session used to perform the commit
52+
53+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
54+
or :class:`dict`
55+
:param client_context: (Optional) Client context to use for all requests made
56+
by this batch.
4957
"""
5058

51-
def __init__(self, session):
59+
def __init__(self, session, client_context=None):
5260
super(_BatchBase, self).__init__(session)
5361

5462
self._mutations: List[Mutation] = []
@@ -58,6 +66,13 @@ def __init__(self, session):
5866
"""Timestamp at which the batch was successfully committed."""
5967
self.commit_stats: Optional[CommitResponse.CommitStats] = None
6068

69+
if client_context is not None:
70+
if type(client_context) is dict:
71+
client_context = ClientContext(client_context)
72+
elif not isinstance(client_context, ClientContext):
73+
raise TypeError("client_context must be a ClientContext or a dict")
74+
self._client_context = client_context
75+
6176
def insert(self, table, columns, values):
6277
"""Insert one or more new table rows.
6378
@@ -226,10 +241,14 @@ def commit(
226241
txn_options,
227242
)
228243

244+
client_context = _merge_client_context(
245+
database._instance._client._client_context, self._client_context
246+
)
247+
request_options = _merge_request_options(request_options, client_context)
248+
229249
if request_options is None:
230250
request_options = RequestOptions()
231-
elif type(request_options) is dict:
232-
request_options = RequestOptions(request_options)
251+
233252
request_options.transaction_tag = self.transaction_tag
234253

235254
# Request tags are not supported for commit requests.
@@ -316,13 +335,25 @@ class MutationGroups(_SessionWrapper):
316335
317336
:type session: :class:`~google.cloud.spanner_v1.session.Session`
318337
:param session: the session used to perform the commit
338+
339+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
340+
or :class:`dict`
341+
:param client_context: (Optional) Client context to use for all requests made
342+
by this mutation group.
319343
"""
320344

321-
def __init__(self, session):
345+
def __init__(self, session, client_context=None):
322346
super(MutationGroups, self).__init__(session)
323347
self._mutation_groups: List[MutationGroup] = []
324348
self.committed: bool = False
325349

350+
if client_context is not None:
351+
if type(client_context) is dict:
352+
client_context = ClientContext(client_context)
353+
elif not isinstance(client_context, ClientContext):
354+
raise TypeError("client_context must be a ClientContext or a dict")
355+
self._client_context = client_context
356+
326357
def group(self):
327358
"""Returns a new `MutationGroup` to which mutations can be added."""
328359
mutation_group = BatchWriteRequest.MutationGroup()
@@ -364,10 +395,13 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
364395
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
365396
)
366397

398+
client_context = _merge_client_context(
399+
database._instance._client._client_context, self._client_context
400+
)
401+
request_options = _merge_request_options(request_options, client_context)
402+
367403
if request_options is None:
368404
request_options = RequestOptions()
369-
elif type(request_options) is dict:
370-
request_options = RequestOptions(request_options)
371405

372406
with trace_call(
373407
name="CloudSpanner.batch_write",

google/cloud/spanner_v1/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from google.cloud.spanner_v1 import __version__
4949
from google.cloud.spanner_v1 import ExecuteSqlRequest
5050
from google.cloud.spanner_v1 import DefaultTransactionOptions
51+
from google.cloud.spanner_v1.types import ClientContext
5152
from google.cloud.spanner_v1._helpers import _merge_query_options
5253
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
5354
from google.cloud.spanner_v1.instance import Instance
@@ -184,6 +185,10 @@ class Client(ClientWithProject):
184185
:param disable_builtin_metrics: (Optional) Default False. Set to True to disable
185186
the Spanner built-in metrics collection and exporting.
186187
188+
:type client_context: :class:`~google.cloud.spanner_v1.types.RequestOptions.ClientContext`
189+
or :class:`dict`
190+
:param client_context: (Optional) Client context to use for all requests made by this client.
191+
187192
:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
188193
and ``admin`` are :data:`True`
189194
"""
@@ -210,6 +215,7 @@ def __init__(
210215
default_transaction_options: Optional[DefaultTransactionOptions] = None,
211216
experimental_host=None,
212217
disable_builtin_metrics=False,
218+
client_context=None,
213219
):
214220
self._emulator_host = _get_spanner_emulator_host()
215221
self._experimental_host = experimental_host
@@ -247,6 +253,13 @@ def __init__(
247253
# Environment flag config has higher precedence than application config.
248254
self._query_options = _merge_query_options(query_options, env_query_options)
249255

256+
if client_context is not None:
257+
if type(client_context) is dict:
258+
client_context = ClientContext(client_context)
259+
elif not isinstance(client_context, ClientContext):
260+
raise TypeError("client_context must be a ClientContext or a dict")
261+
self._client_context = client_context
262+
250263
if self._emulator_host is not None and (
251264
"http://" in self._emulator_host or "https://" in self._emulator_host
252265
):

0 commit comments

Comments
 (0)