diff --git a/nameko_grpc/entrypoint.py b/nameko_grpc/entrypoint.py index cebec63..16fded9 100644 --- a/nameko_grpc/entrypoint.py +++ b/nameko_grpc/entrypoint.py @@ -17,6 +17,7 @@ from nameko_grpc.inspection import Inspector from nameko_grpc.streams import ReceiveStream, SendStream from nameko_grpc.timeout import unbucket_timeout +from nameko_grpc.utils import Teeable log = getLogger(__name__) @@ -234,6 +235,8 @@ def handle_request(self, request_stream, response_stream): if self.cardinality in (Cardinality.UNARY_STREAM, Cardinality.UNARY_UNARY): request = next(request) + else: + request = Teeable(request) context = GrpcContext(request_stream, response_stream) @@ -261,18 +264,28 @@ def handle_request(self, request_stream, response_stream): def handle_result(self, response_stream, worker_ctx, result, exc_info): if self.cardinality in (Cardinality.STREAM_UNARY, Cardinality.UNARY_UNARY): - result = (result,) + response = (result,) + else: + result = Teeable(result) + response = result if exc_info is None: - try: - response_stream.populate(result) - except Exception as exception: - error = GrpcError( - status=StatusCode.UNKNOWN, - details="Exception iterating responses: {}".format(exception), - debug_error_string="", - ) - response_stream.close(error) + + def send_response(): + try: + response_stream.populate(response) + except Exception as exception: + error = GrpcError( + status=StatusCode.UNKNOWN, + details="Exception iterating responses: {}".format(exception), + debug_error_string="", + ) + response_stream.close(error) + + self.container.spawn_managed_thread( + send_response, identifier="send_response" + ) + else: error = GrpcError( status=StatusCode.UNKNOWN, diff --git a/nameko_grpc/tracer/__init__.py b/nameko_grpc/tracer/__init__.py new file mode 100644 index 0000000..1be39ab --- /dev/null +++ b/nameko_grpc/tracer/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +from nameko_grpc.tracer.adapter import GrpcEntrypointAdapter # noqa: F401 +from nameko_grpc.tracer.dependency import GrpcTracer # noqa: F401 diff --git a/nameko_grpc/tracer/adapter.py b/nameko_grpc/tracer/adapter.py new file mode 100644 index 0000000..2cf22e4 --- /dev/null +++ b/nameko_grpc/tracer/adapter.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- +from nameko_tracer import constants +from nameko_tracer.adapters import DefaultAdapter + +from nameko_grpc.constants import Cardinality + + +GRPC_STREAM = "GrpcStream" +GRPC_CONTEXT = "GrpcContext" +GRPC_REQUEST = "GrpcRequest" +GRPC_RESPONSE = "GrpcResponse" + + +def is_request_record(extra): + """ Determine whether the record represents a request. + """ + return extra["stage"] == constants.Stage.request + + +def is_response_record(extra): + """ Determine whether the record represents a response. + """ + return extra["stage"] == constants.Stage.response + + +def is_stream_record(extra): + """ Determine whether the record represents part of a streaming request or response. + """ + return "stream_part" in extra + + +def is_error_record(extra): + """ Determine whether the record represents an error response + """ + return is_response_record(extra) and extra["exc_info_"] + + +def is_streaming_request_method(extra): + """ Determine whether the record relates to a method that has a streaming request. + + Note that the record may be the request or response trace, or part of one of these. + """ + cardinality = get_cardinality(extra) + return cardinality in (Cardinality.STREAM_UNARY, Cardinality.STREAM_STREAM) + + +def is_streaming_response_method(extra): + """ Determine whether the record relates to a method that has a streaming response. + + Note that the record may be the request or response trace, or part of one of these. + """ + cardinality = get_cardinality(extra) + return cardinality in (Cardinality.UNARY_STREAM, Cardinality.STREAM_STREAM) + + +def get_cardinality(extra): + """ Extract the cardinality of the method that this record relates to. + """ + return extra["worker_ctx"].entrypoint.cardinality + + +def add_cardinality(extra): + """ Add the cardinality of the method to which this record relates to the trace + data. + """ + trace_data = extra[constants.TRACE_KEY] + trace_data["cardinality"] = get_cardinality(extra) + + +def add_stream_part(extra): + """ If this record represents part of a stream, add the stream part identifier to + the trace data. + """ + if not is_stream_record(extra): + return + trace_data = extra[constants.TRACE_KEY] + trace_data["stream_part"] = extra["stream_part"] + + +def add_stream_age(extra): + """ If this record represents part of a stream, add the commulative stream age to + the trace data. + """ + if not is_stream_record(extra): + return + trace_data = extra[constants.TRACE_KEY] + trace_data["stream_age"] = extra["stream_age"] + + +def add_grpc_request(extra): + """ Add the GRPC request message to the trace data for this record, under the + `grpc_request` key. + + All records receive a value for this key. + + If this record relates to a method that has a streaming request and the record + does not represent part that stream (i.e. it's a "top-level" record, or a response), + the value is the GRPC_STREAM placeholder. + """ + trace_data = extra[constants.TRACE_KEY] + + if is_streaming_request_method(extra): + if is_request_record(extra) and is_stream_record(extra): + trace_data["grpc_request"] = extra["request"] + else: + trace_data["grpc_request"] = GRPC_STREAM + else: + request, context = extra["worker_ctx"].args + trace_data["grpc_request"] = request + + +def add_grpc_response(extra): + """ Add the GRPC response message to the trace data for this record, under the + `grpc_response` key. + + Only response records receive a value for this key. + + If this record relates to a method that has a streaming response and the record + does not represent part of that stream (i.e. it's the "top-level" record), + the value is the GRPC_STREAM placeholder. + """ + if not is_response_record(extra): + return + + trace_data = extra[constants.TRACE_KEY] + + if is_streaming_response_method(extra): + if is_stream_record(extra): + trace_data["grpc_response"] = extra["result"] + else: + trace_data["grpc_response"] = GRPC_STREAM + else: + trace_data["grpc_response"] = extra["result"] + + +def add_grpc_context(extra): + """ Add the GRPC context object to the trace data for this record, under the + `grpc_context` key. + """ + request, context = extra["worker_ctx"].args + + trace_data = extra[constants.TRACE_KEY] + trace_data["grpc_context"] = context + + +def clean_call_args(extra): + """ Replace the `context` and `request` keys of `call_args` in the trace data for + this record. + + These objects are exposed in the `grpc_context` and `grpc_request` fields + respectively and don't need to be in multiple places. See `add_grpc_context` and + `add_grpc_request` respectively. + + The value for `context` is the GRPC_CONTEXT placeholder. The value for `request` + is the GRPC_REQUEST placeholder, or GRPC_STREAM placeholder if this record relates + to a method that has a streaming request. + """ + + trace_data = extra[constants.TRACE_KEY] + trace_data["call_args"]["context"] = GRPC_CONTEXT + + if is_streaming_request_method(extra): + trace_data["call_args"]["request"] = GRPC_STREAM + else: + trace_data["call_args"]["request"] = GRPC_REQUEST + + +def clean_response(extra): + """ Replaces the `response` key in the trace data for this record. + + Only successful response records have a value for this key. + + The GRPC response message is exposed in the `grpc_response` field and doesn't need + to be in multiple places. See `add_grpc_response`. + + The value for `response` is the GRPC_RESPONSE placeholder, or GRPC_STREAM + placeholder if this record relates to a method that has a streaming response. + + """ + + if not is_response_record(extra) or is_error_record(extra): + return + + trace_data = extra[constants.TRACE_KEY] + + if is_streaming_response_method(extra) and not is_stream_record(extra): + trace_data["response"] = GRPC_STREAM + else: + trace_data["response"] = GRPC_RESPONSE + + +def clean_response_status(extra): + """ Replaces `response_status` keys in the trace data for this record. + + Only response records have a value for this key. + + The value for is unchanged unless this record relates to a method that has a + streaming response and the record does not represent part of that stream + (i.e. it's the "top-level" record), and the record does not already indicate an + error (i.e. the method immediately failed). + + The status of the response is therefore not yet known, so the value is set to + `None`. + + """ + if not is_response_record(extra) or is_error_record(extra): + return + + trace_data = extra[constants.TRACE_KEY] + + if is_streaming_response_method(extra) and not is_stream_record(extra): + # response status still unknown + trace_data["response_status"] = None + + +class GrpcEntrypointAdapter(DefaultAdapter): + """ Logging adapter for methods decorated with the Grpc entrypoint. + + Records may represent one of the following: + + * The request to a "unary request" RPC method + * The response from a "unary response" RPC method + * The "top-level" request to a "streaming request" RPC method + * Each "part" of the stream to a "streaming request" RPC method + * The "top-level" response from a "streaming response" RPC method + * Each "part" of the stream from a "streaming response" RPC method + + """ + + def process(self, message, kwargs): + message, kwargs = super().process(message, kwargs) + + extra = kwargs["extra"] + + add_cardinality(extra) + add_stream_part(extra) + add_stream_age(extra) + + add_grpc_request(extra) + add_grpc_response(extra) + add_grpc_context(extra) + + clean_call_args(extra) + clean_response(extra) + clean_response_status(extra) + + return message, kwargs + + def get_result(self, result): + """ Override get_result to remove serialization. + """ + return result diff --git a/nameko_grpc/tracer/dependency.py b/nameko_grpc/tracer/dependency.py new file mode 100644 index 0000000..6eb3699 --- /dev/null +++ b/nameko_grpc/tracer/dependency.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +import logging +import sys +from datetime import datetime + +from nameko_tracer import Tracer, constants + +from nameko_grpc.constants import Cardinality + + +logger = logging.getLogger(__name__) + +GRPC_ADAPTER = { + "nameko_grpc.entrypoint.Grpc": ("nameko_grpc.tracer.GrpcEntrypointAdapter") +} + + +class GrpcTracer(Tracer): + """ Extend nameko_tracer.Tracer to add support for the Grpc entrypoint, including + streaming requests and responses. + """ + + def setup(self): + self.configure_adapter_types(GRPC_ADAPTER) + super().setup() + + def log_request(self, worker_ctx): + request, context = worker_ctx.args + + cardinality = worker_ctx.entrypoint.cardinality + + request_stream = None + if cardinality in (Cardinality.STREAM_UNARY, Cardinality.STREAM_STREAM): + request_stream = request.tee() + + timestamp = datetime.utcnow() + self.worker_timestamps[worker_ctx] = timestamp + + extra = { + "stage": constants.Stage.request, + "worker_ctx": worker_ctx, + "timestamp": timestamp, + } + try: + adapter = self.adapter_factory(worker_ctx) + adapter.info("[%s] entrypoint call trace", worker_ctx.call_id, extra=extra) + except Exception: + logger.warning("Failed to log entrypoint trace", exc_info=True) + + if request_stream: + self.container.spawn_managed_thread( + lambda: self.log_request_stream(worker_ctx, request_stream) + ) + + def log_result(self, worker_ctx, result, exc_info): + + cardinality = worker_ctx.entrypoint.cardinality + + timestamp = datetime.utcnow() + worker_setup_timestamp = self.worker_timestamps[worker_ctx] + response_time = (timestamp - worker_setup_timestamp).total_seconds() + + result_stream = None + + if exc_info is None and cardinality in ( + Cardinality.UNARY_STREAM, + Cardinality.STREAM_STREAM, + ): + result_stream = result.tee() + + extra = { + "stage": constants.Stage.response, + "worker_ctx": worker_ctx, + "result": result, + "exc_info_": exc_info, + "timestamp": timestamp, + "response_time": response_time, + } + + try: + adapter = self.adapter_factory(worker_ctx) + if exc_info: + adapter.warning( + "[%s] entrypoint result trace", worker_ctx.call_id, extra=extra + ) + else: + adapter.info( + "[%s] entrypoint result trace", worker_ctx.call_id, extra=extra + ) + except Exception: + logger.warning("Failed to log entrypoint trace", exc_info=True) + + if result_stream: + self.container.spawn_managed_thread( + lambda: self.log_result_stream(worker_ctx, result_stream) + ) + + def log_request_stream(self, worker_ctx, request_stream): + + stream_start = datetime.utcnow() + + for index, request in enumerate(request_stream, start=1): + + timestamp = datetime.utcnow() + stream_age = (timestamp - stream_start).total_seconds() + + extra = { + "stage": constants.Stage.request, + "worker_ctx": worker_ctx, + "timestamp": timestamp, + "stream_age": stream_age, + "stream_part": index, + "request": request, + } + try: + adapter = self.adapter_factory(worker_ctx) + adapter.info( + "[%s] entrypoint call trace [stream_part %s]", + worker_ctx.call_id, + index, + extra=extra, + ) + except Exception: + logger.warning("Failed to log entrypoint trace", exc_info=True) + + def log_result_stream(self, worker_ctx, result_stream): + + stream_start = datetime.utcnow() + worker_setup_timestamp = self.worker_timestamps[worker_ctx] + + def log(stream_part, result, exc_info, level): + timestamp = datetime.utcnow() + stream_age = (timestamp - stream_start).total_seconds() + response_time = (timestamp - worker_setup_timestamp).total_seconds() + + extra = { + "stage": constants.Stage.response, + "worker_ctx": worker_ctx, + "result": result, + "exc_info_": exc_info, + "timestamp": timestamp, + "response_time": response_time, + "stream_age": stream_age, + "stream_part": stream_part, + } + try: + adapter = self.adapter_factory(worker_ctx) + adapter.log( + level, + "[%s] entrypoint result trace [stream_part %s]", + worker_ctx.call_id, + stream_part, + extra=extra, + ) + except Exception: + logger.warning("Failed to log entrypoint trace", exc_info=True) + + try: + for index, result in enumerate(result_stream, start=1): + log(index, result, None, logging.INFO) + except Exception: + log(index + 1, None, sys.exc_info(), logging.WARNING) + + def worker_setup(self, worker_ctx): + """ Log entrypoint call details + """ + self.log_request(worker_ctx) + + def worker_result(self, worker_ctx, result=None, exc_info=None): + """ Log entrypoint result details + """ + self.log_result(worker_ctx, result, exc_info) diff --git a/nameko_grpc/tracer/formatter.py b/nameko_grpc/tracer/formatter.py new file mode 100644 index 0000000..7ab5110 --- /dev/null +++ b/nameko_grpc/tracer/formatter.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +import json + +from google.protobuf.json_format import MessageToJson +from nameko_tracer import constants +from nameko_tracer.formatters import JSONFormatter + +from nameko_grpc.constants import Cardinality +from nameko_grpc.context import GrpcContext + + +def default(obj): + try: + return MessageToJson(obj) + except Exception: + pass + + if isinstance(obj, GrpcContext): + return { + "request_metadata": obj.invocation_metadata(), + "response_headers": obj.response_stream.headers.for_application, + "response_trailers": obj.response_stream.trailers.for_application, + } + + try: + if obj in Cardinality: + return obj.name + except TypeError: + pass # https://docs.python.org/3/whatsnew/3.7.html#enum + + return str(obj) + + +def serialise(obj): + return json.dumps(obj, default=default) + + +class GrpcJsonFormatter(JSONFormatter): + + extra_serialise_keys = ( + constants.CONTEXT_DATA_KEY, + constants.EXCEPTION_ARGS_KEY, + "grpc_context", + ) + + def format(self, record): + + trace = getattr(record, constants.TRACE_KEY) + + for key in self.extra_serialise_keys: + if key in trace: + trace[key] = serialise(trace[key]) + + return serialise(trace) diff --git a/nameko_grpc/utils.py b/nameko_grpc/utils.py new file mode 100644 index 0000000..5c02dfd --- /dev/null +++ b/nameko_grpc/utils.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +import collections +from threading import Lock + + +def raisetee(iterable, n=2): + """ Alternative to `itertools.tee` that will raise from all iterators if the + source iterable raises. + + Modified from the "roughly equivalent" example in the documentation at + https://docs.python.org/3/library/itertools.html#itertools.tee + """ + source = iter(iterable) + deques = [collections.deque() for i in range(n)] + + def gen(mydeque): + while True: + if not mydeque: + try: + val = next(source) + except StopIteration: + return + except Exception as exc: + val = exc + for d in deques: + d.append(val) + yield mydeque.popleft() + + def read(generator): + for item in generator: + if isinstance(item, Exception): + raise item + yield item + + return tuple(read(gen(d)) for d in deques) + + +class ThreadSafeTee: + """ Thread-safe wrapper for `itertools.tee` (or `raisetee`) objects. + + Copied from https://stackoverflow.com/questions/6703594/itertools-tee-thread-safe + """ + + def __init__(self, tee, lock): + self.tee = tee + self.lock = lock + + def __iter__(self): + return self + + def __next__(self): + with self.lock: + return next(self.tee) + + def __copy__(self): + return ThreadSafeTee(self.tee.__copy__(), self.lock) + + +def safetee(iterable, n): + """ Replacement for `itertools.tee` that returns `ThreadSafeTee` objects. + """ + lock = Lock() + return (ThreadSafeTee(tee, lock) for tee in raisetee(iterable, n)) + + +class Teeable: + """ Wrapper for `iterable`s that allows them to later be `tee`d + (as in `itertools.tee`) *and* used in a thread-safe manner. + + This is useful for wrapping generators and other iterables that cannot be copied, + such as streaming requests and responses. It is required by extensions which + inspect requests and responses, such as the Nameko Tracer. + """ + + def __init__(self, iterable): + self.iterable = iter(iterable) + + def __iter__(self): + return self + + def __next__(self): + return next(self.iterable) + + def tee(self): + self.iterable, safe_tee = safetee(self.iterable, 2) + return safe_tee diff --git a/setup.cfg b/setup.cfg index 1e02266..a057b40 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,3 +26,5 @@ skip=.tox,.git [tool:pytest] markers = equivalence: mark a test as an equivalence test. +filterwarnings = + ignore:.*non-Enums in containment checks will raise TypeError.*:DeprecationWarning diff --git a/setup.py b/setup.py index 55aaf41..8ca1882 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ packages=find_packages(exclude=["test", "test.*"]), install_requires=["nameko", "h2>=3"], extras_require={ + "tracer": ["nameko_tracer"], "dev": [ "pytest", "grpcio", @@ -30,7 +31,7 @@ "pre-commit", "wrapt", "zmq", - ] + ], }, zip_safe=True, license="Apache License, Version 2.0", diff --git a/test/spec/tracer_nameko.py b/test/spec/tracer_nameko.py new file mode 100644 index 0000000..8632185 --- /dev/null +++ b/test/spec/tracer_nameko.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +from example_nameko import example + +from nameko_grpc.tracer import GrpcTracer + + +class tracer(example): + + tracer = GrpcTracer() diff --git a/test/test_tracer.py b/test/test_tracer.py new file mode 100644 index 0000000..9158547 --- /dev/null +++ b/test/test_tracer.py @@ -0,0 +1,1207 @@ +# -*- coding: utf-8 -*- +""" Test integration with https://github.com/nameko/nameko-tracer +""" + +import json +import logging +import socket +from datetime import datetime + +import pytest +from google.protobuf.json_format import MessageToJson +from grpc import StatusCode +from nameko_tracer.constants import Stage + +from nameko_grpc.constants import Cardinality +from nameko_grpc.context import GrpcContext +from nameko_grpc.exceptions import GrpcError +from nameko_grpc.tracer.adapter import ( + GRPC_CONTEXT, + GRPC_REQUEST, + GRPC_RESPONSE, + GRPC_STREAM, +) +from nameko_grpc.tracer.formatter import GrpcJsonFormatter + + +@pytest.fixture +def client_type(): + return "nameko" # no point testing multiple clients + + +@pytest.fixture +def server(start_nameko_server): + return start_nameko_server("tracer") + + +@pytest.fixture +def caplog(caplog): + with caplog.at_level(logging.INFO): + yield caplog + + +@pytest.fixture +def get_log_records(caplog): + def is_trace(record): + return record.name == "nameko_tracer" + + def is_request(record): + return record.stage == Stage.request + + def is_response(record): + return record.stage == Stage.response + + def is_stream(record): + return getattr(record, "stream_part", False) + + def is_not_stream(record): + return not is_stream(record) + + def match_all(*fns): + def check(val): + return all(fn(val) for fn in fns) + + return check + + def extract_trace(record): + return record.nameko_trace + + def extract_records(): + trace_records = list(filter(is_trace, caplog.records)) + + request_trace = next( + filter(match_all(is_request, is_not_stream), trace_records) + ) + response_trace = next( + filter(match_all(is_response, is_not_stream), trace_records) + ) + + request_stream = list(filter(match_all(is_request, is_stream), trace_records)) + response_stream = list(filter(match_all(is_response, is_stream), trace_records)) + return request_trace, response_trace, request_stream, response_stream + + return extract_records + + +@pytest.fixture +def check_trace(): + def check(record, requires): + data = record.nameko_trace + for key, value in requires.items(): + if callable(value): + assert value(data, key) + else: + assert data[key] == value + + return check + + +@pytest.fixture +def check_format(): + + formatter = GrpcJsonFormatter() + + def check(record, requires): + + formatted = formatter.format(record) + data = json.loads(formatted) + for key, value in requires.items(): + if callable(value): + assert value(data[key]) + else: + assert data[key] == value + + return check + + +@pytest.mark.usefixtures("predictable_call_ids") +class TestEssentialFields: + """ Verify "essential" fields are present on every log trace, including stream + parts. + + Essential fields: + + - hostname + - timestamp + - entrypoint_name + - entrypoint_type + - service + - cardinality + - call_id + - call_id_stack + - stage + - stream_part (for streams) + - stream_age (for streams) + - context_data + + """ + + def test_unary_unary( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary(request, metadata=[("foo", "bar")]) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = { + "hostname": socket.gethostname(), + "timestamp": lambda data, key: isinstance(data[key], datetime), + "entrypoint_name": "unary_unary", + "entrypoint_type": "Grpc", + "service": "example", + "cardinality": Cardinality.UNARY_UNARY, + "call_id": "example.unary_unary.0", + "call_id_stack": ["example.unary_unary.0"], + "context_data": {"foo": "bar"}, + } + + check_trace(request_trace, dict(common, **{"stage": "request"})) + check_format( + request_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "UNARY_UNARY"}, + ) + + check_trace(response_trace, dict(common, **{"stage": "response"})) + check_format( + response_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "UNARY_UNARY"}, + ) + + assert len(request_stream) == 0 + + assert len(result_stream) == 0 + + def test_unary_stream( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A", response_count=2) + responses = list(client.unary_stream(request, metadata=[("foo", "bar")])) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("A", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = { + "hostname": socket.gethostname(), + "timestamp": lambda data, key: isinstance(data[key], datetime), + "entrypoint_name": "unary_stream", + "entrypoint_type": "Grpc", + "service": "example", + "cardinality": Cardinality.UNARY_STREAM, + "call_id": "example.unary_stream.0", + "call_id_stack": ["example.unary_stream.0"], + "context_data": {"foo": "bar"}, + } + + check_trace(request_trace, dict(common, **{"stage": "request"})) + check_format( + request_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "UNARY_STREAM"}, + ) + + check_trace(response_trace, dict(common, **{"stage": "response"})) + check_format( + response_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "UNARY_STREAM"}, + ) + + assert len(request_stream) == 0 + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace( + trace, + dict( + common, + **{ + "stage": "response", + "stream_part": index + 1, + "stream_age": lambda data, key: data[key] > 0, + } + ), + ) + + def test_stream_unary( + self, client, protobufs, get_log_records, check_trace, check_format + ): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + response = client.stream_unary(requests, metadata=[("foo", "bar")]) + assert response.message == "A,B" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = { + "hostname": socket.gethostname(), + "timestamp": lambda data, key: isinstance(data[key], datetime), + "entrypoint_name": "stream_unary", + "entrypoint_type": "Grpc", + "service": "example", + "cardinality": Cardinality.STREAM_UNARY, + "call_id": "example.stream_unary.0", + "call_id_stack": ["example.stream_unary.0"], + "context_data": {"foo": "bar"}, + } + + check_trace(request_trace, dict(common, **{"stage": "request"})) + check_format( + request_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "STREAM_UNARY"}, + ) + + check_trace(response_trace, dict(common, **{"stage": "response"})) + check_format( + response_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "STREAM_UNARY"}, + ) + + assert len(request_stream) == 2 + for index, trace in enumerate(request_stream): + check_trace( + trace, + dict( + common, + **{ + "stage": "request", + "stream_part": index + 1, + "stream_age": lambda data, key: data[key] > 0, + } + ), + ) + + assert len(result_stream) == 0 + + def test_stream_stream( + self, client, protobufs, get_log_records, check_trace, check_format + ): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + responses = list(client.stream_stream(requests, metadata=[("foo", "bar")])) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("B", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = { + "hostname": socket.gethostname(), + "timestamp": lambda data, key: isinstance(data[key], datetime), + "entrypoint_name": "stream_stream", + "entrypoint_type": "Grpc", + "service": "example", + "cardinality": Cardinality.STREAM_STREAM, + "call_id": "example.stream_stream.0", + "call_id_stack": ["example.stream_stream.0"], + "context_data": {"foo": "bar"}, + } + + check_trace(request_trace, dict(common, **{"stage": "request"})) + check_format( + request_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "STREAM_STREAM"}, + ) + + check_trace(response_trace, dict(common, **{"stage": "response"})) + check_format( + response_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "STREAM_STREAM"}, + ) + + assert len(request_stream) == 2 + for index, trace in enumerate(request_stream): + check_trace( + trace, + dict( + common, + **{ + "stage": "request", + "stream_part": index + 1, + "stream_age": lambda data, key: data[key] > 0, + } + ), + ) + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace( + trace, + dict( + common, + **{ + "stage": "response", + "stream_part": index + 1, + "stream_age": lambda data, key: data[key] > 0, + } + ), + ) + + def test_error_before_response( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A") + with pytest.raises(GrpcError) as error: + client.unary_error(request, metadata=[("foo", "bar")]) + assert error.value.status == StatusCode.UNKNOWN + assert error.value.details == "Exception calling application: boom" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = { + "hostname": socket.gethostname(), + "timestamp": lambda data, key: isinstance(data[key], datetime), + "entrypoint_name": "unary_error", + "entrypoint_type": "Grpc", + "service": "example", + "cardinality": Cardinality.UNARY_UNARY, + "call_id": "example.unary_error.0", + "call_id_stack": ["example.unary_error.0"], + "context_data": {"foo": "bar"}, + } + + check_trace(request_trace, dict(common, **{"stage": "request"})) + check_format( + request_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "UNARY_UNARY"}, + ) + + check_trace(response_trace, dict(common, **{"stage": "response"})) + check_format( + response_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "UNARY_UNARY"}, + ) + + assert len(request_stream) == 0 + + assert len(result_stream) == 0 + + def test_error_while_streaming_response( + self, client, protobufs, get_log_records, check_trace, check_format + ): + + # NOTE it's important that the server sleeps between streaming responses + # otherwise it terminates the stream with an error before any parts of the + # response stream are put on the wire + request = protobufs.ExampleRequest(value="A", response_count=10, delay=10) + responses = [] + with pytest.raises(GrpcError) as error: + for response in client.stream_error(request, metadata=[("foo", "bar")]): + responses.append(response) + + assert error.value.status == StatusCode.UNKNOWN + assert error.value.details == "Exception iterating responses: boom" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = { + "hostname": socket.gethostname(), + "timestamp": lambda data, key: isinstance(data[key], datetime), + "entrypoint_name": "stream_error", + "entrypoint_type": "Grpc", + "service": "example", + "cardinality": Cardinality.UNARY_STREAM, + "call_id": "example.stream_error.0", + "call_id_stack": ["example.stream_error.0"], + "context_data": {"foo": "bar"}, + } + + check_trace(request_trace, dict(common, **{"stage": "request"})) + check_format( + request_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "UNARY_STREAM"}, + ) + + check_trace(response_trace, dict(common, **{"stage": "response"})) + check_format( + response_trace, + {"context_data": '{"foo": "bar"}', "cardinality": "UNARY_STREAM"}, + ) + + assert len(request_stream) == 0 + + assert len(result_stream) == 10 + for index, trace in enumerate(result_stream): + check_trace( + trace, + dict( + common, + **{ + "stage": "response", + "stream_part": index + 1, + "stream_age": lambda data, key: data[key] > 0, + } + ), + ) + + +class TestCallArgsField: + def test_unary_unary(self, client, protobufs, get_log_records, check_trace): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary(request) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = { + "call_args": {"context": GRPC_CONTEXT, "request": GRPC_REQUEST}, + "call_args_redacted": False, + } + + check_trace(request_trace, common) + + check_trace(response_trace, common) + + assert len(request_stream) == 0 + + assert len(result_stream) == 0 + + def test_unary_stream(self, client, protobufs, get_log_records, check_trace): + request = protobufs.ExampleRequest(value="A", response_count=2) + responses = list(client.unary_stream(request)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("A", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = { + "call_args": {"context": GRPC_CONTEXT, "request": GRPC_REQUEST}, + "call_args_redacted": False, + } + + check_trace(request_trace, common) + + check_trace(response_trace, common) + + assert len(request_stream) == 0 + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace(trace, common) + + def test_stream_unary(self, client, protobufs, get_log_records, check_trace): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + response = client.stream_unary(requests) + assert response.message == "A,B" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + # streaming request + common = { + "call_args": {"context": GRPC_CONTEXT, "request": GRPC_STREAM}, + "call_args_redacted": False, + } + + check_trace(request_trace, common) + + check_trace(response_trace, common) + + assert len(request_stream) == 2 + for index, trace in enumerate(request_stream): + check_trace(trace, common) + + assert len(result_stream) == 0 + + def test_stream_stream(self, client, protobufs, get_log_records, check_trace): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + responses = list(client.stream_stream(requests)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("B", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + # streaming request + common = { + "call_args": {"context": GRPC_CONTEXT, "request": GRPC_STREAM}, + "call_args_redacted": False, + } + + check_trace(request_trace, common) + + check_trace(response_trace, common) + + assert len(request_stream) == 2 + for index, trace in enumerate(request_stream): + check_trace(trace, common) + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace(trace, common) + + +class TestResponseFields: + def test_unary_unary(self, client, protobufs, get_log_records, check_trace): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary(request) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace( + response_trace, + { + "response": GRPC_RESPONSE, + "response_status": "success", + "response_time": lambda data, key: data[key] > 0, + }, + ) + + assert len(result_stream) == 0 + + def test_unary_stream(self, client, protobufs, get_log_records, check_trace): + request = protobufs.ExampleRequest(value="A", response_count=2) + responses = list(client.unary_stream(request)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("A", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace( + response_trace, + { + "response": GRPC_STREAM, # streaming response + "response_status": None, # still pending + "response_time": lambda data, key: data[key] > 0, + }, + ) + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace( + trace, + { + "response": GRPC_RESPONSE, # individual response + "response_status": "success", + "response_time": lambda data, key: data[key] > 0, + }, + ) + + def test_stream_unary(self, client, protobufs, get_log_records, check_trace): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + response = client.stream_unary(requests) + assert response.message == "A,B" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace( + response_trace, + { + "response": GRPC_RESPONSE, + "response_status": "success", + "response_time": lambda data, key: data[key] > 0, + }, + ) + + assert len(result_stream) == 0 + + def test_stream_stream(self, client, protobufs, get_log_records, check_trace): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + responses = list(client.stream_stream(requests)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("B", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace( + response_trace, + { + "response": GRPC_STREAM, # streaming response + "response_status": None, # still pending + "response_time": lambda data, key: data[key] > 0, + }, + ) + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace( + trace, + { + "response": GRPC_RESPONSE, # individual response + "response_status": "success", + "response_time": lambda data, key: data[key] > 0, + }, + ) + + def test_error_before_response( + self, client, protobufs, get_log_records, check_trace + ): + request = protobufs.ExampleRequest(value="A") + with pytest.raises(GrpcError) as error: + client.unary_error(request) + assert error.value.status == StatusCode.UNKNOWN + assert error.value.details == "Exception calling application: boom" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace( + response_trace, + { + "response": lambda data, key: key not in data, + "response_status": "error", + "response_time": lambda data, key: data[key] > 0, + }, + ) + + assert len(result_stream) == 0 + + def test_error_while_streaming_response( + self, client, protobufs, get_log_records, check_trace + ): + + # NOTE it's important that the server sleeps between streaming responses + # otherwise it terminates the stream with an error before any parts of the + # response stream are put on the wire + request = protobufs.ExampleRequest(value="A", response_count=10, delay=10) + responses = [] + with pytest.raises(GrpcError) as error: + for response in client.stream_error(request): + responses.append(response) + + assert error.value.status == StatusCode.UNKNOWN + assert error.value.details == "Exception iterating responses: boom" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace( + response_trace, + { + "response": GRPC_STREAM, # streaming response + "response_status": None, # still pending + "response_time": lambda data, key: data[key] > 0, + }, + ) + + assert len(result_stream) == 10 + + # check first 9 stream parts + for index, trace in enumerate(result_stream[:-1]): + check_trace( + trace, + { + "response": GRPC_RESPONSE, # individual response + "response_status": "success", + "response_time": lambda data, key: data[key] > 0, + }, + ) + + # check last stream part + check_trace( + result_stream[-1], + { + "response": lambda data, key: key not in data, + "response_status": "error", + "response_time": lambda data, key: data[key] > 0, + }, + ) + + +class TestGrpcRequestFields: + def test_unary_unary( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary(request) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common_trace = {"grpc_request": request} + common_format = {"grpc_request": MessageToJson(request)} + + check_trace(request_trace, common_trace) + check_format(request_trace, common_format) + + check_trace(response_trace, common_trace) + check_format(response_trace, common_format) + + assert len(request_stream) == 0 + + assert len(result_stream) == 0 + + def test_unary_stream( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A", response_count=2) + responses = list(client.unary_stream(request)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("A", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common_trace = {"grpc_request": request} + common_format = {"grpc_request": MessageToJson(request)} + + check_trace(request_trace, common_trace) + check_format(request_trace, common_format) + + check_trace(response_trace, common_trace) + check_format(response_trace, common_format) + + assert len(request_stream) == 0 + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace(response_trace, common_trace) + check_format(response_trace, common_format) + + def test_stream_unary( + self, client, protobufs, get_log_records, check_trace, check_format + ): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + response = client.stream_unary(requests) + assert response.message == "A,B" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = {"grpc_request": GRPC_STREAM} + + check_trace(request_trace, common) + check_format(response_trace, common) + + check_trace(response_trace, common) + check_format(response_trace, common) + + assert len(request_stream) == 2 + for index, trace in enumerate(request_stream): + check_trace(trace, {"grpc_request": requests[index]}) + check_format(trace, {"grpc_request": MessageToJson(requests[index])}) + + assert len(result_stream) == 0 + + def test_stream_stream( + self, client, protobufs, get_log_records, check_trace, check_format + ): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + responses = list(client.stream_stream(requests)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("B", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = {"grpc_request": GRPC_STREAM} + + check_trace(request_trace, common) + check_format(request_trace, common) + + check_trace(response_trace, common) + check_format(response_trace, common) + + assert len(request_stream) == 2 + for index, trace in enumerate(request_stream): + check_trace(trace, {"grpc_request": requests[index]}) + check_format(trace, {"grpc_request": MessageToJson(requests[index])}) + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace(trace, common) + check_format(trace, common) + + +class TestGrpcResponseFields: + def test_unary_unary( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary(request) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace(response_trace, {"grpc_response": response}) + check_format(response_trace, {"grpc_response": MessageToJson(response)}) + + assert len(result_stream) == 0 + + def test_unary_stream( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A", response_count=2) + responses = list(client.unary_stream(request)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("A", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace(response_trace, {"grpc_response": GRPC_STREAM}) + check_format(response_trace, {"grpc_response": GRPC_STREAM}) + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace(trace, {"grpc_response": responses[index]}) + check_format(trace, {"grpc_response": MessageToJson(responses[index])}) + + def test_stream_unary( + self, client, protobufs, get_log_records, check_trace, check_format + ): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + response = client.stream_unary(requests) + assert response.message == "A,B" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace(response_trace, {"grpc_response": response}) + check_format(response_trace, {"grpc_response": MessageToJson(response)}) + + assert len(result_stream) == 0 + + def test_stream_stream( + self, client, protobufs, get_log_records, check_trace, check_format + ): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + responses = list(client.stream_stream(requests)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("B", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace(response_trace, {"grpc_response": GRPC_STREAM}) + check_format(response_trace, {"grpc_response": GRPC_STREAM}) + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace(trace, {"grpc_response": responses[index]}) + check_format(trace, {"grpc_response": MessageToJson(responses[index])}) + + def test_error_before_response( + self, client, protobufs, get_log_records, check_trace + ): + request = protobufs.ExampleRequest(value="A") + with pytest.raises(GrpcError) as error: + client.unary_error(request) + assert error.value.status == StatusCode.UNKNOWN + assert error.value.details == "Exception calling application: boom" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace(response_trace, {"grpc_response": None}) + + assert len(result_stream) == 0 + + def test_error_while_streaming_response( + self, client, protobufs, get_log_records, check_trace + ): + + # NOTE it's important that the server sleeps between streaming responses + # otherwise it terminates the stream with an error before any parts of the + # response stream are put on the wire + request = protobufs.ExampleRequest(value="A", response_count=10, delay=10) + responses = [] + with pytest.raises(GrpcError) as error: + for response in client.stream_error(request): + responses.append(response) + + assert error.value.status == StatusCode.UNKNOWN + assert error.value.details == "Exception iterating responses: boom" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace( + response_trace, {"grpc_response": GRPC_STREAM} # streaming response + ) + + assert len(result_stream) == 10 + + # check first 9 stream parts + for index, trace in enumerate(result_stream[:-1]): + check_trace(trace, {"grpc_response": responses[index]}) + + # check last stream part + check_trace(result_stream[-1], {"grpc_response": None}) + + +class TestGrpcContextField: + def test_unary_unary(self, client, protobufs, get_log_records, check_trace): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary(request) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = {"grpc_context": lambda data, key: isinstance(data[key], GrpcContext)} + + check_trace(request_trace, common) + + check_trace(response_trace, common) + + assert len(request_stream) == 0 + + assert len(result_stream) == 0 + + def test_unary_stream(self, client, protobufs, get_log_records, check_trace): + request = protobufs.ExampleRequest(value="A", response_count=2) + responses = list(client.unary_stream(request)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("A", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = {"grpc_context": lambda data, key: isinstance(data[key], GrpcContext)} + + check_trace(request_trace, common) + + check_trace(response_trace, common) + + assert len(request_stream) == 0 + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace(trace, common) + + def test_stream_unary(self, client, protobufs, get_log_records, check_trace): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + response = client.stream_unary(requests) + assert response.message == "A,B" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = {"grpc_context": lambda data, key: isinstance(data[key], GrpcContext)} + + check_trace(request_trace, common) + + check_trace(response_trace, common) + + assert len(request_stream) == 2 + for index, trace in enumerate(request_stream): + check_trace(trace, common) + + assert len(result_stream) == 0 + + def test_stream_stream(self, client, protobufs, get_log_records, check_trace): + def generate_requests(): + for value in ["A", "B"]: + yield protobufs.ExampleRequest(value=value) + + requests = list(generate_requests()) + responses = list(client.stream_stream(requests)) + assert [(response.message, response.seqno) for response in responses] == [ + ("A", 1), + ("B", 2), + ] + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + common = {"grpc_context": lambda data, key: isinstance(data[key], GrpcContext)} + + check_trace(request_trace, common) + + check_trace(response_trace, common) + + assert len(request_stream) == 2 + for index, trace in enumerate(request_stream): + check_trace(trace, common) + + assert len(result_stream) == 2 + for index, trace in enumerate(result_stream): + check_trace(trace, common) + + def test_invocation_metadata( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary( + request, metadata=[("foo", "foo1"), ("foo", "foo2")] + ) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + expected = json.dumps( + { + "request_metadata": [("foo", "foo1"), ("foo", "foo2")], + "response_headers": [], + "response_trailers": [], + } + ) + + check_format(request_trace, {"grpc_context": expected}) + + check_format(response_trace, {"grpc_context": expected}) + + def test_response_headers( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary( + request, metadata=[("echo-header-foo", "foo1"), ("echo-header-foo", "foo2")] + ) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + expected = json.dumps( + { + "request_metadata": [ + ("echo-header-foo", "foo1"), + ("echo-header-foo", "foo2"), + ], + "response_headers": [("foo", "foo1"), ("foo", "foo2")], + "response_trailers": [], + } + ) + + check_format(request_trace, {"grpc_context": expected}) + + check_format(response_trace, {"grpc_context": expected}) + + def test_response_trailers( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A") + response = client.unary_unary( + request, + metadata=[("echo-trailer-foo", "foo1"), ("echo-trailer-foo", "foo2")], + ) + assert response.message == "A" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + expected = json.dumps( + { + "request_metadata": [ + ("echo-trailer-foo", "foo1"), + ("echo-trailer-foo", "foo2"), + ], + "response_headers": [], + "response_trailers": [("foo", "foo1"), ("foo", "foo2")], + } + ) + + check_format(request_trace, {"grpc_context": expected}) + + check_format(response_trace, {"grpc_context": expected}) + + +class TestExceptionFields: + def test_error_before_response( + self, client, protobufs, get_log_records, check_trace, check_format + ): + request = protobufs.ExampleRequest(value="A") + with pytest.raises(GrpcError) as error: + client.unary_error(request) + assert error.value.status == StatusCode.UNKNOWN + assert error.value.details == "Exception calling application: boom" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace( + response_trace, + { + "exception_value": "boom", + "exception_type": "Error", + "exception_path": "example_nameko.Error", + "exception_args": ["boom"], + "exception_traceback": ( + lambda data, key: 'raise Error("boom")' in data[key] + ), + "exception_expected": True, + }, + ) + + check_format(response_trace, {"exception_args": json.dumps(["boom"])}) + + assert len(result_stream) == 0 + + def test_error_while_streaming_response( + self, client, protobufs, get_log_records, check_trace, check_format + ): + + # NOTE it's important that the server sleeps between streaming responses + # otherwise it terminates the stream with an error before any parts of the + # response stream are put on the wire + request = protobufs.ExampleRequest(value="A", response_count=10, delay=10) + responses = [] + with pytest.raises(GrpcError) as error: + for response in client.stream_error(request): + responses.append(response) + + assert error.value.status == StatusCode.UNKNOWN + assert error.value.details == "Exception iterating responses: boom" + + request_trace, response_trace, request_stream, result_stream = get_log_records() + + check_trace(response_trace, {"response": GRPC_STREAM, "response_status": None}) + + assert len(result_stream) == 10 + + # check first 9 stream parts + for index, trace in enumerate(result_stream[:-1]): + check_trace(trace, {"exception_value": lambda data, key: key not in data}) + + # check last stream part + check_trace( + result_stream[-1], + { + "exception_value": "boom", + "exception_type": "Error", + "exception_path": "example_nameko.Error", + "exception_args": ["boom"], + "exception_traceback": ( + lambda data, key: 'raise Error("boom")' in data[key] + ), + "exception_expected": True, + }, + ) + check_format(result_stream[-1], {"exception_args": json.dumps(["boom"])}) diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..b12d408 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +import random +import time +from collections import defaultdict + +import eventlet +import pytest + +from nameko_grpc.utils import Teeable + + +class TestTeeable: + @pytest.fixture + def tracker(self): + return [] + + @pytest.fixture + def make_generator(self, tracker): + def gen(): + for i in range(10): + tracker.append(i) + yield i + + return gen + + def test_not_iterable(self): + with pytest.raises(TypeError): + Teeable(1) + + def test_tees_are_independent(self, make_generator): + + gen = Teeable(make_generator()) + tee = gen.tee() + + assert next(gen) == 0 + assert next(gen) == 1 + assert next(tee) == 0 + assert next(gen) == 2 + assert next(tee) == 1 + assert next(tee) == 2 + + def test_generator_only_advances_once(self, make_generator, tracker): + gen = Teeable(make_generator()) + tee = gen.tee() + + assert next(gen) == 0 + assert next(gen) == 1 + assert next(tee) == 0 + assert next(tee) == 1 + assert next(gen) == 2 + assert next(tee) == 2 + + assert tracker == [0, 1, 2] + + def test_wrap_after_start(self, make_generator): + generator = make_generator() + + assert next(generator) == 0 + assert next(generator) == 1 + + gen = Teeable(generator) + tee = gen.tee() + + assert next(gen) == 2 + assert next(tee) == 2 + assert next(gen) == 3 + assert next(tee) == 3 + + def test_tee_after_start(self, make_generator): + generator = make_generator() + gen = Teeable(generator) + + assert next(generator) == 0 + assert next(generator) == 1 + + tee = gen.tee() + + assert next(gen) == 2 + assert next(tee) == 2 + assert next(gen) == 3 + assert next(tee) == 3 + + def test_reentrant(self, make_generator, tracker): + + gen = Teeable(make_generator()) + tee1 = gen.tee() + tee2 = gen.tee() + + assert next(gen) == 0 + assert next(gen) == 1 + assert next(tee1) == 0 + assert next(tee1) == 1 + assert next(tee2) == 0 + assert next(tee2) == 1 + + def test_thread_safe(self, make_generator, tracker): + + gen = Teeable(make_generator()) + tee = gen.tee() + + consume_trackers = defaultdict(list) + + def consume(iterable, ident=None): + for i in iterable: + time.sleep(random.random() / 10) + consume_trackers[ident].append(i) + + gt1 = eventlet.spawn(consume, gen, ident="gen") + gt2 = eventlet.spawn(consume, tee, ident="tee") + + gt1.wait() + gt2.wait() + + assert consume_trackers["gen"] == list(range(10)) + assert consume_trackers["tee"] == list(range(10)) + assert tracker == list(range(10)) + + def test_thread_safe_and_reentrant(self, make_generator, tracker): + + gen = Teeable(make_generator()) + tee1 = gen.tee() + tee2 = gen.tee() + tee3 = gen.tee() + + consume_trackers = defaultdict(list) + + def consume(iterable, ident=None): + for i in iterable: + time.sleep(random.random() / 10) + consume_trackers[ident].append(i) + + gt1 = eventlet.spawn(consume, gen, ident="gen") + gt2 = eventlet.spawn(consume, tee1, ident="tee1") + gt3 = eventlet.spawn(consume, tee2, ident="tee2") + gt4 = eventlet.spawn(consume, tee3, ident="tee3") + + gt1.wait() + gt2.wait() + gt3.wait() + gt4.wait() + + assert consume_trackers["gen"] == list(range(10)) + assert consume_trackers["tee1"] == list(range(10)) + assert consume_trackers["tee2"] == list(range(10)) + assert consume_trackers["tee3"] == list(range(10)) + assert tracker == list(range(10)) + + def test_generator_throws(self, tracker): + class Boom(Exception): + pass + + def make_generator(): + for i in range(10): + tracker.append(i) + yield i + raise Boom("boom") + + with pytest.raises(Boom): + list(make_generator()) + + gen = Teeable(make_generator()) + tee = gen.tee() + + with pytest.raises(Boom): + list(gen) + + with pytest.raises(Boom): + list(tee) diff --git a/tox.ini b/tox.ini index beac1da..f30d3e8 100644 --- a/tox.ini +++ b/tox.ini @@ -9,5 +9,5 @@ commands = static: pip install --editable .[dev] static: make static - test: pip install --editable .[dev] + test: pip install --editable .[tracer,dev] test: make test