Skip to content

Commit

Permalink
Attach user identity to router request logs (#15126)
Browse files Browse the repository at this point in the history
* Attach user identity to router request logs

* Add test

* More tests
  • Loading branch information
a2l007 authored Oct 19, 2023
1 parent 5c14b42 commit 7802078
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.Authenticator;
import org.apache.druid.server.security.AuthenticatorMapper;
import org.apache.druid.server.security.AuthorizationUtils;
import org.apache.druid.sql.http.SqlQuery;
import org.apache.druid.sql.http.SqlResource;
import org.eclipse.jetty.client.HttpClient;
Expand Down Expand Up @@ -303,6 +304,7 @@ protected void service(HttpServletRequest request, HttpServletResponse response)

/**
* Rebuilds the {@link SqlQuery} object with sqlQueryId and queryId context parameters if not present
*
* @param sqlQuery the original SqlQuery
* @return an updated sqlQuery object with sqlQueryId and queryId context parameters
*/
Expand Down Expand Up @@ -367,13 +369,16 @@ void handleQueryParseException(
// Log the error message
final String errorMessage = exceptionToReport.getMessage() == null
? "no error message" : exceptionToReport.getMessage();

AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(request);

if (isNativeQuery) {
requestLogger.logNativeQuery(
RequestLogLine.forNative(
null,
DateTimes.nowUtc(),
request.getRemoteAddr(),
new QueryStats(ImmutableMap.of("success", false, "exception", errorMessage))
new QueryStats(ImmutableMap.of("success", false, "exception", errorMessage, "identity", authenticationResult.getIdentity()))
)
);
} else {
Expand All @@ -383,7 +388,7 @@ void handleQueryParseException(
null,
DateTimes.nowUtc(),
request.getRemoteAddr(),
new QueryStats(ImmutableMap.of("success", false, "exception", errorMessage))
new QueryStats(ImmutableMap.of("success", false, "exception", errorMessage, "identity", authenticationResult.getIdentity()))
)
);
}
Expand Down Expand Up @@ -744,6 +749,8 @@ public void onComplete(Result result)
}
emitQueryTime(requestTimeNs, success, sqlQueryId, queryId);

AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req);

//noinspection VariableNotUsedInsideIf
if (sqlQueryId != null) {
// SQL query doesn't have a native query translation in router. Hence, not logging the native query.
Expand All @@ -761,7 +768,9 @@ public void onComplete(Result result)
TimeUnit.NANOSECONDS.toMillis(requestTimeNs),
"success",
success
&& result.getResponse().getStatus() == Status.OK.getStatusCode()
&& result.getResponse().getStatus() == Status.OK.getStatusCode(),
"identity",
authenticationResult.getIdentity()
)
)
)
Expand All @@ -787,7 +796,9 @@ public void onComplete(Result result)
TimeUnit.NANOSECONDS.toMillis(requestTimeNs),
"success",
success
&& result.getResponse().getStatus() == Status.OK.getStatusCode()
&& result.getResponse().getStatus() == Status.OK.getStatusCode(),
"identity",
authenticationResult.getIdentity()
)
)
)
Expand Down Expand Up @@ -824,6 +835,7 @@ public void onFailure(Response response, Throwable failure)

failedQueryCount.incrementAndGet();
emitQueryTime(requestTimeNs, false, sqlQueryId, queryId);
AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req);

//noinspection VariableNotUsedInsideIf
if (sqlQueryId != null) {
Expand All @@ -841,7 +853,9 @@ public void onFailure(Response response, Throwable failure)
"success",
false,
"exception",
errorMessage == null ? "no message" : errorMessage
errorMessage == null ? "no message" : errorMessage,
"identity",
authenticationResult.getIdentity()
)
)
)
Expand Down Expand Up @@ -871,7 +885,9 @@ public void onFailure(Response response, Throwable failure)
"success",
false,
"exception",
errorMessage == null ? "no message" : errorMessage
errorMessage == null ? "no message" : errorMessage,
"identity",
authenticationResult.getIdentity()
)
)
)
Expand All @@ -890,7 +906,12 @@ public void onFailure(Response response, Throwable failure)
super.onFailure(response, failure);
}

