Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in product check #8296

Merged
merged 4 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,31 @@
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Elastic.Clients.Elasticsearch.Serverless.Requests;
using System.Threading;
using Elastic.Transport;
using Elastic.Transport.Diagnostics;
using Elastic.Transport.Products.Elasticsearch;

#if ELASTICSEARCH_SERVERLESS
using Elastic.Clients.Elasticsearch.Serverless.Requests;
#else
using Elastic.Clients.Elasticsearch.Requests;
#endif

#if ELASTICSEARCH_SERVERLESS
namespace Elastic.Clients.Elasticsearch.Serverless;
#else

namespace Elastic.Clients.Elasticsearch;
#endif

/// <summary>
/// A strongly-typed client for communicating with Elasticsearch server endpoints.
/// </summary>
public partial class ElasticsearchClient
{
private const string OpenTelemetrySpanAttributePrefix = "db.elasticsearch.";

// This should be updated if any of the code uses semantic conventions defined in newer schema versions.
private const string OpenTelemetrySchemaVersion = "https://opentelemetry.io/schemas/1.21.0";

Expand Down Expand Up @@ -82,17 +92,18 @@ internal ElasticsearchClient(ITransport<IElasticsearchClientSettings> transport)
public Serializer SourceSerializer => _transport.Configuration.SourceSerializer;
public ITransport<IElasticsearchClientSettings> Transport => _transport;

private ProductCheckStatus _productCheckStatus;
private partial void SetupNamespaces();

private volatile int _productCheckStatus;
flobernd marked this conversation as resolved.
Show resolved Hide resolved

private enum ProductCheckStatus
{
NotChecked,
Succeeded,
Failed
NotChecked = 0,
InProgress = 1,
flobernd marked this conversation as resolved.
Show resolved Hide resolved
Succeeded = 2,
Failed = 3
}

private partial void SetupNamespaces();

internal TResponse DoRequest<TRequest, TResponse, TRequestParameters>(TRequest request)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
Expand Down Expand Up @@ -133,48 +144,116 @@ private ValueTask<TResponse> DoRequestCoreAsync<TRequest, TResponse, TRequestPar
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
{
if (_productCheckStatus == ProductCheckStatus.Failed)
throw new UnsupportedProductException(UnsupportedProductException.InvalidProductError);
// The product check modifies request parameters and therefore must not be executed concurrently.
// We use a lockless CAS approach to make sure that only a single product check request is executed at a time.
// We do not guarantee that the product check is always performed on the first request.

var (requestModified, hadRequestConfig, originalHeaders) = AttachProductCheckHeaderIfRequired<TRequest, TRequestParameters>(request);
var (resolvedUrl, urlTemplate, resolvedRouteValues, postData) = PrepareRequest<TRequest, TRequestParameters>(request, forceConfiguration);
var openTelemetryData = PrepareOpenTelemetryData<TRequest, TRequestParameters>(request, resolvedRouteValues);
var productCheckStatus = Interlocked.CompareExchange(
ref _productCheckStatus,
(int)ProductCheckStatus.InProgress,
(int)ProductCheckStatus.NotChecked
);

if (_productCheckStatus == ProductCheckStatus.Succeeded && !requestModified)
return productCheckStatus switch
{
if (isAsync)
return new ValueTask<TResponse>(_transport.RequestAsync<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData, cancellationToken));
else
return new ValueTask<TResponse>(_transport.Request<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData));
(int)ProductCheckStatus.NotChecked => SendRequestWithProductCheck(),
(int)ProductCheckStatus.InProgress or
(int)ProductCheckStatus.Succeeded => SendRequest(),
(int)ProductCheckStatus.Failed => throw new UnsupportedProductException(UnsupportedProductException.InvalidProductError),
_ => throw new InvalidOperationException("unreachable")
};

ValueTask<TResponse> SendRequest()
{
var (resolvedUrl, _, resolvedRouteValues, postData) = PrepareRequest<TRequest, TRequestParameters>(request, forceConfiguration);
var openTelemetryData = PrepareOpenTelemetryData<TRequest, TRequestParameters>(request, resolvedRouteValues);

return isAsync
? new ValueTask<TResponse>(_transport
.RequestAsync<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData, cancellationToken))
: new ValueTask<TResponse>(_transport
.Request<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData));
}

return SendRequest(isAsync);
async ValueTask<TResponse> SendRequestWithProductCheck()
{
try
{
return await SendRequestWithProductCheckCore().ConfigureAwait(false);
}
catch
{
// Re-try product check on next request

Interlocked.CompareExchange(
ref _productCheckStatus,
(int)ProductCheckStatus.NotChecked,
(int)ProductCheckStatus.InProgress
);
flobernd marked this conversation as resolved.
Show resolved Hide resolved

async ValueTask<TResponse> SendRequest(bool isAsync)
throw;
}
}

