From 0ac6a63ca7711754759188328a51b7b4ed623199 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Wed, 13 Nov 2024 22:30:27 -0800 Subject: [PATCH] Update with code review suggestions --- google/cloud/spanner_v1/database.py | 2 +- ...ility.py => test_observability_options.py} | 62 ++++++++----------- 2 files changed, 26 insertions(+), 38 deletions(-) rename tests/system/{test_observability.py => test_observability_options.py} (76%) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 4b6ff4bb4c..abddd5d97d 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -1116,7 +1116,7 @@ def observability_options(self): if not (self._instance and self._instance._client): return None - return getattr(self._instance._client, 'observability_options', None) + return getattr(self._instance._client, "observability_options", None) class BatchCheckout(object): diff --git a/tests/system/test_observability.py b/tests/system/test_observability_options.py similarity index 76% rename from tests/system/test_observability.py rename to tests/system/test_observability_options.py index 0ee5b413ab..1f9058b2d5 100644 --- a/tests/system/test_observability.py +++ b/tests/system/test_observability_options.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import mock import pytest import unittest -import mock from . import _helpers from google.cloud.spanner_v1 import Client, DirectedReadOptions @@ -37,7 +37,7 @@ @pytest.mark.skipif(not HAS_OTEL_INSTALLED, reason="OpenTelemetry needed.") @pytest.mark.skipif(not _helpers.USE_EMULATOR, reason="Emulator needed.") -class TestObservability(unittest.TestCase): +def test_observability_options_propagation(): PROJECT = _helpers.EMULATOR_PROJECT PATH = "projects/%s" % (PROJECT,) CONFIGURATION_NAME = "config-name" @@ -62,13 +62,7 @@ class TestObservability(unittest.TestCase): }, } - def test_observability_options_propagated_extended_tracing_off(self): - self.__test_observability_options(True) - - def test_observability_options_propagated(self): - self.__test_observability_options(False) - - def __test_observability_options(self, enable_extended_tracing): + def test_propagation(enable_extended_tracing): global_tracer_provider = TracerProvider(sampler=ALWAYS_ON) trace.set_tracer_provider(global_tracer_provider) global_trace_exporter = InMemorySpanExporter() @@ -86,17 +80,17 @@ def __test_observability_options(self, enable_extended_tracing): enable_extended_tracing=enable_extended_tracing, ) client = Client( - project=self.PROJECT, + project=PROJECT, observability_options=observability_options, credentials=_make_credentials(), ) instance = client.instance( - self.INSTANCE_ID, - self.CONFIGURATION_NAME, - display_name=self.DISPLAY_NAME, - node_count=self.NODE_COUNT, - labels=self.LABELS, + INSTANCE_ID, + CONFIGURATION_NAME, + display_name=DISPLAY_NAME, + node_count=NODE_COUNT, + labels=LABELS, ) try: @@ -104,13 +98,13 @@ def __test_observability_options(self, enable_extended_tracing): except: pass - db = instance.database(self.DATABASE_ID) + db = instance.database(DATABASE_ID) try: db.create() except: pass - self.assertEqual(db.observability_options, observability_options) + assert db.observability_options == observability_options with db.snapshot() as snapshot: res = snapshot.execute_sql("SELECT 1") for val in res: @@ -118,34 +112,24 @@ def __test_observability_options(self, enable_extended_tracing): from_global_spans = global_trace_exporter.get_finished_spans() from_inject_spans = inject_trace_exporter.get_finished_spans() - self.assertEqual( - len(from_global_spans), - 0, - "Expecting no spans from the global trace exporter", - ) - self.assertEqual( - len(from_inject_spans) >= 2, - True, - "Expecting at least 2 spans from the injected trace exporter", - ) + assert ( + len(from_global_spans) == 0 + ) # "Expecting no spans from the global trace exporter" + assert ( + len(from_inject_spans) >= 2 + ) # "Expecting at least 2 spans from the injected trace exporter" gotNames = [span.name for span in from_inject_spans] wantNames = ["CloudSpanner.CreateSession", "CloudSpanner.ReadWriteTransaction"] - self.assertEqual( - gotNames, - wantNames, - "Span names mismatch", - ) + assert gotNames == wantNames # Check for conformance of enable_extended_tracing lastSpan = from_inject_spans[len(from_inject_spans) - 1] wantAnnotatedSQL = "SELECT 1" if not enable_extended_tracing: wantAnnotatedSQL = None - self.assertEqual( - lastSpan.attributes.get("db.statement", None), - wantAnnotatedSQL, - "Mismatch in annotated sql", - ) + assert ( + lastSpan.attributes.get("db.statement", None) == wantAnnotatedSQL + ) # "Mismatch in annotated sql" try: db.delete() @@ -153,6 +137,10 @@ def __test_observability_options(self, enable_extended_tracing): except: pass + # Test the respective options for enable_extended_tracing + test_propagation(True) + test_propagation(False) + def _make_credentials(): import google.auth.credentials