Skip to content

Commit ffb60f6

Browse files
committed
fix: Metrics thread-safety refactor and Batch.commit idempotency fix
1 parent 67c682e commit ffb60f6

File tree

10 files changed

+271
-118
lines changed

10 files changed

+271
-118
lines changed

.kokoro/presubmit/presubmit.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
# Only run a subset of all nox sessions
44
env_vars: {
55
key: "NOX_SESSION"
6-
value: "unit-3.9 unit-3.12 cover docs docfx"
6+
value: "unit-3.10 unit-3.12 cover docs docfx"
77
}

google/cloud/spanner_v1/batch.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Context manager for Cloud Spanner batched writes."""
16+
1617
import functools
1718
from typing import List, Optional
1819

@@ -27,7 +28,6 @@
2728
_metadata_with_prefix,
2829
_metadata_with_leader_aware_routing,
2930
_merge_Transaction_Options,
30-
AtomicCounter,
3131
)
3232
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
3333
from google.cloud.spanner_v1 import RequestOptions
@@ -252,9 +252,6 @@ def wrapped_method():
252252
max_commit_delay=max_commit_delay,
253253
request_options=request_options,
254254
)
255-
# This code is retried due to ABORTED, hence nth_request
256-
# should be increased. attempt can only be increased if
257-
# we encounter UNAVAILABLE or INTERNAL.
258255
call_metadata, error_augmenter = database.with_error_augmentation(
259256
getattr(database, "_next_nth_request", 0),
260257
1,
@@ -376,8 +373,6 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
376373
observability_options=getattr(database, "observability_options", None),
377374
metadata=metadata,
378375
) as span, MetricsCapture():
379-
attempt = AtomicCounter(0)
380-
nth_request = getattr(database, "_next_nth_request", 0)
381376

382377
def wrapped_method():
383378
batch_write_request = BatchWriteRequest(
@@ -390,8 +385,8 @@ def wrapped_method():
390385
api.batch_write,
391386
request=batch_write_request,
392387
metadata=database.metadata_with_request_id(
393-
nth_request,
394-
attempt.increment(),
388+
getattr(database, "_next_nth_request", 0),
389+
1,
395390
metadata,
396391
span,
397392
),

google/cloud/spanner_v1/client.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
* a :class:`~google.cloud.spanner_v1.instance.Instance` owns a
2424
:class:`~google.cloud.spanner_v1.database.Database`
2525
"""
26+
2627
import grpc
2728
import os
2829
import logging
@@ -108,6 +109,42 @@ def _get_spanner_enable_builtin_metrics_env():
108109
return os.getenv(SPANNER_DISABLE_BUILTIN_METRICS_ENV_VAR) != "true"
109110

110111

112+
def _initialize_metrics(project, credentials):
113+
"""
114+
Initializes the Spanner built-in metrics.
115+
116+
This function sets up the OpenTelemetry MeterProvider and the SpannerMetricsTracerFactory.
117+
It uses a lock to ensure that initialization happens only once.
118+
"""
119+
global _metrics_monitor_initialized
120+
if not _metrics_monitor_initialized:
121+
with _metrics_monitor_lock:
122+
if not _metrics_monitor_initialized:
123+
meter_provider = metrics.NoOpMeterProvider()
124+
try:
125+
if not _get_spanner_emulator_host():
126+
meter_provider = MeterProvider(
127+
metric_readers=[
128+
PeriodicExportingMetricReader(
129+
CloudMonitoringMetricsExporter(
130+
project_id=project,
131+
credentials=credentials,
132+
),
133+
export_interval_millis=METRIC_EXPORT_INTERVAL_MS,
134+
),
135+
]
136+
)
137+
metrics.set_meter_provider(meter_provider)
138+
SpannerMetricsTracerFactory()
139+
_metrics_monitor_initialized = True
140+
except Exception as e:
141+
# log is already defined at module level
142+
log.warning(
143+
"Failed to initialize Spanner built-in metrics. Error: %s",
144+
e,
145+
)
146+
147+
111148
class Client(ClientWithProject):
112149
"""Client for interacting with Cloud Spanner API.
113150
@@ -255,38 +292,12 @@ def __init__(
255292
"http://" in self._emulator_host or "https://" in self._emulator_host
256293
):
257294
warnings.warn(_EMULATOR_HOST_HTTP_SCHEME)
258-
# Check flag to enable Spanner builtin metrics
259-
global _metrics_monitor_initialized
260295
if (
261296
_get_spanner_enable_builtin_metrics_env()
262297
and not disable_builtin_metrics
263298
and HAS_GOOGLE_CLOUD_MONITORING_INSTALLED
264299
):
265-
if not _metrics_monitor_initialized:
266-
with _metrics_monitor_lock:
267-
if not _metrics_monitor_initialized:
268-
meter_provider = metrics.NoOpMeterProvider()
269-
try:
270-
if not _get_spanner_emulator_host():
271-
meter_provider = MeterProvider(
272-
metric_readers=[
273-
PeriodicExportingMetricReader(
274-
CloudMonitoringMetricsExporter(
275-
project_id=project,
276-
credentials=credentials,
277-
),
278-
export_interval_millis=METRIC_EXPORT_INTERVAL_MS,
279-
),
280-
]
281-
)
282-
metrics.set_meter_provider(meter_provider)
283-
SpannerMetricsTracerFactory()
284-
_metrics_monitor_initialized = True
285-
except Exception as e:
286-
log.warning(
287-
"Failed to initialize Spanner built-in metrics. Error: %s",
288-
e,
289-
)
300+
_initialize_metrics(project, credentials)
290301
else:
291302
SpannerMetricsTracerFactory(enabled=False)
292303

google/cloud/spanner_v1/metrics/metrics_capture.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
performance monitoring.
2121
"""
2222