async ValueTask<TResponse> SendRequestWithProductCheckCore()
{
// Attach product check header

var hadRequestConfig = false;
HeadersList? originalHeaders = null;

if (request.RequestParameters.RequestConfiguration is null)
request.RequestParameters.RequestConfiguration = new RequestConfiguration();
else
{
originalHeaders = request.RequestParameters.RequestConfiguration.ResponseHeadersToParse;
hadRequestConfig = true;
}

request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = request.RequestParameters.RequestConfiguration.ResponseHeadersToParse.Count == 0
? new HeadersList("x-elastic-product")
: new HeadersList(request.RequestParameters.RequestConfiguration.ResponseHeadersToParse, "x-elastic-product");

// Send request

var (resolvedUrl, _, resolvedRouteValues, postData) = PrepareRequest<TRequest, TRequestParameters>(request, forceConfiguration);
var openTelemetryData = PrepareOpenTelemetryData<TRequest, TRequestParameters>(request, resolvedRouteValues);

TResponse response;

if (isAsync)
response = await _transport.RequestAsync<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData, cancellationToken).ConfigureAwait(false);
{
response = await _transport
.RequestAsync<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData, cancellationToken)
.ConfigureAwait(false);
}
else
response = _transport.Request<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData);
{
response = _transport
.Request<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData);
}

PostRequestProductCheck<TRequest, TResponse>(request, response);
// Evaluate product check result

if (_productCheckStatus == ProductCheckStatus.Failed)
var productCheckSucceeded = response.ApiCallDetails.TryGetHeader("x-elastic-product", out var values) &&
values.Single().Equals("Elasticsearch", StringComparison.Ordinal);
flobernd marked this conversation as resolved.
Show resolved Hide resolved

_productCheckStatus = productCheckSucceeded
? (int)ProductCheckStatus.Succeeded
: (int)ProductCheckStatus.Failed;

if (_productCheckStatus == (int)ProductCheckStatus.Failed)
flobernd marked this conversation as resolved.
Show resolved Hide resolved
throw new UnsupportedProductException(UnsupportedProductException.InvalidProductError);

if (request.RequestParameters.RequestConfiguration is not null)
{
if (!hadRequestConfig)
{
request.RequestParameters.RequestConfiguration = null;
}
else if (originalHeaders.HasValue && originalHeaders.Value.Count > 0)
{
request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = originalHeaders.Value;
}
}
if (request.RequestParameters.RequestConfiguration is null)
return response;

// Reset request configuration

if (!hadRequestConfig)
request.RequestParameters.RequestConfiguration = null;
else if (originalHeaders is { Count: > 0 })
request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = originalHeaders.Value;

return response;
}
Expand Down Expand Up @@ -215,42 +294,6 @@ private static OpenTelemetryData PrepareOpenTelemetryData<TRequest, TRequestPara
return openTelemetryData;
}

private (bool requestModified, bool hadRequestConfig, HeadersList? originalHeaders) AttachProductCheckHeaderIfRequired<TRequest, TRequestParameters>(TRequest request)
where TRequest : Request<TRequestParameters>
where TRequestParameters : RequestParameters, new()
{
var requestModified = false;
var hadRequestConfig = false;
HeadersList? originalHeaders = null;

// If we have not yet checked the product name, add the product header to the list of headers to parse.
if (_productCheckStatus == ProductCheckStatus.NotChecked)
{
requestModified = true;

if (request.RequestParameters.RequestConfiguration is null)
{
request.RequestParameters.RequestConfiguration = new RequestConfiguration();
}
else
{
originalHeaders = request.RequestParameters.RequestConfiguration.ResponseHeadersToParse;
hadRequestConfig = true;
}

if (request.RequestParameters.RequestConfiguration.ResponseHeadersToParse.Count == 0)
{
request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = new HeadersList("x-elastic-product");
}
else
{
request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = new HeadersList(request.RequestParameters.RequestConfiguration.ResponseHeadersToParse, "x-elastic-product");
}
}

return (requestModified, hadRequestConfig, originalHeaders);
}

private (string resolvedUrl, string urlTemplate, Dictionary<string, string>? resolvedRouteValues, PostData data) PrepareRequest<TRequest, TRequestParameters>(TRequest request,
Action<IRequestConfiguration>? forceConfiguration)
where TRequest : Request<TRequestParameters>
Expand Down Expand Up @@ -278,21 +321,6 @@ private static OpenTelemetryData PrepareOpenTelemetryData<TRequest, TRequestPara
return (resolvedUrl, urlTemplate, routeValues, postData);
}

private void PostRequestProductCheck<TRequest, TResponse>(TRequest request, TResponse response)
where TRequest : Request
where TResponse : TransportResponse, new()
{
if (response.ApiCallDetails.HttpStatusCode.HasValue && response.ApiCallDetails.HttpStatusCode.Value >= 200 && response.ApiCallDetails.HttpStatusCode.Value <= 299 && _productCheckStatus == ProductCheckStatus.NotChecked)
{
if (!response.ApiCallDetails.TryGetHeader("x-elastic-product", out var values) || !values.Single().Equals("Elasticsearch", StringComparison.Ordinal))
{
_productCheckStatus = ProductCheckStatus.Failed;
}

_productCheckStatus = ProductCheckStatus.Succeeded;
}
}

private static void ForceConfiguration<TRequestParameters>(Request<TRequestParameters> request, Action<IRequestConfiguration> forceConfiguration)
where TRequestParameters : RequestParameters, new()
{
Expand Down
Loading
Loading