Skip to content

Commit

Permalink
Refactor query shape field data maps
Browse files Browse the repository at this point in the history
Signed-off-by: David Zane <[email protected]>
  • Loading branch information
dzane17 committed Sep 9, 2024
1 parent 4d896cc commit 20763f8
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,20 @@

package org.opensearch.plugin.insights.core.service.categorizer;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.hash.MurmurHash3;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.index.query.AbstractGeometryQueryBuilder;
import org.opensearch.index.query.CommonTermsQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.FieldMaskingSpanQueryBuilder;
import org.opensearch.index.query.FuzzyQueryBuilder;
import org.opensearch.index.query.GeoBoundingBoxQueryBuilder;
import org.opensearch.index.query.GeoDistanceQueryBuilder;
import org.opensearch.index.query.GeoPolygonQueryBuilder;
import org.opensearch.index.query.MatchBoolPrefixQueryBuilder;
import org.opensearch.index.query.MatchPhrasePrefixQueryBuilder;
import org.opensearch.index.query.MatchPhraseQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.MultiTermQueryBuilder;
import org.opensearch.index.query.PrefixQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.RegexpQueryBuilder;
import org.opensearch.index.query.SpanNearQueryBuilder;
import org.opensearch.index.query.SpanTermQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.query.WildcardQueryBuilder;
import org.opensearch.index.query.WithFieldName;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.PipelineAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.VariableWidthHistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.missing.MissingAggregationBuilder;
import org.opensearch.search.aggregations.bucket.range.AbstractRangeBuilder;
import org.opensearch.search.aggregations.bucket.range.GeoDistanceAggregationBuilder;
import org.opensearch.search.aggregations.bucket.range.IpRangeAggregationBuilder;
import org.opensearch.search.aggregations.bucket.sampler.DiversifiedAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.RareTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder;
import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder;
import org.opensearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.GeoCentroidAggregationBuilder;
import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder;
import org.opensearch.search.aggregations.metrics.MinAggregationBuilder;
import org.opensearch.search.aggregations.metrics.StatsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.SumAggregationBuilder;
import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder;
import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.FieldSortBuilder;
Expand All @@ -74,9 +33,9 @@
public class QueryShapeGenerator {
static final String EMPTY_STRING = "";
static final String ONE_SPACE_INDENT = " ";
static final Map<Class<?>, List<Function<Object, String>>> QUERY_FIELD_DATA_MAP = FieldDataMapHelper.getQueryFieldDataMap();
static final Map<Class<?>, List<Function<Object, String>>> AGG_FIELD_DATA_MAP = FieldDataMapHelper.getAggFieldDataMap();
static final Map<Class<?>, List<Function<Object, String>>> SORT_FIELD_DATA_MAP = FieldDataMapHelper.getSortFieldDataMap();
static final ConcurrentHashMap<Class<?>, List<Method>> QUERY_FIELD_DATA_MAP = FieldDataMapHelper.getQueryFieldDataMap();
static final ConcurrentHashMap<Class<?>, List<Method>> AGG_FIELD_DATA_MAP = FieldDataMapHelper.getAggFieldDataMap();
static final ConcurrentHashMap<Class<?>, List<Method>> SORT_FIELD_DATA_MAP = FieldDataMapHelper.getSortFieldDataMap();

/**
* Method to get query shape hash code given a source
Expand Down Expand Up @@ -161,7 +120,7 @@ static StringBuilder recursiveAggregationShapeBuilder(
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append(baseIndent).append(ONE_SPACE_INDENT.repeat(2)).append(aggBuilder.getType());
if (showFields) {
stringBuilder.append(buildFieldDataString(AGG_FIELD_DATA_MAP.get(aggBuilder.getClass()), aggBuilder));
stringBuilder.append(buildFieldDataString(AGG_FIELD_DATA_MAP, aggBuilder));
}
stringBuilder.append("\n");

Expand Down Expand Up @@ -227,7 +186,7 @@ static String buildSortShape(List<SortBuilder<?>> sortBuilderList, Boolean showF
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append(ONE_SPACE_INDENT.repeat(2)).append(sortBuilder.order());
if (showFields) {
stringBuilder.append(buildFieldDataString(SORT_FIELD_DATA_MAP.get(sortBuilder.getClass()), sortBuilder));
stringBuilder.append(buildFieldDataString(SORT_FIELD_DATA_MAP, sortBuilder));
}
shapeStrings.add(stringBuilder.toString());
}
Expand All @@ -243,11 +202,20 @@ static String buildSortShape(List<SortBuilder<?>> sortBuilderList, Boolean showF
* @return String: comma separated list with leading space in square brackets
* Ex: " [my_field, width:5]"
*/
static String buildFieldDataString(List<Function<Object, String>> methods, NamedWriteable builder) {
static String buildFieldDataString(ConcurrentHashMap<Class<?>, List<Method>> fieldMethodsMap, NamedWriteable builder) {
List<String> fieldDataList = new ArrayList<>();
if (methods != null) {
for (Function<Object, String> lambda : methods) {
fieldDataList.add(lambda.apply(builder));
for (ConcurrentHashMap.Entry<Class<?>, List<Method>> entry : fieldMethodsMap.entrySet()) {
if (entry.getKey().isInstance(builder)) {
List<Method> methodsList = entry.getValue();
for (Method method : methodsList) {
try {
fieldDataList.add((String) method.invoke(builder));
} catch (Exception e) {
throw new RuntimeException(
String.format("No such method '%s' for class '%s'", method.getName(), builder.getClass().getName())
);
}
}
}
}
return " [" + String.join(", ", fieldDataList) + "]";
Expand All @@ -257,85 +225,54 @@ static String buildFieldDataString(List<Function<Object, String>> methods, Named
* Helper class to create static field data maps
*/
private static class FieldDataMapHelper {

// Helper method to create map entries
private static <T> Map.Entry<Class<?>, List<Function<Object, String>>> createEntry(Class<T> clazz, Function<T, String> extractor) {
return Map.entry(clazz, List.of(obj -> extractor.apply(clazz.cast(obj))));
}
private static final String INVALID_METHOD_EXCEPTION_STRING = "Invalid method referenced during field data map creation";

/**
* Returns a map where the keys are query builders, and the values are lists of
* functions that extract field values from instances of these classes.
* Returns a map where the keys are query builder classes, and the values are lists of
* Methods that extract field data from instances of these classes.
*
* @return a map with class types as keys and lists of field extraction functions as values.
* @return a map with class types as keys and lists of Methods as values.
*/
private static Map<Class<?>, List<Function<Object, String>>> getQueryFieldDataMap() {
return Map.ofEntries(
createEntry(AbstractGeometryQueryBuilder.class, AbstractGeometryQueryBuilder::fieldName),
createEntry(CommonTermsQueryBuilder.class, CommonTermsQueryBuilder::fieldName),
createEntry(ExistsQueryBuilder.class, ExistsQueryBuilder::fieldName),
createEntry(FieldMaskingSpanQueryBuilder.class, FieldMaskingSpanQueryBuilder::fieldName),
createEntry(FuzzyQueryBuilder.class, FuzzyQueryBuilder::fieldName),
createEntry(GeoBoundingBoxQueryBuilder.class, GeoBoundingBoxQueryBuilder::fieldName),
createEntry(GeoDistanceQueryBuilder.class, GeoDistanceQueryBuilder::fieldName),
createEntry(GeoPolygonQueryBuilder.class, GeoPolygonQueryBuilder::fieldName),
createEntry(MatchBoolPrefixQueryBuilder.class, MatchBoolPrefixQueryBuilder::fieldName),
createEntry(MatchQueryBuilder.class, MatchQueryBuilder::fieldName),
createEntry(MatchPhraseQueryBuilder.class, MatchPhraseQueryBuilder::fieldName),
createEntry(MatchPhrasePrefixQueryBuilder.class, MatchPhrasePrefixQueryBuilder::fieldName),
createEntry(MultiTermQueryBuilder.class, MultiTermQueryBuilder::fieldName),
createEntry(PrefixQueryBuilder.class, PrefixQueryBuilder::fieldName),
createEntry(RangeQueryBuilder.class, RangeQueryBuilder::fieldName),
createEntry(RegexpQueryBuilder.class, RegexpQueryBuilder::fieldName),
createEntry(SpanNearQueryBuilder.SpanGapQueryBuilder.class, SpanNearQueryBuilder.SpanGapQueryBuilder::fieldName),
createEntry(SpanTermQueryBuilder.class, SpanTermQueryBuilder::fieldName),
createEntry(TermQueryBuilder.class, TermQueryBuilder::fieldName),
createEntry(TermsQueryBuilder.class, TermsQueryBuilder::fieldName),
createEntry(WildcardQueryBuilder.class, WildcardQueryBuilder::fieldName)
);
private static ConcurrentHashMap<Class<?>, List<Method>> getQueryFieldDataMap() {
ConcurrentHashMap<Class<?>, List<Method>> map = new ConcurrentHashMap<>();
try {
map.put(WithFieldName.class, List.of(WithFieldName.class.getMethod("fieldName")));
} catch (NoSuchMethodException e) {
throw new RuntimeException(INVALID_METHOD_EXCEPTION_STRING);
}
return map;
}

/**
* Returns a map where the keys are aggregation builders, and the values are lists of
* functions that extract field values from instances of these classes.
* Returns a map where the keys are aggregation builder classes, and the values are lists of
* Methods that extract field data from instances of these classes.
*
* @return a map with class types as keys and lists of field extraction functions as values.
* @return a map with class types as keys and lists of Methods as values.
*/
private static Map<Class<?>, List<Function<Object, String>>> getAggFieldDataMap() {
return Map.ofEntries(
createEntry(IpRangeAggregationBuilder.class, IpRangeAggregationBuilder::field),
createEntry(AutoDateHistogramAggregationBuilder.class, AutoDateHistogramAggregationBuilder::field),
createEntry(DateHistogramAggregationBuilder.class, DateHistogramAggregationBuilder::field),
createEntry(HistogramAggregationBuilder.class, HistogramAggregationBuilder::field),
createEntry(VariableWidthHistogramAggregationBuilder.class, VariableWidthHistogramAggregationBuilder::field),
createEntry(MissingAggregationBuilder.class, MissingAggregationBuilder::field),
createEntry(AbstractRangeBuilder.class, AbstractRangeBuilder::field),
createEntry(GeoDistanceAggregationBuilder.class, GeoDistanceAggregationBuilder::field),
createEntry(DiversifiedAggregationBuilder.class, DiversifiedAggregationBuilder::field),
createEntry(RareTermsAggregationBuilder.class, RareTermsAggregationBuilder::field),
createEntry(SignificantTermsAggregationBuilder.class, SignificantTermsAggregationBuilder::field),
createEntry(TermsAggregationBuilder.class, TermsAggregationBuilder::field),
createEntry(AvgAggregationBuilder.class, AvgAggregationBuilder::field),
createEntry(CardinalityAggregationBuilder.class, CardinalityAggregationBuilder::field),
createEntry(ExtendedStatsAggregationBuilder.class, ExtendedStatsAggregationBuilder::field),
createEntry(GeoCentroidAggregationBuilder.class, GeoCentroidAggregationBuilder::field),
createEntry(MaxAggregationBuilder.class, MaxAggregationBuilder::field),
createEntry(MinAggregationBuilder.class, MinAggregationBuilder::field),
createEntry(StatsAggregationBuilder.class, StatsAggregationBuilder::field),
createEntry(SumAggregationBuilder.class, SumAggregationBuilder::field),
createEntry(ValueCountAggregationBuilder.class, ValueCountAggregationBuilder::field),
createEntry(ValuesSourceAggregationBuilder.class, ValuesSourceAggregationBuilder::field)
);
private static ConcurrentHashMap<Class<?>, List<Method>> getAggFieldDataMap() {
ConcurrentHashMap<Class<?>, List<Method>> map = new ConcurrentHashMap<>();
try {
map.put(ValuesSourceAggregationBuilder.class, List.of(ValuesSourceAggregationBuilder.class.getMethod("field")));
} catch (NoSuchMethodException e) {
throw new RuntimeException(INVALID_METHOD_EXCEPTION_STRING);
}
return map;
}

/**
* Returns a map where the keys are sort builders, and the values are lists of
* functions that extract field values from instances of these classes.
* Returns a map where the keys are sort builder classes, and the values are lists of
* Methods that extract field data from instances of these classes.
*
* @return a map with class types as keys and lists of field extraction functions as values.
* @return a map with class types as keys and lists of Methods as values.
*/
private static Map<Class<?>, List<Function<Object, String>>> getSortFieldDataMap() {
return Map.ofEntries(createEntry(FieldSortBuilder.class, FieldSortBuilder::getFieldName));
private static ConcurrentHashMap<Class<?>, List<Method>> getSortFieldDataMap() {
ConcurrentHashMap<Class<?>, List<Method>> map = new ConcurrentHashMap<>();
try {
map.put(FieldSortBuilder.class, List.of(FieldSortBuilder.class.getMethod("getFieldName")));
} catch (NoSuchMethodException e) {
throw new RuntimeException(INVALID_METHOD_EXCEPTION_STRING);
}
return map;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

import static org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator.ONE_SPACE_INDENT;
import static org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator.QUERY_FIELD_DATA_MAP;
import static org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator.buildFieldDataString;

import java.util.ArrayList;
import java.util.EnumMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import org.apache.lucene.search.BooleanClause;
import org.opensearch.common.SetOnce;
import org.opensearch.index.query.QueryBuilder;
Expand All @@ -33,15 +33,7 @@ public final class QueryShapeVisitor implements QueryBuilderVisitor {
@Override
public void accept(QueryBuilder queryBuilder) {
queryType.set(queryBuilder.getName());

List<String> fieldDataList = new ArrayList<>();
List<Function<Object, String>> methods = QUERY_FIELD_DATA_MAP.get(queryBuilder.getClass());
if (methods != null) {
for (Function<Object, String> lambda : methods) {
fieldDataList.add(lambda.apply(queryBuilder));
}
}
fieldData.set(String.join(", ", fieldDataList));
fieldData.set(buildFieldDataString(QUERY_FIELD_DATA_MAP, queryBuilder));
}

@Override
Expand Down Expand Up @@ -101,7 +93,7 @@ public String toJson() {
public String prettyPrintTree(String indent, Boolean showFields) {
StringBuilder outputBuilder = new StringBuilder(indent).append(queryType.get());
if (showFields) {
outputBuilder.append(" [").append(fieldData.get()).append("]");
outputBuilder.append(fieldData.get());
}
outputBuilder.append("\n");
for (Map.Entry<BooleanClause.Occur, List<QueryShapeVisitor>> entry : childVisitors.entrySet()) {
Expand Down

0 comments on commit 20763f8

Please sign in to comment.