Skip to content

Commit

Permalink
Audit more endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
kfaraz committed Dec 4, 2023
1 parent 9cb9dc6 commit 74860c7
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public Response taskPost(
auditManager.doAudit(
AuditEvent.builder()
.key(task.getDataSource())
.type("submit.ingestion.task")
.type("ingestion.batch")
.payload(new TaskIdentifier(task.getId(), task.getGroupId(), task.getType()))
.auditInfo(new AuditInfo(author, comment, req.getRemoteAddr()))
.build()
Expand Down
22 changes: 7 additions & 15 deletions processing/src/main/java/org/apache/druid/audit/AuditManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public interface AuditManager
* audited changes. Only SQL-based implementations need to implement this method,
* other implementations call {@link #doAudit} by default.
*
* @param AuditEvent
* @param handle JDBI Handle representing connection to the database
* @param event Event to audit
* @param handle JDBI Handle representing connection to the database
*/
default void doAudit(AuditEvent event, Handle handle) throws IOException
{
Expand All @@ -53,10 +53,7 @@ default void doAudit(AuditEvent event, Handle handle) throws IOException
* Fetches audit entries made for the given key, type and interval. Implementations
* that do not maintain an audit history should return an empty list.
*
* @param key
* @param type
* @param interval
* @return list of AuditEntries satisfying the passed parameters
* @return List of recorded audit events satisfying the passed parameters.
*/
List<AuditEvent> fetchAuditHistory(String key, String type, Interval interval);

Expand All @@ -65,34 +62,29 @@ default void doAudit(AuditEvent event, Handle handle) throws IOException
*
* @param type Type of audit entry
* @param interval Eligible interval for audit time
* @return List of audit entries satisfying the passed parameters.
* @return List of recorded audit events satisfying the passed parameters.
*/
List<AuditEvent> fetchAuditHistory(String type, Interval interval);

/**
* Provides last N entries of audit history for given key, type
*
* @param key
* @param type
* @param limit
* @return list of AuditEntries satisfying the passed parameters
* @return list of recorded audit events satisfying the passed parameters
*/
List<AuditEvent> fetchAuditHistory(String key, String type, int limit);

/**
* Provides last N entries of audit history for given type
*
* @param type type of AuditEvent
* @param limit
* @return list of AuditEntries satisfying the passed parameters
* @return List of recorded audit events satisfying the passed parameters.
*/
List<AuditEvent> fetchAuditHistory(String type, int limit);

/**
* Remove audit logs created older than the given timestamp.
*
* @param timestamp timestamp in milliseconds
* @return number of audit logs removed
* @return Number of audit logs removed
*/
int removeAuditLogsOlderThan(long timestamp);
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public void testSet()

jacksonConfigManager.set(key, val, auditInfo);

ArgumentCaptor<AuditEvent.Builder> auditCapture = ArgumentCaptor.forClass(AuditEvent.Builder.class);
ArgumentCaptor<AuditEvent> auditCapture = ArgumentCaptor.forClass(AuditEvent.class);
Mockito.verify(mockAuditManager).doAudit(auditCapture.capture());
Assert.assertNotNull(auditCapture.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@
import org.apache.druid.error.InvalidInput;
import org.apache.druid.guice.annotations.Json;
import org.apache.druid.guice.annotations.JsonNonNull;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;

import java.io.IOException;

public class AuditSerdeHelper implements AuditEvent.PayloadDeserializer
public class AuditSerdeHelper
{
/**
* Default message stored instead of the actual audit payload if the audit
* payload size exceeds the maximum size limit.
*/
private static final String PAYLOAD_TRUNCATED_MSG =
"Payload truncated as it exceeds 'druid.audit.manager.maxPayloadSizeBytes'";
private static final String SERIALIZE_ERROR_MSG =
"Error serializing payload";
private static final Logger log = new Logger(AuditSerdeHelper.class);

private final ObjectMapper jsonMapper;
private final ObjectMapper jsonMapperSkipNulls;
Expand Down Expand Up @@ -72,7 +75,6 @@ public AuditRecord processAuditEvent(AuditEvent event)
);
}

@Override
public <T> T deserializePayloadFromString(String serializedPayload, Class<T> clazz)
{
if (serializedPayload == null || serializedPayload.isEmpty()) {
Expand Down Expand Up @@ -105,7 +107,9 @@ private String serializePayloadToString(Object payload)
: jsonMapper.writeValueAsString(payload);
}
catch (IOException e) {
throw new ISE(e, "Could not serialize audit payload[%s]", payload);
// Do not throw exception, only log error
log.error(e, "Could not serialize audit payload[%s]", payload);
return SERIALIZE_ERROR_MSG;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import com.sun.jersey.spi.container.ResourceFilters;
import it.unimi.dsi.fastutil.objects.Object2LongMap;
import org.apache.commons.lang.StringUtils;
import org.apache.druid.audit.AuditEvent;
import org.apache.druid.audit.AuditInfo;
import org.apache.druid.audit.AuditManager;
import org.apache.druid.client.CoordinatorServerView;
import org.apache.druid.client.DruidDataSource;
import org.apache.druid.client.DruidServer;
Expand Down Expand Up @@ -70,7 +73,9 @@
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
Expand Down Expand Up @@ -110,6 +115,7 @@ public class DataSourcesResource
private final OverlordClient overlordClient;
private final AuthorizerMapper authorizerMapper;
private final DruidCoordinator coordinator;
private final AuditManager auditManager;

@Inject
public DataSourcesResource(
Expand All @@ -118,7 +124,8 @@ public DataSourcesResource(
MetadataRuleManager metadataRuleManager,
@Nullable OverlordClient overlordClient,
AuthorizerMapper authorizerMapper,
DruidCoordinator coordinator
DruidCoordinator coordinator,
AuditManager auditManager
)
{
this.serverInventoryView = serverInventoryView;
Expand All @@ -127,6 +134,7 @@ public DataSourcesResource(
this.overlordClient = overlordClient;
this.authorizerMapper = authorizerMapper;
this.coordinator = coordinator;
this.auditManager = auditManager;
}

@GET
Expand Down Expand Up @@ -220,13 +228,19 @@ public Response markAsUsedNonOvershadowedSegments(
@Consumes(MediaType.APPLICATION_JSON)
public Response markSegmentsAsUnused(
@PathParam("dataSourceName") final String dataSourceName,
final MarkDataSourceSegmentsPayload payload
final MarkDataSourceSegmentsPayload payload,
@HeaderParam(AuditManager.X_DRUID_AUTHOR) @DefaultValue("") final String author,
@HeaderParam(AuditManager.X_DRUID_COMMENT) @DefaultValue("") final String comment,
@Context final HttpServletRequest req
)
{
MarkSegments markSegments = () -> {
final Interval interval = payload.getInterval();
final int numUpdatedSegments;
final Object auditPayload;
if (interval != null) {
return segmentsMetadataManager.markAsUnusedSegmentsInInterval(dataSourceName, interval);
numUpdatedSegments = segmentsMetadataManager.markAsUnusedSegmentsInInterval(dataSourceName, interval);
auditPayload = Collections.singletonMap("interval", interval);
} else {
final Set<SegmentId> segmentIds =
payload.getSegmentIds()
Expand All @@ -236,12 +250,24 @@ public Response markSegmentsAsUnused(
.collect(Collectors.toSet());

// Note: segments for the "wrong" datasource are ignored.
return segmentsMetadataManager.markSegmentsAsUnused(
numUpdatedSegments = segmentsMetadataManager.markSegmentsAsUnused(
segmentIds.stream()
.filter(segmentId -> segmentId.getDataSource().equals(dataSourceName))
.collect(Collectors.toSet())
);
auditPayload = Collections.singletonMap("segmentIds", segmentIds);
}
if (author != null && !author.isEmpty()) {
auditManager.doAudit(
AuditEvent.builder()
.key(dataSourceName)
.type("segments.markUnused")
.payload(auditPayload)
.auditInfo(new AuditInfo(author, comment, req.getRemoteAddr()))
.build()
);
}
return numUpdatedSegments;
};
return doMarkSegmentsWithPayload("markSegmentsAsUnused", dataSourceName, payload, markSegments);
}
Expand Down Expand Up @@ -312,7 +338,8 @@ private static Response doMarkSegments(String method, String dataSourceName, Mar
public Response markAsUnusedAllSegmentsOrKillUnusedSegmentsInInterval(
@PathParam("dataSourceName") final String dataSourceName,
@QueryParam("kill") final String kill,
@QueryParam("interval") final String interval
@QueryParam("interval") final String interval,
@Context HttpServletRequest req
)
{
if (overlordClient == null) {
Expand All @@ -321,7 +348,7 @@ public Response markAsUnusedAllSegmentsOrKillUnusedSegmentsInInterval(

boolean killSegments = kill != null && Boolean.valueOf(kill);
if (killSegments) {
return killUnusedSegmentsInInterval(dataSourceName, interval);
return killUnusedSegmentsInInterval(dataSourceName, interval, null, null, req);
} else {
MarkSegments markSegments = () -> segmentsMetadataManager.markAsUnusedAllSegmentsInDataSource(dataSourceName);
return doMarkSegments("markAsUnusedAllSegments", dataSourceName, markSegments);
Expand All @@ -334,7 +361,10 @@ public Response markAsUnusedAllSegmentsOrKillUnusedSegmentsInInterval(
@Produces(MediaType.APPLICATION_JSON)
public Response killUnusedSegmentsInInterval(
@PathParam("dataSourceName") final String dataSourceName,
@PathParam("interval") final String interval
@PathParam("interval") final String interval,
@HeaderParam(AuditManager.X_DRUID_AUTHOR) @DefaultValue("") final String author,
@HeaderParam(AuditManager.X_DRUID_COMMENT) @DefaultValue("") final String comment,
@Context final HttpServletRequest req
)
{
if (overlordClient == null) {
Expand All @@ -345,7 +375,18 @@ public Response killUnusedSegmentsInInterval(
}
final Interval theInterval = Intervals.of(interval.replace('_', '/'));
try {
FutureUtils.getUnchecked(overlordClient.runKillTask("api-issued", dataSourceName, theInterval, null), true);
final String killTaskId = FutureUtils.getUnchecked(
overlordClient.runKillTask("api-issued", dataSourceName, theInterval, null),
true
);
auditManager.doAudit(
AuditEvent.builder()
.key(dataSourceName)
.type("segments.killTask")
.payload(ImmutableMap.of("killTaskId", killTaskId, "interval", theInterval))
.auditInfo(new AuditInfo(author, comment, req.getRemoteAddr()))
.build()
);
return Response.ok().build();
}
catch (Exception e) {
Expand Down
Loading

0 comments on commit 74860c7

Please sign in to comment.