diff --git a/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java b/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java index 75b13a39f1ff..5134a8109b8a 100644 --- a/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java +++ b/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java @@ -56,6 +56,7 @@ import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.Authenticator; import org.apache.druid.server.security.AuthenticatorMapper; +import org.apache.druid.server.security.AuthorizationUtils; import org.apache.druid.sql.http.SqlQuery; import org.apache.druid.sql.http.SqlResource; import org.eclipse.jetty.client.HttpClient; @@ -303,6 +304,7 @@ protected void service(HttpServletRequest request, HttpServletResponse response) /** * Rebuilds the {@link SqlQuery} object with sqlQueryId and queryId context parameters if not present + * * @param sqlQuery the original SqlQuery * @return an updated sqlQuery object with sqlQueryId and queryId context parameters */ @@ -367,13 +369,16 @@ void handleQueryParseException( // Log the error message final String errorMessage = exceptionToReport.getMessage() == null ? "no error message" : exceptionToReport.getMessage(); + + AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(request); + if (isNativeQuery) { requestLogger.logNativeQuery( RequestLogLine.forNative( null, DateTimes.nowUtc(), request.getRemoteAddr(), - new QueryStats(ImmutableMap.of("success", false, "exception", errorMessage)) + new QueryStats(ImmutableMap.of("success", false, "exception", errorMessage, "identity", authenticationResult.getIdentity())) ) ); } else { @@ -383,7 +388,7 @@ void handleQueryParseException( null, DateTimes.nowUtc(), request.getRemoteAddr(), - new QueryStats(ImmutableMap.of("success", false, "exception", errorMessage)) + new QueryStats(ImmutableMap.of("success", false, "exception", errorMessage, "identity", authenticationResult.getIdentity())) ) ); } @@ -744,6 +749,8 @@ public void onComplete(Result result) } emitQueryTime(requestTimeNs, success, sqlQueryId, queryId); + AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req); + //noinspection VariableNotUsedInsideIf if (sqlQueryId != null) { // SQL query doesn't have a native query translation in router. Hence, not logging the native query. @@ -761,7 +768,9 @@ public void onComplete(Result result) TimeUnit.NANOSECONDS.toMillis(requestTimeNs), "success", success - && result.getResponse().getStatus() == Status.OK.getStatusCode() + && result.getResponse().getStatus() == Status.OK.getStatusCode(), + "identity", + authenticationResult.getIdentity() ) ) ) @@ -787,7 +796,9 @@ public void onComplete(Result result) TimeUnit.NANOSECONDS.toMillis(requestTimeNs), "success", success - && result.getResponse().getStatus() == Status.OK.getStatusCode() + && result.getResponse().getStatus() == Status.OK.getStatusCode(), + "identity", + authenticationResult.getIdentity() ) ) ) @@ -824,6 +835,7 @@ public void onFailure(Response response, Throwable failure) failedQueryCount.incrementAndGet(); emitQueryTime(requestTimeNs, false, sqlQueryId, queryId); + AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req); //noinspection VariableNotUsedInsideIf if (sqlQueryId != null) { @@ -841,7 +853,9 @@ public void onFailure(Response response, Throwable failure) "success", false, "exception", - errorMessage == null ? "no message" : errorMessage + errorMessage == null ? "no message" : errorMessage, + "identity", + authenticationResult.getIdentity() ) ) ) @@ -871,7 +885,9 @@ public void onFailure(Response response, Throwable failure) "success", false, "exception", - errorMessage == null ? "no message" : errorMessage + errorMessage == null ? "no message" : errorMessage, + "identity", + authenticationResult.getIdentity() ) ) ) @@ -890,7 +906,12 @@ public void onFailure(Response response, Throwable failure) super.onFailure(response, failure); } - private void emitQueryTime(long requestTimeNs, boolean success, @Nullable String sqlQueryId, @Nullable String queryId) + private void emitQueryTime( + long requestTimeNs, + boolean success, + @Nullable String sqlQueryId, + @Nullable String queryId + ) { QueryMetrics queryMetrics; if (sqlQueryId != null) { diff --git a/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java b/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java index 6facaa547780..54238fe8cce6 100644 --- a/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java +++ b/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java @@ -65,6 +65,8 @@ import org.apache.druid.server.router.QueryHostFinder; import org.apache.druid.server.router.RendezvousHashAvaticaConnectionBalancer; import org.apache.druid.server.security.AllowAllAuthorizer; +import org.apache.druid.server.security.AuthConfig; +import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.AuthenticatorMapper; import org.apache.druid.server.security.Authorizer; import org.apache.druid.server.security.AuthorizerMapper; @@ -227,7 +229,7 @@ public void testSqlQueryProxy() throws Exception Properties properties = new Properties(); properties.setProperty("druid.router.sql.enable", "true"); - verifyServletCallsForQuery(query, true, false, hostFinder, properties); + verifyServletCallsForQuery(query, true, false, hostFinder, properties, false); } @Test @@ -244,7 +246,7 @@ public void testQueryProxy() throws Exception EasyMock.expect(hostFinder.pickServer(query)).andReturn(new TestServer("http", "1.2.3.4", 9999)).once(); EasyMock.replay(hostFinder); - verifyServletCallsForQuery(query, false, false, hostFinder, new Properties()); + verifyServletCallsForQuery(query, false, false, hostFinder, new Properties(), false); } @Test @@ -258,7 +260,7 @@ public void testJDBCSqlProxy() throws Exception .once(); EasyMock.replay(hostFinder); - verifyServletCallsForQuery(jdbcRequest, false, true, hostFinder, new Properties()); + verifyServletCallsForQuery(jdbcRequest, false, true, hostFinder, new Properties(), false); } @Test @@ -408,6 +410,7 @@ public void testHandleQueryParseExceptionWithFilterDisabled() throws Exception new Properties(), new ServerConfig() ); + Mockito.when(request.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).thenReturn(new AuthenticationResult("userA", "basic", "basic", null)); IOException testException = new IOException(errorMessage); servlet.handleQueryParseException(request, response, mockMapper, testException, false); ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); @@ -454,6 +457,7 @@ public ErrorResponseTransformStrategy getErrorResponseTransformStrategy() } } ); + Mockito.when(request.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).thenReturn(new AuthenticationResult("userA", "basic", "basic", null)); IOException testException = new IOException(errorMessage); servlet.handleQueryParseException(request, response, mockMapper, testException, false); ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); @@ -501,6 +505,7 @@ public ErrorResponseTransformStrategy getErrorResponseTransformStrategy() } } ); + Mockito.when(request.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).thenReturn(new AuthenticationResult("userA", "basic", "basic", null)); IOException testException = new IOException(errorMessage); servlet.handleQueryParseException(request, response, mockMapper, testException, false); ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); @@ -512,6 +517,46 @@ public ErrorResponseTransformStrategy getErrorResponseTransformStrategy() Assert.assertNull(((QueryException) captor.getValue()).getHost()); } + @Test + public void testNativeQueryProxyFailure() throws Exception + { + final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + .dataSource("foo") + .intervals("2000/P1D") + .granularity(Granularities.ALL) + .context(ImmutableMap.of("queryId", "dummy")) + .build(); + + final QueryHostFinder hostFinder = EasyMock.createMock(QueryHostFinder.class); + EasyMock.expect(hostFinder.pickServer(query)).andReturn(new TestServer("http", "1.2.3.4", 9999)).once(); + EasyMock.replay(hostFinder); + + verifyServletCallsForQuery(query, false, false, hostFinder, new Properties(), true); + } + + @Test + public void testSqlQueryProxyFailure() throws Exception + { + final SqlQuery query = new SqlQuery( + "SELECT * FROM foo", + ResultFormat.ARRAY, + false, + false, + false, + ImmutableMap.of("sqlQueryId", "dummy"), + null + ); + final QueryHostFinder hostFinder = EasyMock.createMock(QueryHostFinder.class); + EasyMock.expect(hostFinder.findServerSql( + query.withOverridenContext(ImmutableMap.of("sqlQueryId", "dummy", "queryId", "dummy"))) + ).andReturn(new TestServer("http", "1.2.3.4", 9999)).once(); + EasyMock.replay(hostFinder); + + Properties properties = new Properties(); + properties.setProperty("druid.router.sql.enable", "true"); + verifyServletCallsForQuery(query, true, false, hostFinder, properties, true); + } + /** * Verifies that the Servlet calls the right methods the right number of times. */ @@ -520,7 +565,8 @@ private void verifyServletCallsForQuery( boolean isNativeSql, boolean isJDBCSql, QueryHostFinder hostFinder, - Properties properties + Properties properties, + boolean isFailure ) throws Exception { final ObjectMapper jsonMapper = TestHelper.makeJsonMapper(); @@ -587,27 +633,30 @@ public int read() EasyMock.expectLastCall(); requestMock.setAttribute("org.apache.druid.proxy.to.host.scheme", "http"); EasyMock.expectLastCall(); + EasyMock.expect(requestMock.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(new AuthenticationResult("userA", "basic", "basic", null)); + if (isFailure) { + EasyMock.expect(requestMock.getRemoteAddr()).andReturn("0.0.0.0:0"); + } + EasyMock.replay(requestMock); final AtomicLong didService = new AtomicLong(); final Request proxyRequestMock = Mockito.spy(Request.class); - final Result result = new Result( - proxyRequestMock, - new HttpResponse(proxyRequestMock, ImmutableList.of()) - { - @Override - public HttpFields getHeaders() - { - HttpFields httpFields = new HttpFields(); - if (isJDBCSql) { - httpFields.add(new HttpField("X-Druid-SQL-Query-Id", "jdbcDummy")); - } else if (isNativeSql) { - httpFields.add(new HttpField("X-Druid-SQL-Query-Id", "dummy")); - } - return httpFields; - } + HttpResponse response = new HttpResponse(proxyRequestMock, ImmutableList.of()) + { + @Override + public HttpFields getHeaders() + { + HttpFields httpFields = new HttpFields(); + if (isJDBCSql) { + httpFields.add(new HttpField("X-Druid-SQL-Query-Id", "jdbcDummy")); + } else if (isNativeSql) { + httpFields.add(new HttpField("X-Druid-SQL-Query-Id", "dummy")); } - ); + return httpFields; + } + }; + final Result result = new Result(proxyRequestMock, response); final StubServiceEmitter stubServiceEmitter = new StubServiceEmitter("", ""); final AsyncQueryForwardingServlet servlet = new AsyncQueryForwardingServlet( new MapQueryToolChestWarehouse(ImmutableMap.of()), @@ -640,7 +689,11 @@ protected void doService( // partial state of the servlet. Hence, only catching the exact exception to avoid possible errors. // Further, the metric assertions are also done to ensure that the metrics have emitted. try { - servlet.newProxyResponseListener(requestMock, null).onComplete(result); + if (isFailure) { + servlet.newProxyResponseListener(requestMock, null).onFailure(response, new Throwable("Proxy failed")); + } else { + servlet.newProxyResponseListener(requestMock, null).onComplete(result); + } } catch (NullPointerException ignored) { }