diff --git a/src/otlp.rs b/src/otlp.rs index ba2ee7c..1dbc420 100644 --- a/src/otlp.rs +++ b/src/otlp.rs @@ -14,6 +14,19 @@ use opentelemetry_sdk::{ use std::time::Duration; use tracing::Level; +#[derive(Default, Debug, PartialEq)] +pub enum Protocol { + #[default] + Grpc, + HttpProtobuf, +} + +impl std::fmt::Display for Protocol { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Debug::fmt(self, f) + } +} + #[must_use] pub fn identity( v: opentelemetry_otlp::OtlpTracePipeline, @@ -34,9 +47,9 @@ where let (protocol, endpoint) = infer_protocol_and_endpoint(maybe_protocol.as_deref(), maybe_endpoint.as_deref()); tracing::debug!(target: "otel::setup", OTEL_EXPORTER_OTLP_TRACES_ENDPOINT = endpoint); - tracing::debug!(target: "otel::setup", OTEL_EXPORTER_OTLP_TRACES_PROTOCOL = protocol); - let exporter: SpanExporterBuilder = match protocol.as_str() { - "http/protobuf" => opentelemetry_otlp::new_exporter() + tracing::debug!(target: "otel::setup", OTEL_EXPORTER_OTLP_TRACES_PROTOCOL = protocol.to_string()); + let exporter: SpanExporterBuilder = match protocol { + Protocol::HttpProtobuf => opentelemetry_otlp::new_exporter() .http() .with_http_client(HyperClient::new_with_timeout( hyper::Client::new(), @@ -45,7 +58,7 @@ where .with_endpoint(endpoint) .with_headers(read_headers_from_env()) .into(), - _ => opentelemetry_otlp::new_exporter() + Protocol::Grpc => opentelemetry_otlp::new_exporter() .tonic() .with_endpoint(endpoint) .into(), @@ -141,10 +154,10 @@ where fn infer_protocol_and_endpoint( maybe_protocol: Option<&str>, maybe_endpoint: Option<&str>, -) -> (String, String) { - let maybe_protocol = match maybe_protocol { - Some("grpc") => Some("grpc"), - Some("http") | Some("http/protobuf") => Some("http/protobuf"), +) -> (Protocol, String) { + let protocol = match maybe_protocol { + Some("grpc") => Some(Protocol::Grpc), + Some("http") | Some("http/protobuf") => Some(Protocol::HttpProtobuf), Some(other) => { tracing::warn!(target: "otel::setup", "unsupported protocol {other:?}"); None @@ -152,20 +165,20 @@ fn infer_protocol_and_endpoint( None => None, }; - let protocol = maybe_protocol.unwrap_or_else(|| { + let protocol = protocol.unwrap_or_else(|| { if maybe_endpoint.map_or(false, |e| e.contains(":4317")) { - "grpc" + Protocol::Grpc } else { - "http/protobuf" + Protocol::HttpProtobuf } }); let endpoint = match protocol { - "http/protobuf" => maybe_endpoint.unwrap_or("http://localhost:4318"), //Devskim: ignore DS137138 - _ => maybe_endpoint.unwrap_or("http://localhost:4317"), //Devskim: ignore DS137138 + Protocol::HttpProtobuf => maybe_endpoint.unwrap_or("http://localhost:4318"), //Devskim: ignore DS137138 + Protocol::Grpc => maybe_endpoint.unwrap_or("http://localhost:4317"), //Devskim: ignore DS137138 }; - (protocol.to_string(), endpoint.to_string()) + (protocol, endpoint.to_string()) } #[cfg(test)] @@ -174,40 +187,41 @@ mod tests { use rstest::rstest; use super::*; + use Protocol::*; #[rstest] - #[case(None, None, "http/protobuf", "http://localhost:4318")] //Devskim: ignore DS137138 - #[case(Some("http/protobuf"), None, "http/protobuf", "http://localhost:4318")] //Devskim: ignore DS137138 - #[case(Some("http"), None, "http/protobuf", "http://localhost:4318")] //Devskim: ignore DS137138 - #[case(Some("grpc"), None, "grpc", "http://localhost:4317")] //Devskim: ignore DS137138 - #[case(None, Some("http://localhost:4317"), "grpc", "http://localhost:4317")] + #[case(None, None, HttpProtobuf, "http://localhost:4318")] //Devskim: ignore DS137138 + #[case(Some("http/protobuf"), None, HttpProtobuf, "http://localhost:4318")] //Devskim: ignore DS137138 + #[case(Some("http"), None, HttpProtobuf, "http://localhost:4318")] //Devskim: ignore DS137138 + #[case(Some("grpc"), None, Grpc, "http://localhost:4317")] //Devskim: ignore DS137138 + #[case(None, Some("http://localhost:4317"), Grpc, "http://localhost:4317")] #[case( Some("http/protobuf"), Some("http://localhost:4318"), //Devskim: ignore DS137138 - "http/protobuf", + HttpProtobuf, "http://localhost:4318" //Devskim: ignore DS137138 )] #[case( Some("http/protobuf"), Some("https://examples.com:4318"), - "http/protobuf", + HttpProtobuf, "https://examples.com:4318" )] #[case( Some("http/protobuf"), Some("https://examples.com:4317"), - "http/protobuf", + HttpProtobuf, "https://examples.com:4317" )] fn test_infer_protocol_and_endpoint( #[case] traces_protocol: Option<&str>, #[case] traces_endpoint: Option<&str>, - #[case] expected_protocol: &str, + #[case] expected_protocol: Protocol, #[case] expected_endpoint: &str, ) { assert!( infer_protocol_and_endpoint(traces_protocol, traces_endpoint) - == (expected_protocol.to_string(), expected_endpoint.to_string()) + == (expected_protocol, expected_endpoint.to_string()) ); } }