Skip to content

Commit

Permalink
add table samplinng
Browse files Browse the repository at this point in the history
  • Loading branch information
takaaki7 authored and takaaki7 committed Oct 12, 2023
1 parent 6a6f36a commit c339b85
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type", defaultImpl = TableDataSource.class)
@JsonSubTypes({
@JsonSubTypes.Type(value = TableDataSource.class, name = "table"),
@JsonSubTypes.Type(value = TableDataSource.class, name = "sampled_table"),
@JsonSubTypes.Type(value = SampledTableDataSource.class, name = "sampled_table"),
@JsonSubTypes.Type(value = QueryDataSource.class, name = "query"),
@JsonSubTypes.Type(value = UnionDataSource.class, name = "union"),
@JsonSubTypes.Type(value = JoinDataSource.class, name = "join"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@
import org.apache.druid.segment.SegmentReference;

@JsonTypeName("sampled_table")
public class SampledTableDataSource implements DataSource
public class SampledTableDataSource extends TableDataSource
{
private final String name;
private final SamplingType samplingType;
private final float samplingPercentage;
private final int samplingPercentage;
public enum SamplingType implements Cacheable
{
FIXED_SHARD;
Expand Down Expand Up @@ -70,123 +69,59 @@ public byte[] getCacheKey()
public SampledTableDataSource(
@JsonProperty("name") String name,
@JsonProperty("samplingType") SamplingType samplingType,
@JsonProperty("samplingPercentage") float samplingPercentage
@JsonProperty("samplingPercentage") int samplingPercentage
)
{
this.name = Preconditions.checkNotNull(name, "'name' must be nonnull");
super(name);
this.samplingType = samplingType;
this.samplingPercentage = samplingPercentage;
}

@JsonCreator
public static SampledTableDataSource create(final String name, final String samplingType, final float samplingRatio)
public static SampledTableDataSource create(
@JsonProperty("name")final String name,
@JsonProperty("samplingType")final String samplingType,
@JsonProperty("samplingPercentage")final int samplingPercentage)
{
return new SampledTableDataSource(name, SamplingType.fromString(samplingType), samplingRatio);
return new SampledTableDataSource(name, SamplingType.fromString(samplingType), samplingPercentage);
}

@JsonProperty
public String getName()
{
return name;
}

@Override
public Set<String> getTableNames()
{
return Collections.singleton(name);
}

@Override
public List<DataSource> getChildren()
{
return Collections.emptyList();
}

@Override
public DataSource withChildren(List<DataSource> children)
{
if (!children.isEmpty()) {
throw new IAE("Cannot accept children");
}

return this;
}

@Override
public boolean isCacheable(boolean isBroker)
{
return true;
}

@Override
public boolean isGlobal()
{
return false;
}

@Override
public boolean isConcrete()
{
return true;
}

@Override
public Function<SegmentReference, SegmentReference> createSegmentMapFunction(
Query query,
AtomicLong cpuTime
)
{
return Function.identity();
}

@Override
public DataSource withUpdatedDataSource(DataSource newSource)
{
return newSource;
}

@Override
public byte[] getCacheKey()
{
return new byte[0];
}

@Override
public DataSourceAnalysis getAnalysis()
{
return new DataSourceAnalysis(this, null, null, Collections.emptyList());
}

@JsonProperty
public SamplingType getSamplingType() {
return samplingType;
}

@JsonProperty
public float getSamplingPercentage() {
return samplingPercentage;
}

@Override
public String toString()
{
return name;
}

@Override
public boolean equals(Object o)
{
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
if (!(o instanceof SampledTableDataSource)) {
return false;
}
if (!super.equals(o)) {
return false;
}

SampledTableDataSource that = (SampledTableDataSource) o;
return name.equals(that.name);

if (samplingPercentage != that.samplingPercentage) {
return false;
}
return samplingType == that.samplingType;
}

@Override
public int hashCode()
{
return Objects.hash(name);
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (samplingType != null ? samplingType.hashCode() : 0);
result = 31 * result + samplingPercentage;
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,12 @@ public Object mergeValues(Object oldValue, Object newValue)
false
);

public static final Key SAMPLING_COMPOSITION = new StringKey(
"samplingComposition",
true,
false
);

/**
* Indicates if a {@link ResponseContext} was truncated during serialization.
*/
Expand Down Expand Up @@ -488,6 +494,7 @@ public Object mergeValues(Object oldValue, Object newValue)
TIMEOUT_AT,
NUM_SCANNED_ROWS,
CPU_CONSUMED_NANOS,
SAMPLING_COMPOSITION,
TRUNCATED,
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.druid.query.DataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.SampledTableDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource;
import org.apache.druid.query.UnnestDataSource;
Expand Down Expand Up @@ -113,10 +114,11 @@ public DataSource getBaseDataSource()
* Note that this can return empty even if {@link #isConcreteAndTableBased()} is true. This happens if the base
* datasource is a {@link UnionDataSource} of {@link TableDataSource}.
*/
public Optional<TableDataSource> getBaseTableDataSource()
{
public Optional<TableDataSource> getBaseTableDataSource() {
if (baseDataSource instanceof TableDataSource) {
return Optional.of((TableDataSource) baseDataSource);
} else if (baseDataSource instanceof SampledTableDataSource) {
return Optional.of((SampledTableDataSource) baseDataSource);
} else {
return Optional.empty();
}
Expand Down Expand Up @@ -216,6 +218,7 @@ public boolean isConcreteBased()
public boolean isTableBased()
{
return (baseDataSource instanceof TableDataSource
|| baseDataSource instanceof SampledTableDataSource
|| (baseDataSource instanceof UnionDataSource &&
baseDataSource.getChildren()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.druid.client;

import static org.apache.druid.query.context.ResponseContext.Keys.SAMPLING_COMPOSITION;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -346,7 +348,10 @@ ClusterQueryResult<T> run(
}

final Set<SegmentServerSelector> segmentServers = computeSegmentsToQuery(timeline, specificSegments);
pruneSegmentsForShardSampling(segmentServers);
Pair<Integer, Integer> ratio = pruneSegmentsForShardSampling(segmentServers);
if (ratio != null) {
responseContext.add(SAMPLING_COMPOSITION, ratio.lhs + "/" + ratio.rhs);
}
@Nullable
final byte[] queryCacheKey = cacheKeyManager.computeSegmentLevelQueryCacheKey();
if (query.getContext().get(QueryResource.HEADER_IF_NONE_MATCH) != null) {
Expand Down Expand Up @@ -507,30 +512,32 @@ private void computeUncoveredIntervals(TimelineLookup<String, ServerSelector> ti
}
}

private void pruneSegmentsForShardSampling(
final Set<SegmentServerSelector> segments
) {
if (
query.getDataSource() instanceof SampledTableDataSource
) {
private Pair<Integer, Integer> pruneSegmentsForShardSampling(final Set<SegmentServerSelector> segments) {
if (query.getDataSource() instanceof SampledTableDataSource) {
if (((SampledTableDataSource) query.getDataSource()).getSamplingType()
== SamplingType.FIXED_SHARD) {
int allShards = segments.stream().mapToInt(s->s.getSegmentDescriptor().getPartitionNumber()).max().getAsInt();
int allSegmentsSize = segments.size();
int allShards = segments.stream()
.mapToInt(s -> s.getSegmentDescriptor().getPartitionNumber()).max().getAsInt();
int targetShards = Math.round(
allShards * ((SampledTableDataSource) query.getDataSource()).getSamplingPercentage());
Iterator<SegmentServerSelector> iterator = segments.iterator();
int removedSegments = 0;
while (iterator.hasNext()) {
SegmentServerSelector segmentServerSelector = iterator.next();
SegmentDescriptor segmentDescriptor = segmentServerSelector.getSegmentDescriptor();
int shard = segmentDescriptor.getPartitionNumber();
if (targetShards < shard) {
removedSegments++;
iterator.remove();
}
}
return Pair.of(allSegmentsSize - removedSegments, allSegmentsSize);
} else {
throw new UnsupportedOperationException("");
}
}
return null;
}


Expand Down

0 comments on commit c339b85

Please sign in to comment.