private void emitQueryTime(long requestTimeNs, boolean success, @Nullable String sqlQueryId, @Nullable String queryId)
private void emitQueryTime(
long requestTimeNs,
boolean success,
@Nullable String sqlQueryId,
@Nullable String queryId
)
{
QueryMetrics queryMetrics;
if (sqlQueryId != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
import org.apache.druid.server.router.QueryHostFinder;
import org.apache.druid.server.router.RendezvousHashAvaticaConnectionBalancer;
import org.apache.druid.server.security.AllowAllAuthorizer;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.AuthenticatorMapper;
import org.apache.druid.server.security.Authorizer;
import org.apache.druid.server.security.AuthorizerMapper;
Expand Down Expand Up @@ -227,7 +229,7 @@ public void testSqlQueryProxy() throws Exception

Properties properties = new Properties();
properties.setProperty("druid.router.sql.enable", "true");
verifyServletCallsForQuery(query, true, false, hostFinder, properties);
verifyServletCallsForQuery(query, true, false, hostFinder, properties, false);
}

@Test
Expand All @@ -244,7 +246,7 @@ public void testQueryProxy() throws Exception
EasyMock.expect(hostFinder.pickServer(query)).andReturn(new TestServer("http", "1.2.3.4", 9999)).once();
EasyMock.replay(hostFinder);

verifyServletCallsForQuery(query, false, false, hostFinder, new Properties());
verifyServletCallsForQuery(query, false, false, hostFinder, new Properties(), false);
}

@Test
Expand All @@ -258,7 +260,7 @@ public void testJDBCSqlProxy() throws Exception
.once();
EasyMock.replay(hostFinder);

verifyServletCallsForQuery(jdbcRequest, false, true, hostFinder, new Properties());
verifyServletCallsForQuery(jdbcRequest, false, true, hostFinder, new Properties(), false);
}

@Test
Expand Down Expand Up @@ -408,6 +410,7 @@ public void testHandleQueryParseExceptionWithFilterDisabled() throws Exception
new Properties(),
new ServerConfig()
);
Mockito.when(request.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).thenReturn(new AuthenticationResult("userA", "basic", "basic", null));
IOException testException = new IOException(errorMessage);
servlet.handleQueryParseException(request, response, mockMapper, testException, false);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
Expand Down Expand Up @@ -454,6 +457,7 @@ public ErrorResponseTransformStrategy getErrorResponseTransformStrategy()
}
}
);
Mockito.when(request.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).thenReturn(new AuthenticationResult("userA", "basic", "basic", null));
IOException testException = new IOException(errorMessage);
servlet.handleQueryParseException(request, response, mockMapper, testException, false);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
Expand Down Expand Up @@ -501,6 +505,7 @@ public ErrorResponseTransformStrategy getErrorResponseTransformStrategy()
}
}
);
Mockito.when(request.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).thenReturn(new AuthenticationResult("userA", "basic", "basic", null));
IOException testException = new IOException(errorMessage);
servlet.handleQueryParseException(request, response, mockMapper, testException, false);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
Expand All @@ -512,6 +517,46 @@ public ErrorResponseTransformStrategy getErrorResponseTransformStrategy()
Assert.assertNull(((QueryException) captor.getValue()).getHost());
}

@Test
public void testNativeQueryProxyFailure() throws Exception
{
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("foo")
.intervals("2000/P1D")
.granularity(Granularities.ALL)
.context(ImmutableMap.of("queryId", "dummy"))
.build();

final QueryHostFinder hostFinder = EasyMock.createMock(QueryHostFinder.class);
EasyMock.expect(hostFinder.pickServer(query)).andReturn(new TestServer("http", "1.2.3.4", 9999)).once();
EasyMock.replay(hostFinder);

verifyServletCallsForQuery(query, false, false, hostFinder, new Properties(), true);
}

