diff --git a/processing/src/main/java/org/apache/druid/query/DataSource.java b/processing/src/main/java/org/apache/druid/query/DataSource.java index 360c339627f9..4c4113220ed9 100644 --- a/processing/src/main/java/org/apache/druid/query/DataSource.java +++ b/processing/src/main/java/org/apache/druid/query/DataSource.java @@ -36,6 +36,7 @@ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type", defaultImpl = TableDataSource.class) @JsonSubTypes({ @JsonSubTypes.Type(value = TableDataSource.class, name = "table"), + @JsonSubTypes.Type(value = SampledTableDataSource.class, name = "sampled_table"), @JsonSubTypes.Type(value = QueryDataSource.class, name = "query"), @JsonSubTypes.Type(value = UnionDataSource.class, name = "union"), @JsonSubTypes.Type(value = JoinDataSource.class, name = "join"), diff --git a/processing/src/main/java/org/apache/druid/query/SampledTableDataSource.java b/processing/src/main/java/org/apache/druid/query/SampledTableDataSource.java new file mode 100644 index 000000000000..5d1dd777f171 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/SampledTableDataSource.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeName; +import com.fasterxml.jackson.annotation.JsonValue; +import org.apache.druid.java.util.common.Cacheable; +import org.apache.druid.java.util.common.StringUtils; + +@JsonTypeName("sampled_table") +public class SampledTableDataSource extends TableDataSource +{ + private final SamplingType samplingType; + private final int samplingPercentage; + + public enum SamplingType implements Cacheable + { + FIXED_SHARD; + + @JsonValue + @Override + public String toString() + { + return StringUtils.toLowerCase(this.name()); + } + + @JsonCreator + public static SamplingType fromString(String name) + { + return valueOf(StringUtils.toUpperCase(name)); + } + + @Override + public byte[] getCacheKey() + { + return new byte[] {(byte) this.ordinal()}; + } + } + + @JsonCreator + public SampledTableDataSource( + @JsonProperty("name") String name, + @JsonProperty("samplingType") SamplingType samplingType, + @JsonProperty("samplingPercentage") int samplingPercentage + ) + { + super(name); + this.samplingType = samplingType; + this.samplingPercentage = samplingPercentage; + } + + @JsonCreator + public static SampledTableDataSource create( + @JsonProperty("name")final String name, + @JsonProperty("samplingType")final String samplingType, + @JsonProperty("samplingPercentage")final int samplingPercentage) + { + return new SampledTableDataSource(name, SamplingType.fromString(samplingType), samplingPercentage); + } + + + @JsonProperty + public SamplingType getSamplingType() + { + return samplingType; + } + + @JsonProperty + public float getSamplingPercentage() + { + return samplingPercentage; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (!(o instanceof SampledTableDataSource)) { + return false; + } + if (!super.equals(o)) { + return false; + } + + SampledTableDataSource that = (SampledTableDataSource) o; + + if (samplingPercentage != that.samplingPercentage) { + return false; + } + return samplingType == that.samplingType; + } + + @Override + public int hashCode() + { + int result = super.hashCode(); + result = 31 * result + (samplingType != null ? samplingType.hashCode() : 0); + result = 31 * result + samplingPercentage; + return result; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java b/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java index 6727782cc406..34cdb580eb70 100644 --- a/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java +++ b/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java @@ -453,6 +453,12 @@ public Object mergeValues(Object oldValue, Object newValue) false ); + public static final Key SAMPLING_COMPOSITION = new StringKey( + "samplingComposition", + true, + false + ); + /** * Indicates if a {@link ResponseContext} was truncated during serialization. */ @@ -488,6 +494,7 @@ public Object mergeValues(Object oldValue, Object newValue) TIMEOUT_AT, NUM_SCANNED_ROWS, CPU_CONSUMED_NANOS, + SAMPLING_COMPOSITION, TRUNCATED, } ); @@ -738,6 +745,12 @@ public void addCpuNanos(long ns) addValue(Keys.CPU_CONSUMED_NANOS, ns); } + + public void addSamplingComposition(String samplingComposition) + { + addValue(Keys.SAMPLING_COMPOSITION, samplingComposition); + } + private Object addValue(Key key, Object value) { return getDelegate().merge(key, value, key::mergeValues); diff --git a/processing/src/main/java/org/apache/druid/query/planning/DataSourceAnalysis.java b/processing/src/main/java/org/apache/druid/query/planning/DataSourceAnalysis.java index f17ab6aec235..dfea4cc36d19 100644 --- a/processing/src/main/java/org/apache/druid/query/planning/DataSourceAnalysis.java +++ b/processing/src/main/java/org/apache/druid/query/planning/DataSourceAnalysis.java @@ -24,6 +24,7 @@ import org.apache.druid.query.DataSource; import org.apache.druid.query.JoinDataSource; import org.apache.druid.query.Query; +import org.apache.druid.query.SampledTableDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnionDataSource; import org.apache.druid.query.UnnestDataSource; @@ -117,6 +118,8 @@ public Optional getBaseTableDataSource() { if (baseDataSource instanceof TableDataSource) { return Optional.of((TableDataSource) baseDataSource); + } else if (baseDataSource instanceof SampledTableDataSource) { + return Optional.of((SampledTableDataSource) baseDataSource); } else { return Optional.empty(); } @@ -216,6 +219,7 @@ public boolean isConcreteBased() public boolean isTableBased() { return (baseDataSource instanceof TableDataSource + || baseDataSource instanceof SampledTableDataSource || (baseDataSource instanceof UnionDataSource && baseDataSource.getChildren() .stream() diff --git a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java index 19df276344ef..519296e388aa 100644 --- a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java +++ b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java @@ -69,6 +69,8 @@ import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.QueryToolChestWarehouse; import org.apache.druid.query.Result; +import org.apache.druid.query.SampledTableDataSource; +import org.apache.druid.query.SampledTableDataSource.SamplingType; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.query.aggregation.MetricManipulatorFns; import org.apache.druid.query.context.ResponseContext; @@ -344,6 +346,10 @@ ClusterQueryResult run( } final Set segmentServers = computeSegmentsToQuery(timeline, specificSegments); + Pair ratio = pruneSegmentsForShardSampling(segmentServers); + if (ratio != null) { + responseContext.addSamplingComposition(ratio.lhs + "/" + ratio.rhs); + } @Nullable final byte[] queryCacheKey = cacheKeyManager.computeSegmentLevelQueryCacheKey(); if (query.getContext().get(QueryResource.HEADER_IF_NONE_MATCH) != null) { @@ -460,6 +466,7 @@ private Set computeSegmentsToQuery( segments.add(new SegmentServerSelector(server, segment)); } } + return segments; } @@ -503,6 +510,36 @@ private void computeUncoveredIntervals(TimelineLookup ti } } + private Pair pruneSegmentsForShardSampling(final Set segments) + { + if (query.getDataSource() instanceof SampledTableDataSource) { + if (((SampledTableDataSource) query.getDataSource()).getSamplingType() + == SamplingType.FIXED_SHARD) { + int allSegmentsSize = segments.size(); + int allShards = segments.stream() + .mapToInt(s -> s.getSegmentDescriptor().getPartitionNumber()).max().getAsInt(); + int targetShards = Math.round( + allShards * ((SampledTableDataSource) query.getDataSource()).getSamplingPercentage()) / 100; + Iterator iterator = segments.iterator(); + int removedSegments = 0; + while (iterator.hasNext()) { + SegmentServerSelector segmentServerSelector = iterator.next(); + SegmentDescriptor segmentDescriptor = segmentServerSelector.getSegmentDescriptor(); + int shard = segmentDescriptor.getPartitionNumber(); + if (targetShards < shard) { + removedSegments++; + iterator.remove(); + } + } + return Pair.of(allSegmentsSize - removedSegments, allSegmentsSize); + } else { + throw new UnsupportedOperationException(""); + } + } + return null; + } + + private List> pruneSegmentsWithCachedResults( final byte[] queryCacheKey, final Set segments @@ -541,6 +578,7 @@ private Map computePerSegmentCacheKeys( byte[] queryCacheKey ) { + // cacheKeys map must preserve segment ordering, in order for shards to always be combined in the same order Map cacheKeys = Maps.newLinkedHashMap(); for (SegmentServerSelector segmentServer : segments) {