From a73e1663e8cabf3f9650f21b70d19c8d6440a526 Mon Sep 17 00:00:00 2001 From: Subham Sinha Date: Wed, 24 Jun 2026 12:03:04 +0530 Subject: [PATCH] fix(metrics): correct GFE metrics extraction and enable by default --- .../spanner_v1/metrics/metrics_interceptor.py | 272 +++++++++++++++++- .../spanner_v1/metrics/metrics_tracer.py | 79 ++++- .../metrics/metrics_tracer_factory.py | 6 + .../metrics/spanner_metrics_tracer_factory.py | 7 +- .../spanner/transports/grpc_asyncio.py | 17 +- .../mockserver_tests/test_gfe_metrics.py | 206 +++++++++++++ .../tests/unit/test_metrics_interceptor.py | 11 +- .../tests/unit/test_metrics_tracer.py | 33 +++ 8 files changed, 610 insertions(+), 21 deletions(-) create mode 100644 packages/google-cloud-spanner/tests/mockserver_tests/test_gfe_metrics.py diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_interceptor.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_interceptor.py index 3e38c4e0191d..0ad9752f14a8 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_interceptor.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_interceptor.py @@ -14,14 +14,19 @@ """Interceptor for collecting Cloud Spanner metrics.""" +import inspect +import logging import re -from typing import Dict +from typing import Any, Dict +import grpc from grpc_interceptor import ClientInterceptor from .constants import GOOGLE_CLOUD_RESOURCE_KEY, SPANNER_METHOD_PREFIX from .spanner_metrics_tracer_factory import SpannerMetricsTracerFactory +logger = logging.getLogger(__name__) + class MetricsInterceptor(ClientInterceptor): """Interceptor that collects metrics for Cloud Spanner operations.""" @@ -88,6 +93,8 @@ def _set_metrics_tracer_attributes(self, resources: Dict[str, str]) -> None: if "database" in resources: tracer.set_database(resources["database"]) + + def intercept(self, invoked_method, request_or_iterator, call_details): """Intercept gRPC calls to collect metrics. @@ -122,10 +129,265 @@ def intercept(self, invoked_method, request_or_iterator, call_details): tracer.set_method(method_name) tracer.record_attempt_start() response = invoked_method(request_or_iterator, call_details) - tracer.record_attempt_completion() - # Process and send GFE metrics if enabled - if tracer.gfe_enabled: - metadata = response.initial_metadata() + return _wrap_response(response, tracer) + + +def _wrap_response(response: Any, tracer: Any) -> Any: + """Wraps the response if it is streaming, or records metrics immediately if unary.""" + if hasattr(response, "__next__"): + return _StreamingResponseWrapper(response, tracer) + else: + # Unary call: execute completion and record metrics immediately + try: + tracer.record_attempt_completion() + metadata = [] + if hasattr(response, "initial_metadata"): + try: + metadata.extend(response.initial_metadata() or []) + except Exception as e: + logger.warning(f"Failed to retrieve initial metadata: {e}") + if hasattr(response, "trailing_metadata"): + try: + metadata.extend(response.trailing_metadata() or []) + except Exception as e: + logger.warning(f"Failed to retrieve trailing metadata: {e}") tracer.record_gfe_metrics(metadata) + except Exception as e: + logger.warning(f"Failed to record metrics: {e}") return response + + +class AsyncMetricsInterceptor( + grpc.aio.UnaryUnaryClientInterceptor, + grpc.aio.UnaryStreamClientInterceptor, + grpc.aio.StreamUnaryClientInterceptor, + grpc.aio.StreamStreamClientInterceptor, +): + """Async Interceptor that collects metrics for Cloud Spanner operations.""" + + async def intercept_unary_unary(self, continuation, client_call_details, request): + return await self._async_intercept(continuation, client_call_details, request) + + async def intercept_unary_stream(self, continuation, client_call_details, request): + return await self._async_intercept(continuation, client_call_details, request) + + async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + return await self._async_intercept(continuation, client_call_details, request_iterator) + + async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + return await self._async_intercept(continuation, client_call_details, request_iterator) + + async def _async_intercept( + self, + continuation: Any, + call_details: grpc.ClientCallDetails, + request_or_iterator: Any, + ) -> Any: + # Implementation for async interceptor + factory = SpannerMetricsTracerFactory() + tracer = SpannerMetricsTracerFactory.get_current_tracer() + if tracer is None or not factory.enabled: + return await continuation(call_details, request_or_iterator) + + if not ( + tracer.client_attributes.get("project_id") + and tracer.client_attributes.get("instance_id") + and tracer.client_attributes.get("database") + ): + resources = MetricsInterceptor._extract_resource_from_path(call_details.metadata) + MetricsInterceptor._set_metrics_tracer_attributes(resources) + + method_name = call_details.method.removeprefix(SPANNER_METHOD_PREFIX).replace( + "/", "." + ) + + tracer.set_method(method_name) + tracer.record_attempt_start() + response = await continuation(call_details, request_or_iterator) + + if hasattr(response, "__anext__"): + return _AsyncStreamingResponseWrapper(response, tracer) + else: + return _AsyncUnaryResponseWrapper(response, tracer) + + +class _StreamingResponseWrapper: + """Wrapper for streaming RPC response iterators to defer metrics recording.""" + + def __init__(self, response, tracer): + self._response = response + self._tracer = tracer + self._metrics_recorded = False + self._iterator = None + + def __iter__(self): + self._iterator = iter(self._response) + return self + + def __next__(self): + if self._iterator is None: + self._iterator = iter(self._response) + try: + return next(self._iterator) + except StopIteration: + self._record_metrics() + raise + except Exception: + self._record_metrics() + raise + + def _record_metrics(self): + if self._metrics_recorded: + return + self._metrics_recorded = True + try: + self._tracer.record_attempt_completion() + metadata = [] + if hasattr(self._response, "initial_metadata"): + try: + metadata.extend(self._response.initial_metadata() or []) + except Exception as e: + logger.warning(f"Failed to retrieve initial metadata: {e}") + if hasattr(self._response, "trailing_metadata"): + try: + metadata.extend(self._response.trailing_metadata() or []) + except Exception as e: + logger.warning(f"Failed to retrieve trailing metadata: {e}") + self._tracer.record_gfe_metrics(metadata) + except Exception as e: + logger.warning(f"Failed to record metrics: {e}") + + def __del__(self): + try: + self._record_metrics() + except Exception: + pass + + def __getattr__(self, name): + return getattr(self._response, name) + + +class _AsyncUnaryResponseWrapper: + """Wrapper for async unary RPC response to defer metrics recording until awaited.""" + + def __init__(self, response, tracer): + self._response = response + self._tracer = tracer + self._metrics_recorded = False + + def __await__(self): + async def _wait(): + try: + return await self._response + finally: + await self._record_metrics() + return _wait().__await__() + + async def _record_metrics(self): + if self._metrics_recorded: + return + self._metrics_recorded = True + try: + self._tracer.record_attempt_completion() + metadata = [] + if hasattr(self._response, "initial_metadata"): + try: + res = self._response.initial_metadata() + if inspect.isawaitable(res): + res = await res + metadata.extend(res or []) + except Exception as e: + logger.warning(f"Failed to retrieve initial metadata: {e}") + if hasattr(self._response, "trailing_metadata"): + try: + res = self._response.trailing_metadata() + if inspect.isawaitable(res): + res = await res + metadata.extend(res or []) + except Exception as e: + logger.warning(f"Failed to retrieve trailing metadata: {e}") + self._tracer.record_gfe_metrics(metadata) + except Exception as e: + logger.warning(f"Failed to record metrics: {e}") + + def __del__(self): + if not self._metrics_recorded: + self._metrics_recorded = True + try: + self._tracer.record_attempt_completion() + except Exception: + pass + + def __getattr__(self, name): + return getattr(self._response, name) + + +class _AsyncStreamingResponseWrapper: + """Wrapper for async streaming RPC response iterators to defer metrics recording.""" + + def __init__(self, response, tracer): + self._response = response + self._tracer = tracer + self._metrics_recorded = False + self._iterator = None + + def __aiter__(self): + if hasattr(self._response, "__aiter__"): + self._iterator = self._response.__aiter__() + else: + self._iterator = self._response + return self + + async def __anext__(self): + if self._iterator is None: + if hasattr(self._response, "__aiter__"): + self._iterator = self._response.__aiter__() + else: + self._iterator = self._response + try: + return await self._iterator.__anext__() + except StopAsyncIteration: + await self._record_metrics() + raise + except Exception: + await self._record_metrics() + raise + + async def _record_metrics(self): + if self._metrics_recorded: + return + self._metrics_recorded = True + try: + self._tracer.record_attempt_completion() + metadata = [] + if hasattr(self._response, "initial_metadata"): + try: + res = self._response.initial_metadata() + if inspect.isawaitable(res): + res = await res + metadata.extend(res or []) + except Exception as e: + logger.warning(f"Failed to retrieve initial metadata: {e}") + if hasattr(self._response, "trailing_metadata"): + try: + res = self._response.trailing_metadata() + if inspect.isawaitable(res): + res = await res + metadata.extend(res or []) + except Exception as e: + logger.warning(f"Failed to retrieve trailing metadata: {e}") + self._tracer.record_gfe_metrics(metadata) + except Exception as e: + logger.warning(f"Failed to record metrics: {e}") + + def __del__(self): + if not self._metrics_recorded: + self._metrics_recorded = True + try: + self._tracer.record_attempt_completion() + except Exception: + pass + + def __getattr__(self, name): + return getattr(self._response, name) diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer.py index f79869948f99..27d33660e736 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer.py @@ -19,8 +19,9 @@ while the helper classes provide additional functionality and context for the metrics being traced. """ +import re from datetime import datetime -from typing import Dict +from typing import Any, Dict, Optional from grpc import StatusCode @@ -198,6 +199,8 @@ def __init__( instrument_operation_counter: "Counter", client_attributes: Dict[str, str], gfe_enabled: bool = False, + instrument_gfe_latency: Optional["Histogram"] = None, + instrument_gfe_missing_header_count: Optional["Counter"] = None, ): """ Initialize a MetricsTracer instance with the given parameters. @@ -214,6 +217,8 @@ def __init__( instrument_operation_counter (Counter): Instrument for counting operations. client_attributes (Dict[str, str]): Dictionary of client attributes used for metrics tracing. gfe_enabled (bool, optional): Indicates if GFE metrics are enabled. Defaults to False. + instrument_gfe_latency (Histogram, optional): Instrument for measuring GFE latency. + instrument_gfe_missing_header_count (Counter, optional): Instrument for counting missing GFE headers. """ self.current_op = MetricOpTracer() self._client_attributes = client_attributes @@ -221,8 +226,10 @@ def __init__( self._instrument_attempt_counter = instrument_attempt_counter self._instrument_operation_latency = instrument_operation_latency self._instrument_operation_counter = instrument_operation_counter + self._instrument_gfe_latency = instrument_gfe_latency + self._instrument_gfe_missing_header_count = instrument_gfe_missing_header_count self.enabled = enabled - self.gfe_enabled = gfe_enabled + self.gfe_enabled = True @staticmethod def _get_ms_time_diff(start: datetime, end: datetime) -> float: @@ -399,7 +406,11 @@ def record_gfe_latency(self, latency: int) -> None: Args: latency (int): The latency duration to be recorded. """ - if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED or not self.gfe_enabled: + if ( + not self.enabled + or not HAS_OPENTELEMETRY_INSTALLED + or not getattr(self, "_instrument_gfe_latency", None) + ): return self._instrument_gfe_latency.record( amount=latency, attributes=self.client_attributes @@ -409,12 +420,72 @@ def record_gfe_missing_header_count(self) -> None: """ Increments the counter for missing GFE headers. """ - if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED or not self.gfe_enabled: + if ( + not self.enabled + or not HAS_OPENTELEMETRY_INSTALLED + or not getattr(self, "_instrument_gfe_missing_header_count", None) + ): return self._instrument_gfe_missing_header_count.add( amount=1, attributes=self.client_attributes ) + @staticmethod + def extract_gfe_latency(metadata: Any) -> Optional[int]: + """ + Extracts the GFE latency value (in milliseconds) from response metadata. + """ + if not metadata: + return None + + header_vals = [] + if isinstance(metadata, dict): + for key, val in metadata.items(): + if key and str(key).lower() in ("server-timing", "server_timing"): + if isinstance(val, (list, tuple)): + header_vals.extend(val) + else: + header_vals.append(val) + elif isinstance(metadata, (list, tuple)): + for item in metadata: + if isinstance(item, (list, tuple)) and len(item) == 2: + key, val = item + if key and str(key).lower() in ("server-timing", "server_timing"): + if isinstance(val, (list, tuple)): + header_vals.extend(val) + else: + header_vals.append(val) + + for header_val in header_vals: + if not header_val: + continue + if isinstance(header_val, bytes): + try: + header_val = header_val.decode("utf-8") + except Exception: + header_val = str(header_val) + elif not isinstance(header_val, str): + header_val = str(header_val) + match = re.search(r"gfet4t7;\s*dur=([0-9.]+)", header_val) + if match: + try: + return int(float(match.group(1))) + except ValueError: + pass + return None + + def record_gfe_metrics(self, metadata: Any) -> None: + """ + Extracts and records GFE metrics from the RPC response metadata. + """ + if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED: + return + latency = self.extract_gfe_latency(metadata) + if latency is not None: + self.record_gfe_latency(latency) + else: + self.record_gfe_missing_header_count() + def _create_operation_otel_attributes(self) -> dict: """ Create additional attributes for operation metrics tracing. diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer_factory.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer_factory.py index f22d285c9750..029dddfaa15a 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer_factory.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer_factory.py @@ -85,6 +85,7 @@ def __init__(self, enabled: bool, service_name: str): project (str): The project ID for the monitored resource. """ self.enabled = enabled + self.gfe_enabled = True self._create_metric_instruments(service_name) self._client_attributes = {} @@ -268,6 +269,11 @@ def create_metrics_tracer(self) -> MetricsTracer: instrument_operation_latency=self._instrument_operation_latency, instrument_operation_counter=self._instrument_operation_counter, client_attributes=self._client_attributes.copy(), + gfe_enabled=True, + instrument_gfe_latency=getattr(self, "_instrument_gfe_latency", None), + instrument_gfe_missing_header_count=getattr( + self, "_instrument_gfe_missing_header_count", None + ), ) return metrics_tracer diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py index 6fc5956582c1..7886e555f120 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py @@ -51,9 +51,7 @@ class SpannerMetricsTracerFactory(MetricsTracerFactory): "current_metrics_tracer", default=None ) - def __new__( - cls, enabled: bool = True, gfe_enabled: bool = False - ) -> "SpannerMetricsTracerFactory": + def __new__(cls, enabled: bool = True) -> "SpannerMetricsTracerFactory": """ Create a new instance of SpannerMetricsTracerFactory if it doesn't already exist. @@ -63,7 +61,6 @@ def __new__( Args: enabled (bool): A flag indicating whether metrics tracing is enabled. Defaults to True. - gfe_enabled (bool): A flag indicating whether GFE metrics are enabled. Defaults to False. Returns: SpannerMetricsTracerFactory: The singleton instance of SpannerMetricsTracerFactory. @@ -83,7 +80,7 @@ def __new__( cls._generate_client_hash(client_uid) ) cls._metrics_tracer_factory.set_location(_get_cloud_region()) - cls._metrics_tracer_factory.gfe_enabled = gfe_enabled + cls._metrics_tracer_factory.gfe_enabled = True if cls._metrics_tracer_factory.enabled != enabled: cls._metrics_tracer_factory.enabled = enabled diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py index c688b31eefc4..f51bd3a3a085 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py @@ -32,7 +32,10 @@ from google.protobuf.json_format import MessageToJson from grpc.experimental import aio # type: ignore -from google.cloud.spanner_v1.metrics.metrics_interceptor import MetricsInterceptor +from google.cloud.spanner_v1.metrics.metrics_interceptor import ( + MetricsInterceptor, + AsyncMetricsInterceptor, +) from google.cloud.spanner_v1.types import ( commit_response, location, @@ -327,6 +330,18 @@ def __init__( ], ) + if metrics_interceptor is not None: + self._metrics_interceptor = AsyncMetricsInterceptor() + # Attach interceptor directly since grpc.aio does not provide intercept_channel. + if hasattr(self._grpc_channel, "_unary_unary_interceptors"): + self._grpc_channel._unary_unary_interceptors.append(self._metrics_interceptor) + if hasattr(self._grpc_channel, "_unary_stream_interceptors"): + self._grpc_channel._unary_stream_interceptors.append(self._metrics_interceptor) + if hasattr(self._grpc_channel, "_stream_unary_interceptors"): + self._grpc_channel._stream_unary_interceptors.append(self._metrics_interceptor) + if hasattr(self._grpc_channel, "_stream_stream_interceptors"): + self._grpc_channel._stream_stream_interceptors.append(self._metrics_interceptor) + self._interceptor = _LoggingClientAIOInterceptor() self._grpc_channel._unary_unary_interceptors.append(self._interceptor) self._logged_channel = self._grpc_channel diff --git a/packages/google-cloud-spanner/tests/mockserver_tests/test_gfe_metrics.py b/packages/google-cloud-spanner/tests/mockserver_tests/test_gfe_metrics.py new file mode 100644 index 000000000000..6d4cf51bff4e --- /dev/null +++ b/packages/google-cloud-spanner/tests/mockserver_tests/test_gfe_metrics.py @@ -0,0 +1,206 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +import grpc +from google.api_core.client_options import ClientOptions +from google.auth.credentials import AnonymousCredentials +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader + +import google.cloud.spanner_v1.client as client_mod +from google.cloud.spanner_v1 import Client +from google.cloud.spanner_v1.metrics.metrics_interceptor import MetricsInterceptor +from google.cloud.spanner_v1.metrics.spanner_metrics_tracer_factory import ( + SpannerMetricsTracerFactory, +) +from google.cloud.spanner_v1.pool import FixedSizePool +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_select1_result, +) + + +class TestGFEMetricsIntegration(MockServerTestBase): + def setUp(self): + super().setUp() + os.environ["SPANNER_DISABLE_BUILTIN_METRICS"] = "false" + SpannerMetricsTracerFactory._metrics_tracer_factory = None + client_mod._metrics_monitor_initialized = False + + def tearDown(self): + super().tearDown() + os.environ["SPANNER_DISABLE_BUILTIN_METRICS"] = "true" + SpannerMetricsTracerFactory._metrics_tracer_factory = None + client_mod._metrics_monitor_initialized = False + + def test_gfe_metrics_exported(self): + add_select1_result() + reader = InMemoryMetricReader() + meter_provider = MeterProvider(metric_readers=[reader]) + + orig_call = grpc._channel._UnaryStreamMultiCallable.__call__ + orig_initial_metadata = grpc._channel._MultiThreadedRendezvous.initial_metadata + orig_trailing_metadata = ( + grpc._channel._MultiThreadedRendezvous.trailing_metadata + ) + + def custom_initial_metadata(self): + mocked = getattr(self, "_is_execute_streaming_sql_mock", False) + if mocked: + return (("server-timing", "gfet4t7; dur=55"),) + return orig_initial_metadata(self) + + def custom_trailing_metadata(self): + mocked = getattr(self, "_is_execute_streaming_sql_mock", False) + if mocked: + return (("server-timing", "gfet4t7; dur=55"),) + return orig_trailing_metadata(self) + + def custom_call(self_callable, request, *args, **kwargs): + method = getattr(self_callable, "_method", b"") + method_str = method.decode("utf-8") if isinstance(method, bytes) else method + response = orig_call(self_callable, request, *args, **kwargs) + if "ExecuteStreamingSql" in method_str: + response._is_execute_streaming_sql_mock = True + return response + + try: + with ( + mock.patch( + "google.cloud.spanner_v1.metrics.metrics_tracer_factory.get_meter_provider", + return_value=meter_provider, + ), + mock.patch( + "google.cloud.spanner_v1.client.MeterProvider", + return_value=meter_provider, + ), + mock.patch( + "google.cloud.spanner_v1.client._get_spanner_emulator_host", + return_value=None, + ), + mock.patch( + "grpc._channel._UnaryStreamMultiCallable.__call__", + custom_call, + ), + mock.patch( + "grpc._channel._MultiThreadedRendezvous.initial_metadata", + custom_initial_metadata, + ), + mock.patch( + "grpc._channel._MultiThreadedRendezvous.trailing_metadata", + custom_trailing_metadata, + ), + ): + client = Client( + project="p", + credentials=AnonymousCredentials(), + client_options=ClientOptions( + api_endpoint="localhost:" + str(MockServerTestBase.port), + ), + ) + instance = client.instance("test-instance") + database = instance.database( + "test-database", + pool=FixedSizePool(size=10), + enable_interceptors_in_tests=True, + ) + database._interceptors.append(MetricsInterceptor()) + database._spanner_api = ( + None # Force recreation with the new interceptor + ) + + with database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + # Consume the streaming results to complete the stream + list(results) + + metric_data = reader.get_metrics_data() + self.assertIsNotNone(metric_data) + metrics = { + metric.name: metric + for rm in metric_data.resource_metrics + for sm in rm.scope_metrics + for metric in sm.metrics + } + + self.assertIn("gfe_latency", metrics, f"Metrics: {list(metrics.keys())}") + gfe_metric = metrics["gfe_latency"] + point = next(iter(gfe_metric.data.data_points)) + self.assertEqual(point.sum, 55) + + finally: + pass + + def test_gfe_missing_header_count_exported(self): + add_select1_result() + reader = InMemoryMetricReader() + meter_provider = MeterProvider(metric_readers=[reader]) + + try: + with ( + mock.patch( + "google.cloud.spanner_v1.metrics.metrics_tracer_factory.get_meter_provider", + return_value=meter_provider, + ), + mock.patch( + "google.cloud.spanner_v1.client.MeterProvider", + return_value=meter_provider, + ), + mock.patch( + "google.cloud.spanner_v1.client._get_spanner_emulator_host", + return_value=None, + ), + ): + client = Client( + project="p", + credentials=AnonymousCredentials(), + client_options=ClientOptions( + api_endpoint="localhost:" + str(MockServerTestBase.port), + ), + ) + instance = client.instance("test-instance") + database = instance.database( + "test-database", + pool=FixedSizePool(size=10), + enable_interceptors_in_tests=True, + ) + database._interceptors.append(MetricsInterceptor()) + database._spanner_api = ( + None # Force recreation with the new interceptor + ) + + with database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + list(results) + + metric_data = reader.get_metrics_data() + self.assertIsNotNone(metric_data) + metrics = { + metric.name: metric + for rm in metric_data.resource_metrics + for sm in rm.scope_metrics + for metric in sm.metrics + } + + self.assertIn( + "gfe_missing_header_count", metrics, f"Metrics: {list(metrics.keys())}" + ) + missing_metric = metrics["gfe_missing_header_count"] + point = next(iter(missing_metric.data.data_points)) + self.assertGreaterEqual(point.value, 1) + finally: + pass diff --git a/packages/google-cloud-spanner/tests/unit/test_metrics_interceptor.py b/packages/google-cloud-spanner/tests/unit/test_metrics_interceptor.py index 6e091860b425..efa080191c9e 100644 --- a/packages/google-cloud-spanner/tests/unit/test_metrics_interceptor.py +++ b/packages/google-cloud-spanner/tests/unit/test_metrics_interceptor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest @@ -41,7 +41,7 @@ def __init__(self): self.project = None self.instance = None self.database = None - self.gfe_enabled = False + self.gfe_enabled = True self.record_attempt_start = MagicMock() self.record_attempt_completion = MagicMock() self.set_method = MagicMock() @@ -99,10 +99,8 @@ def test_set_metrics_tracer_attributes(interceptor, mock_tracer_ctx): def test_intercept_with_tracer(interceptor, mock_tracer_ctx): # mock_tracer_ctx fixture sets the ContextVar - mock_tracer_ctx.gfe_enabled = False - - invoked_response = MagicMock() - invoked_response.initial_metadata.return_value = {} + invoked_response = Mock() + invoked_response.initial_metadata.return_value = [] mock_invoked_method = MagicMock(return_value=invoked_response) call_details = MagicMock( @@ -119,4 +117,5 @@ def test_intercept_with_tracer(interceptor, mock_tracer_ctx): assert response == invoked_response mock_tracer_ctx.record_attempt_start.assert_called() mock_tracer_ctx.record_attempt_completion.assert_called_once() + mock_tracer_ctx.record_gfe_metrics.assert_called_once() mock_invoked_method.assert_called_once_with("request", call_details) diff --git a/packages/google-cloud-spanner/tests/unit/test_metrics_tracer.py b/packages/google-cloud-spanner/tests/unit/test_metrics_tracer.py index 90b2f2f511f9..4769974f0c8a 100644 --- a/packages/google-cloud-spanner/tests/unit/test_metrics_tracer.py +++ b/packages/google-cloud-spanner/tests/unit/test_metrics_tracer.py @@ -264,3 +264,36 @@ def test_record_gfe_missing_header_count(metrics_tracer): metrics_tracer.record_gfe_missing_header_count() assert mock_gfe_missing_header_count.add.call_count == 1 # Should not increment metrics_tracer.enabled = True # Reset for next test + + +def test_extract_gfe_latency(): + # Valid trailing metadata list of tuples + metadata_list = [("server-timing", "gfet4t7; dur=123")] + assert MetricsTracer.extract_gfe_latency(metadata_list) == 123 + + # Valid metadata dict + metadata_dict = {"server-timing": "gfet4t7; dur=456"} + assert MetricsTracer.extract_gfe_latency(metadata_dict) == 456 + + # Missing header + assert MetricsTracer.extract_gfe_latency([("other-header", "val")]) is None + assert MetricsTracer.extract_gfe_latency(None) is None + + +def test_record_gfe_metrics(metrics_tracer): + mock_gfe_latency = mock.create_autospec(Histogram, instance=True) + mock_gfe_missing = mock.create_autospec(Counter, instance=True) + metrics_tracer._instrument_gfe_latency = mock_gfe_latency + metrics_tracer._instrument_gfe_missing_header_count = mock_gfe_missing + metrics_tracer.gfe_enabled = True + + # With header + metrics_tracer.record_gfe_metrics([("server-timing", "gfet4t7; dur=88")]) + assert mock_gfe_latency.record.call_count == 1 + assert mock_gfe_latency.record.call_args[1]["amount"] == 88 + assert mock_gfe_missing.add.call_count == 0 + + # Without header + metrics_tracer.record_gfe_metrics([("other", "1")]) + assert mock_gfe_latency.record.call_count == 1 + assert mock_gfe_missing.add.call_count == 1