@Test
public void testSqlQueryProxyFailure() throws Exception
{
final SqlQuery query = new SqlQuery(
"SELECT * FROM foo",
ResultFormat.ARRAY,
false,
false,
false,
ImmutableMap.of("sqlQueryId", "dummy"),
null
);
final QueryHostFinder hostFinder = EasyMock.createMock(QueryHostFinder.class);
EasyMock.expect(hostFinder.findServerSql(
query.withOverridenContext(ImmutableMap.of("sqlQueryId", "dummy", "queryId", "dummy")))
).andReturn(new TestServer("http", "1.2.3.4", 9999)).once();
EasyMock.replay(hostFinder);

Properties properties = new Properties();
properties.setProperty("druid.router.sql.enable", "true");
verifyServletCallsForQuery(query, true, false, hostFinder, properties, true);
}

/**
* Verifies that the Servlet calls the right methods the right number of times.
*/
Expand All @@ -520,7 +565,8 @@ private void verifyServletCallsForQuery(
boolean isNativeSql,
boolean isJDBCSql,
QueryHostFinder hostFinder,
Properties properties
Properties properties,
boolean isFailure
) throws Exception
{
final ObjectMapper jsonMapper = TestHelper.makeJsonMapper();
Expand Down Expand Up @@ -587,27 +633,30 @@ public int read()
EasyMock.expectLastCall();
requestMock.setAttribute("org.apache.druid.proxy.to.host.scheme", "http");
EasyMock.expectLastCall();
EasyMock.expect(requestMock.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(new AuthenticationResult("userA", "basic", "basic", null));
if (isFailure) {
EasyMock.expect(requestMock.getRemoteAddr()).andReturn("0.0.0.0:0");
}

EasyMock.replay(requestMock);

final AtomicLong didService = new AtomicLong();
final Request proxyRequestMock = Mockito.spy(Request.class);
final Result result = new Result(
proxyRequestMock,
new HttpResponse(proxyRequestMock, ImmutableList.of())
{
@Override
public HttpFields getHeaders()
{
HttpFields httpFields = new HttpFields();
if (isJDBCSql) {
httpFields.add(new HttpField("X-Druid-SQL-Query-Id", "jdbcDummy"));
} else if (isNativeSql) {
httpFields.add(new HttpField("X-Druid-SQL-Query-Id", "dummy"));
}
return httpFields;
}
HttpResponse response = new HttpResponse(proxyRequestMock, ImmutableList.of())
{
@Override
public HttpFields getHeaders()
{
HttpFields httpFields = new HttpFields();
if (isJDBCSql) {
httpFields.add(new HttpField("X-Druid-SQL-Query-Id", "jdbcDummy"));
} else if (isNativeSql) {
httpFields.add(new HttpField("X-Druid-SQL-Query-Id", "dummy"));
}
);
return httpFields;
}
};
final Result result = new Result(proxyRequestMock, response);
final StubServiceEmitter stubServiceEmitter = new StubServiceEmitter("", "");
final AsyncQueryForwardingServlet servlet = new AsyncQueryForwardingServlet(
new MapQueryToolChestWarehouse(ImmutableMap.of()),
Expand Down Expand Up @@ -640,7 +689,11 @@ protected void doService(
// partial state of the servlet. Hence, only catching the exact exception to avoid possible errors.
// Further, the metric assertions are also done to ensure that the metrics have emitted.
try {
servlet.newProxyResponseListener(requestMock, null).onComplete(result);
if (isFailure) {
servlet.newProxyResponseListener(requestMock, null).onFailure(response, new Throwable("Proxy failed"));
} else {
servlet.newProxyResponseListener(requestMock, null).onComplete(result);
}
}
catch (NullPointerException ignored) {
}
Expand Down

0 comments on commit 7802078

Please sign in to comment.