Skip to content

Commit

Permalink
all: implement custom tracer_provider injection
Browse files Browse the repository at this point in the history
An important feature for observability is to allow the injection
of a custom tracer_provider instead of always using the global
tracer_provider.
  • Loading branch information
odeke-em committed Nov 10, 2024
1 parent 41604fe commit eca7d36
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 11 deletions.
7 changes: 5 additions & 2 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@ def get_tracer(tracer_provider=None):


@contextmanager
def trace_call(name, session, extra_attributes=None):
def trace_call(name, session, extra_attributes=None, tracer_provider=None):
if not HAS_OPENTELEMETRY_INSTALLED or not session:
# Empty context manager. Users will have to check if the generated value is None or a span
yield None
return

tracer = get_tracer()
if tracer_provider is None and getattr(session, "_tracer_provider", None):
tracer_provider = session._tracer_provider

tracer = get_tracer(tracer_provider)

# Set base attributes that we know for every trace created
attributes = {
Expand Down
3 changes: 3 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
query_options=None,
route_to_leader_enabled=True,
directed_read_options=None,
tracer_provider=None,
):
self._emulator_host = _get_spanner_emulator_host()

Expand Down Expand Up @@ -187,6 +188,7 @@ def __init__(

self._route_to_leader_enabled = route_to_leader_enabled
self._directed_read_options = directed_read_options
self._tracer_provider = tracer_provider

@property
def credentials(self):
Expand Down Expand Up @@ -371,6 +373,7 @@ def instance(
self._emulator_host,
labels,
processing_units,
tracer_provider=self._tracer_provider,
)

def list_instances(self, filter_="", page_size=None):
Expand Down
14 changes: 12 additions & 2 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
database_role=None,
enable_drop_protection=False,
proto_descriptors=None,
tracer_provider=None,
):
self.database_id = database_id
self._instance = instance
Expand All @@ -178,11 +179,15 @@ def __init__(
self._reconciling = False
self._directed_read_options = self._instance._client.directed_read_options
self._proto_descriptors = proto_descriptors
self._tracer_provider = tracer_provider

if pool is None:
pool = BurstyPool(database_role=database_role)
pool = BurstyPool(
database_role=database_role, tracer_provider=self._tracer_provider
)

self._pool = pool
self._pool._tracer_provider = self._tracer_provider
pool.bind(self)

@classmethod
Expand Down Expand Up @@ -742,7 +747,12 @@ def session(self, labels=None, database_role=None):
# If role is specified in param, then that role is used
# instead.
role = database_role or self._database_role
return Session(self, labels=labels, database_role=role)
return Session(
self,
labels=labels,
database_role=role,
tracer_provider=self._tracer_provider,
)

def snapshot(self, **kw):
"""Return an object which wraps a snapshot.
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
emulator_host=None,
labels=None,
processing_units=None,
tracer_provider=None,
):
self.instance_id = instance_id
self._client = client
Expand All @@ -145,6 +146,7 @@ def __init__(
if labels is None:
labels = {}
self.labels = labels
self._tracer_provider = tracer_provider

def _update_from_pb(self, instance_pb):
"""Refresh self from the server-provided protobuf.
Expand Down Expand Up @@ -499,6 +501,7 @@ def database(
database_role=database_role,
enable_drop_protection=enable_drop_protection,
proto_descriptors=proto_descriptors,
tracer_provider=self._tracer_provider,
)
else:
return TestDatabase(
Expand All @@ -511,6 +514,7 @@ def database(
database_dialect=database_dialect,
database_role=database_role,
enable_drop_protection=enable_drop_protection,
tracer_provider=self._tracer_provider,
)

def list_databases(self, page_size=None):
Expand Down
23 changes: 18 additions & 5 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ class AbstractSessionPool(object):

_database = None

def __init__(self, labels=None, database_role=None):
def __init__(self, labels=None, database_role=None, tracer_provider=None):
if labels is None:
labels = {}
self._labels = labels
self._database_role = database_role
self._tracer_provider = tracer_provider

@property
def labels(self):
Expand Down Expand Up @@ -178,8 +179,11 @@ def __init__(
default_timeout=DEFAULT_TIMEOUT,
labels=None,
database_role=None,
tracer_provider=None,
):
super(FixedSizePool, self).__init__(labels=labels, database_role=database_role)
super(FixedSizePool, self).__init__(
labels=labels, database_role=database_role, tracer_provider=tracer_provider
)
self.size = size
self.default_timeout = default_timeout
self._sessions = queue.LifoQueue(size)
Expand Down Expand Up @@ -284,8 +288,12 @@ class BurstyPool(AbstractSessionPool):
:param database_role: (Optional) user-assigned database_role for the session.
"""

def __init__(self, target_size=10, labels=None, database_role=None):
super(BurstyPool, self).__init__(labels=labels, database_role=database_role)
def __init__(
self, target_size=10, labels=None, database_role=None, tracer_provider=None
):
super(BurstyPool, self).__init__(
labels=labels, database_role=database_role, tracer_provider=tracer_provider
)
self.target_size = target_size
self._database = None
self._sessions = queue.LifoQueue(target_size)
Expand Down Expand Up @@ -392,8 +400,11 @@ def __init__(
ping_interval=3000,
labels=None,
database_role=None,
tracer_provider=None,
):
super(PingingPool, self).__init__(labels=labels, database_role=database_role)
super(PingingPool, self).__init__(
labels=labels, database_role=database_role, tracer_provider=tracer_provider
)
self.size = size
self.default_timeout = default_timeout
self._delta = datetime.timedelta(seconds=ping_interval)
Expand Down Expand Up @@ -546,6 +557,7 @@ def __init__(
ping_interval=3000,
labels=None,
database_role=None,
tracer_provider=None,
):
"""This throws a deprecation warning on initialization."""
warn(
Expand All @@ -561,6 +573,7 @@ def __init__(
ping_interval,
labels=labels,
database_role=database_role,
tracer_provider=tracer_provider,
)

self.begin_pending_transactions()
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ class Session(object):
_session_id = None
_transaction = None

def __init__(self, database, labels=None, database_role=None):
def __init__(self, database, labels=None, database_role=None, tracer_provider=None):
self._database = database
if labels is None:
labels = {}
self._labels = labels
self._database_role = database_role
self._tracer_provider = tracer_provider

def __lt__(self, other):
return self._session_id < other._session_id
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/test__opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def _make_rpc_error(error_cls, trailing_metadata=None):
def _make_session():
from google.cloud.spanner_v1.session import Session

return mock.Mock(autospec=Session, instance=True)
session = mock.Mock(autospec=Session, instance=True)
# Setting _tracer_provider to None is to avoid the nasty spill-over
# of mock._tracer_provider spuriously failing tests, because per
# unittest.mock.Mock's definition invoking any attribute or method
# returns another mock.
setattr(session, "_tracer_provider", None)
return session


# Skip all of these tests if we don't have OpenTelemetry
Expand Down

0 comments on commit eca7d36

Please sign in to comment.