Skip to content

Commit

Permalink
Add HttpHeaders in broker event listener requestContext (apache#12258)
Browse files Browse the repository at this point in the history
  • Loading branch information
tibrewalpratik17 authored Jan 29, 2024
1 parent 4823802 commit 6cc1915
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.apache.pinot.spi.eventlistener.query.BrokerQueryEventListener;
import org.apache.pinot.spi.eventlistener.query.PinotBrokerQueryEventListenerFactory;
import org.apache.pinot.spi.exception.BadQueryRequestException;
import org.apache.pinot.spi.trace.RequestContext;
import org.apache.pinot.spi.trace.Tracing;
Expand Down Expand Up @@ -259,6 +260,11 @@ public BrokerResponse handleRequest(JsonNode request, @Nullable SqlNodeAndOption

long requestId = _brokerIdGenerator.get();
requestContext.setRequestId(requestId);
if (httpHeaders != null) {
requestContext.setRequestHttpHeaders(httpHeaders.getRequestHeaders().entrySet().stream()
.filter(entry -> PinotBrokerQueryEventListenerFactory.getAllowlistQueryRequestHeaders()
.contains(entry.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
}

// First-stage access control to prevent unauthenticated requests from using up resources. Secondary table-level
// check comes later.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,25 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.apache.pinot.spi.utils.CommonConstants.CONFIG_OF_BROKER_EVENT_LISTENER_CLASS_NAME;
import static org.apache.pinot.spi.utils.CommonConstants.CONFIG_OF_REQUEST_CONTEXT_TRACKED_HEADER_KEYS;
import static org.apache.pinot.spi.utils.CommonConstants.DEFAULT_BROKER_EVENT_LISTENER_CLASS_NAME;


public class PinotBrokerQueryEventListenerFactory {
private static final Logger LOGGER = LoggerFactory.getLogger(PinotBrokerQueryEventListenerFactory.class);
private static BrokerQueryEventListener _brokerQueryEventListener = null;
private static List<String> _allowlistQueryRequestHeaders = new ArrayList<>();

private PinotBrokerQueryEventListenerFactory() {
}
Expand All @@ -44,6 +50,8 @@ private PinotBrokerQueryEventListenerFactory() {
public synchronized static void init(PinotConfiguration eventListenerConfiguration) {
// Initializes BrokerQueryEventListener.
initializeBrokerQueryEventListener(eventListenerConfiguration);
// Initializes request headers
initializeAllowlistQueryRequestHeaders(eventListenerConfiguration);
}

/**
Expand Down Expand Up @@ -78,6 +86,19 @@ private static void initializeBrokerQueryEventListener(PinotConfiguration eventL
+ "Please check if any pinot-event-listener related jar is actually added to the classpath.");
}

/**
* Initializes allowlist request-headers to extract from query request.
* @param eventListenerConfiguration The subset of the configuration containing the event-listener-related keys
*/
private static void initializeAllowlistQueryRequestHeaders(PinotConfiguration eventListenerConfiguration) {
List<String> allowlistQueryRequestHeaders =
Splitter.on(",").omitEmptyStrings().trimResults()
.splitToList(eventListenerConfiguration.getProperty(CONFIG_OF_REQUEST_CONTEXT_TRACKED_HEADER_KEYS, ""));

LOGGER.info("{}: allowlist headers will be used for PinotBrokerQueryEventListener", allowlistQueryRequestHeaders);
registerAllowlistQueryRequestHeaders(allowlistQueryRequestHeaders);
}

/**
* Registers a broker event listener.
*/
Expand All @@ -86,6 +107,14 @@ private static void registerBrokerEventListener(BrokerQueryEventListener brokerQ
_brokerQueryEventListener = brokerQueryEventListener;
}

/**
* Registers allowlist http headers for query-requests.
*/
private static void registerAllowlistQueryRequestHeaders(List<String> allowlistQueryRequestHeaders) {
LOGGER.info("Registering query request headers allowlist : {}", allowlistQueryRequestHeaders);
_allowlistQueryRequestHeaders = ImmutableList.copyOf(allowlistQueryRequestHeaders);
}

/**
* Returns the brokerQueryEventListener. If the BrokerQueryEventListener is null,
* first creates and initializes the BrokerQueryEventListener.
Expand All @@ -103,4 +132,9 @@ public static synchronized BrokerQueryEventListener getBrokerQueryEventListener(
public static BrokerQueryEventListener getBrokerQueryEventListener() {
return getBrokerQueryEventListener(new PinotConfiguration(Collections.emptyMap()));
}

@VisibleForTesting
public static List<String> getAllowlistQueryRequestHeaders() {
return _allowlistQueryRequestHeaders;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public class DefaultRequestContext implements RequestScope {
private long _explainPlanNumMatchAllFilterSegments;
private Map<String, String> _traceInfo = new HashMap<>();
private List<String> _processingExceptions = new ArrayList<>();
private Map<String, List<String>> _requestHttpHeaders = new HashMap<>();

public DefaultRequestContext() {
}
Expand Down Expand Up @@ -562,6 +563,16 @@ public void setProcessingExceptions(List<String> processingExceptions) {
_processingExceptions.addAll(processingExceptions);
}

@Override
public Map<String, List<String>> getRequestHttpHeaders() {
return _requestHttpHeaders;
}

@Override
public void setRequestHttpHeaders(Map<String, List<String>> requestHttpHeaders) {
_requestHttpHeaders.putAll(requestHttpHeaders);
}

@Override
public void close() {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ default boolean isSampledRequest() {

void setProcessingExceptions(List<String> processingExceptions);

Map<String, List<String>> getRequestHttpHeaders();

void setRequestHttpHeaders(Map<String, List<String>> requestHttpHeaders);

enum FanoutType {
OFFLINE, REALTIME, HYBRID
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ private CommonConstants() {
public static final String UNKNOWN = "unknown";
public static final String CONFIG_OF_METRICS_FACTORY_CLASS_NAME = "factory.className";
public static final String CONFIG_OF_BROKER_EVENT_LISTENER_CLASS_NAME = "factory.className";
public static final String CONFIG_OF_REQUEST_CONTEXT_TRACKED_HEADER_KEYS = "request.context.tracked.header.keys";
public static final String DEFAULT_METRICS_FACTORY_CLASS_NAME =
"org.apache.pinot.plugin.metrics.yammer.YammerMetricsFactory";
public static final String DEFAULT_BROKER_EVENT_LISTENER_CLASS_NAME =
Expand Down

0 comments on commit 6cc1915

Please sign in to comment.