23+
from contextvars import Token
24+
2325
from .spanner_metrics_tracer_factory import SpannerMetricsTracerFactory
2426

2527

@@ -30,6 +32,9 @@ class MetricsCapture:
3032
the start and completion of metrics tracing for a given operation.
3133
"""
3234

35+
_token: Token
36+
"""Token to reset the context variable after the operation completes."""
37+
3338
def __enter__(self):
3439
"""Enter the runtime context related to this object.
3540
@@ -45,11 +50,11 @@ def __enter__(self):
4550
return self
4651

4752
# Define a new metrics tracer for the new operation
48-
SpannerMetricsTracerFactory.current_metrics_tracer = (
49-
factory.create_metrics_tracer()
50-
)
51-
if SpannerMetricsTracerFactory.current_metrics_tracer:
52-
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_start()
53+
# Set the context var and keep the token for reset
54+
tracer = factory.create_metrics_tracer()
55+
self._token = SpannerMetricsTracerFactory.set_current_tracer(tracer)
56+
if tracer:
57+
tracer.record_operation_start()
5358
return self
5459

5560
def __exit__(self, exc_type, exc_value, traceback):
@@ -70,6 +75,11 @@ def __exit__(self, exc_type, exc_value, traceback):
7075
if not SpannerMetricsTracerFactory().enabled:
7176
return False
7277

73-
if SpannerMetricsTracerFactory.current_metrics_tracer:
74-
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_completion()
78+
tracer = SpannerMetricsTracerFactory.get_current_tracer()
79+
if tracer:
80+
tracer.record_operation_completion()
81+
82+
# Reset the context var using the token
83+
if getattr(self, "_token", None):
84+
SpannerMetricsTracerFactory.reset_current_tracer(self._token)
7585
return False # Propagate the exception if any

google/cloud/spanner_v1/metrics/metrics_interceptor.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,17 @@ def _set_metrics_tracer_attributes(self, resources: Dict[str, str]) -> None:
9797
Args:
9898
resources (Dict[str, str]): A dictionary containing project, instance, and database information.
9999
"""
100-
if SpannerMetricsTracerFactory.current_metrics_tracer is None:
100+
tracer = SpannerMetricsTracerFactory.get_current_tracer()
101+
if tracer is None:
101102
return
102103

103104
if resources:
104105
if "project" in resources:
105-
SpannerMetricsTracerFactory.current_metrics_tracer.set_project(
106-
resources["project"]
107-
)
106+
tracer.set_project(resources["project"])
108107
if "instance" in resources:
109-
SpannerMetricsTracerFactory.current_metrics_tracer.set_instance(
110-
resources["instance"]
111-
)
108+
tracer.set_instance(resources["instance"])
112109
if "database" in resources:
113-
SpannerMetricsTracerFactory.current_metrics_tracer.set_database(
114-
resources["database"]
115-
)
110+
tracer.set_database(resources["database"])
116111

117112
def intercept(self, invoked_method, request_or_iterator, call_details):
118113
"""Intercept gRPC calls to collect metrics.
@@ -126,31 +121,32 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
126121
The RPC response
127122
"""
128123
factory = SpannerMetricsTracerFactory()
129-
if (
130-
SpannerMetricsTracerFactory.current_metrics_tracer is None
131-
or not factory.enabled
132-
):
124+
tracer = SpannerMetricsTracerFactory.get_current_tracer()
125+
if tracer is None or not factory.enabled:
133126
return invoked_method(request_or_iterator, call_details)
134127

135128
# Setup Metric Tracer attributes from call details
136-
## Extract Project / Instance / Databse from header information
137-
resources = self._extract_resource_from_path(call_details.metadata)
138-
self._set_metrics_tracer_attributes(resources)
129+
## Extract Project / Instance / Database from header information if not already set
130+
if not (
131+
tracer.client_attributes.get("project_id")
132+
and tracer.client_attributes.get("instance_id")
133+
and tracer.client_attributes.get("database")
134+
):
135+
resources = self._extract_resource_from_path(call_details.metadata)
136+
self._set_metrics_tracer_attributes(resources)
139137

140138
## Format method to be be spanner.<method name>
141139
method_name = self._remove_prefix(
142140
call_details.method, SPANNER_METHOD_PREFIX
143141
).replace("/", ".")
144142

145-
SpannerMetricsTracerFactory.current_metrics_tracer.set_method(method_name)
146-
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_start()
143+
tracer.set_method(method_name)
144+
tracer.record_attempt_start()
147145
response = invoked_method(request_or_iterator, call_details)
148-
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_completion()
146+
tracer.record_attempt_completion()
149147

150148
# Process and send GFE metrics if enabled
151-
if SpannerMetricsTracerFactory.current_metrics_tracer.gfe_enabled:
149+
if tracer.gfe_enabled:
152150
metadata = response.initial_metadata()
153-
SpannerMetricsTracerFactory.current_metrics_trace.record_gfe_metrics(
154-
metadata
155-
)
151+
tracer.record_gfe_metrics(metadata)
156152
return response

google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import logging
2121
from .constants import SPANNER_SERVICE_NAME
22+
import contextvars
2223

2324
try:
2425
import mmh3
@@ -43,7 +44,9 @@ class SpannerMetricsTracerFactory(MetricsTracerFactory):
4344
"""A factory for creating SpannerMetricsTracer instances."""
4445

4546
_metrics_tracer_factory: "SpannerMetricsTracerFactory" = None
46-
current_metrics_tracer: MetricsTracer = None
47+
_current_metrics_tracer_ctx = contextvars.ContextVar(
48+
"current_metrics_tracer", default=None
49+
)
4750

4851
def __new__(
4952
cls, enabled: bool = True, gfe_enabled: bool = False
@@ -80,10 +83,22 @@ def __new__(
8083
cls._metrics_tracer_factory.gfe_enabled = gfe_enabled
8184

8285
if cls._metrics_tracer_factory.enabled != enabled:
83-
cls._metrics_tracer_factory.enabeld = enabled
86+
cls._metrics_tracer_factory.enabled = enabled
8487

8588
return cls._metrics_tracer_factory
8689

90+
@staticmethod
91+
def get_current_tracer() -> MetricsTracer:
92+
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()
93+
94+
@staticmethod
95+
def set_current_tracer(tracer: MetricsTracer) -> contextvars.Token:
96+
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.set(tracer)
97+
98+
@staticmethod
99+
def reset_current_tracer(token: contextvars.Token):
100+
SpannerMetricsTracerFactory._current_metrics_tracer_ctx.reset(token)
101+
87102
@staticmethod
88103
def _generate_client_uid() -> str:
89104
"""Generate a client UID in the form of uuidv4@pid@hostname.

tests/unit/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from unittest.mock import patch
17+
18+
19+
@pytest.fixture(autouse=True)
20+
def mock_periodic_exporting_metric_reader():
21+
"""Globally mock PeriodicExportingMetricReader to prevent real network calls."""
22+
with patch(
23+
"google.cloud.spanner_v1.client.PeriodicExportingMetricReader"
24+
) as mock_client_reader, patch(
25+
"opentelemetry.sdk.metrics.export.PeriodicExportingMetricReader"
26+
):
27+
yield mock_client_reader

0 commit comments

Comments
 (0)