diff --git a/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeGenerator.java b/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeGenerator.java index 5664f3c4..205aa0f6 100644 --- a/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeGenerator.java +++ b/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeGenerator.java @@ -8,61 +8,20 @@ package org.opensearch.plugin.insights.core.service.categorizer; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.function.Function; +import java.util.concurrent.ConcurrentHashMap; import org.apache.lucene.util.BytesRef; import org.opensearch.common.hash.MurmurHash3; import org.opensearch.core.common.io.stream.NamedWriteable; -import org.opensearch.index.query.AbstractGeometryQueryBuilder; -import org.opensearch.index.query.CommonTermsQueryBuilder; -import org.opensearch.index.query.ExistsQueryBuilder; -import org.opensearch.index.query.FieldMaskingSpanQueryBuilder; -import org.opensearch.index.query.FuzzyQueryBuilder; -import org.opensearch.index.query.GeoBoundingBoxQueryBuilder; -import org.opensearch.index.query.GeoDistanceQueryBuilder; -import org.opensearch.index.query.GeoPolygonQueryBuilder; -import org.opensearch.index.query.MatchBoolPrefixQueryBuilder; -import org.opensearch.index.query.MatchPhrasePrefixQueryBuilder; -import org.opensearch.index.query.MatchPhraseQueryBuilder; -import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.index.query.MultiTermQueryBuilder; -import org.opensearch.index.query.PrefixQueryBuilder; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.RangeQueryBuilder; -import org.opensearch.index.query.RegexpQueryBuilder; -import org.opensearch.index.query.SpanNearQueryBuilder; -import org.opensearch.index.query.SpanTermQueryBuilder; -import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.index.query.TermsQueryBuilder; -import org.opensearch.index.query.WildcardQueryBuilder; +import org.opensearch.index.query.WithFieldName; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.PipelineAggregationBuilder; -import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder; -import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; -import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; -import org.opensearch.search.aggregations.bucket.histogram.VariableWidthHistogramAggregationBuilder; -import org.opensearch.search.aggregations.bucket.missing.MissingAggregationBuilder; -import org.opensearch.search.aggregations.bucket.range.AbstractRangeBuilder; -import org.opensearch.search.aggregations.bucket.range.GeoDistanceAggregationBuilder; -import org.opensearch.search.aggregations.bucket.range.IpRangeAggregationBuilder; -import org.opensearch.search.aggregations.bucket.sampler.DiversifiedAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.RareTermsAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder; -import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; -import org.opensearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder; -import org.opensearch.search.aggregations.metrics.GeoCentroidAggregationBuilder; -import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; -import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; -import org.opensearch.search.aggregations.metrics.StatsAggregationBuilder; -import org.opensearch.search.aggregations.metrics.SumAggregationBuilder; -import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.FieldSortBuilder; @@ -74,9 +33,9 @@ public class QueryShapeGenerator { static final String EMPTY_STRING = ""; static final String ONE_SPACE_INDENT = " "; - static final Map, List>> QUERY_FIELD_DATA_MAP = FieldDataMapHelper.getQueryFieldDataMap(); - static final Map, List>> AGG_FIELD_DATA_MAP = FieldDataMapHelper.getAggFieldDataMap(); - static final Map, List>> SORT_FIELD_DATA_MAP = FieldDataMapHelper.getSortFieldDataMap(); + static final ConcurrentHashMap, List> QUERY_FIELD_DATA_MAP = FieldDataMapHelper.getQueryFieldDataMap(); + static final ConcurrentHashMap, List> AGG_FIELD_DATA_MAP = FieldDataMapHelper.getAggFieldDataMap(); + static final ConcurrentHashMap, List> SORT_FIELD_DATA_MAP = FieldDataMapHelper.getSortFieldDataMap(); /** * Method to get query shape hash code given a source @@ -161,7 +120,7 @@ static StringBuilder recursiveAggregationShapeBuilder( StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(baseIndent).append(ONE_SPACE_INDENT.repeat(2)).append(aggBuilder.getType()); if (showFields) { - stringBuilder.append(buildFieldDataString(AGG_FIELD_DATA_MAP.get(aggBuilder.getClass()), aggBuilder)); + stringBuilder.append(buildFieldDataString(AGG_FIELD_DATA_MAP, aggBuilder)); } stringBuilder.append("\n"); @@ -227,7 +186,7 @@ static String buildSortShape(List> sortBuilderList, Boolean showF StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(ONE_SPACE_INDENT.repeat(2)).append(sortBuilder.order()); if (showFields) { - stringBuilder.append(buildFieldDataString(SORT_FIELD_DATA_MAP.get(sortBuilder.getClass()), sortBuilder)); + stringBuilder.append(buildFieldDataString(SORT_FIELD_DATA_MAP, sortBuilder)); } shapeStrings.add(stringBuilder.toString()); } @@ -243,11 +202,20 @@ static String buildSortShape(List> sortBuilderList, Boolean showF * @return String: comma separated list with leading space in square brackets * Ex: " [my_field, width:5]" */ - static String buildFieldDataString(List> methods, NamedWriteable builder) { + static String buildFieldDataString(ConcurrentHashMap, List> fieldMethodsMap, NamedWriteable builder) { List fieldDataList = new ArrayList<>(); - if (methods != null) { - for (Function lambda : methods) { - fieldDataList.add(lambda.apply(builder)); + for (ConcurrentHashMap.Entry, List> entry : fieldMethodsMap.entrySet()) { + if (entry.getKey().isInstance(builder)) { + List methodsList = entry.getValue(); + for (Method method : methodsList) { + try { + fieldDataList.add((String) method.invoke(builder)); + } catch (Exception e) { + throw new RuntimeException( + String.format("No such method '%s' for class '%s'", method.getName(), builder.getClass().getName()) + ); + } + } } } return " [" + String.join(", ", fieldDataList) + "]"; @@ -257,85 +225,54 @@ static String buildFieldDataString(List> methods, Named * Helper class to create static field data maps */ private static class FieldDataMapHelper { - - // Helper method to create map entries - private static Map.Entry, List>> createEntry(Class clazz, Function extractor) { - return Map.entry(clazz, List.of(obj -> extractor.apply(clazz.cast(obj)))); - } + private static final String INVALID_METHOD_EXCEPTION_STRING = "Invalid method referenced during field data map creation"; /** - * Returns a map where the keys are query builders, and the values are lists of - * functions that extract field values from instances of these classes. + * Returns a map where the keys are query builder classes, and the values are lists of + * Methods that extract field data from instances of these classes. * - * @return a map with class types as keys and lists of field extraction functions as values. + * @return a map with class types as keys and lists of Methods as values. */ - private static Map, List>> getQueryFieldDataMap() { - return Map.ofEntries( - createEntry(AbstractGeometryQueryBuilder.class, AbstractGeometryQueryBuilder::fieldName), - createEntry(CommonTermsQueryBuilder.class, CommonTermsQueryBuilder::fieldName), - createEntry(ExistsQueryBuilder.class, ExistsQueryBuilder::fieldName), - createEntry(FieldMaskingSpanQueryBuilder.class, FieldMaskingSpanQueryBuilder::fieldName), - createEntry(FuzzyQueryBuilder.class, FuzzyQueryBuilder::fieldName), - createEntry(GeoBoundingBoxQueryBuilder.class, GeoBoundingBoxQueryBuilder::fieldName), - createEntry(GeoDistanceQueryBuilder.class, GeoDistanceQueryBuilder::fieldName), - createEntry(GeoPolygonQueryBuilder.class, GeoPolygonQueryBuilder::fieldName), - createEntry(MatchBoolPrefixQueryBuilder.class, MatchBoolPrefixQueryBuilder::fieldName), - createEntry(MatchQueryBuilder.class, MatchQueryBuilder::fieldName), - createEntry(MatchPhraseQueryBuilder.class, MatchPhraseQueryBuilder::fieldName), - createEntry(MatchPhrasePrefixQueryBuilder.class, MatchPhrasePrefixQueryBuilder::fieldName), - createEntry(MultiTermQueryBuilder.class, MultiTermQueryBuilder::fieldName), - createEntry(PrefixQueryBuilder.class, PrefixQueryBuilder::fieldName), - createEntry(RangeQueryBuilder.class, RangeQueryBuilder::fieldName), - createEntry(RegexpQueryBuilder.class, RegexpQueryBuilder::fieldName), - createEntry(SpanNearQueryBuilder.SpanGapQueryBuilder.class, SpanNearQueryBuilder.SpanGapQueryBuilder::fieldName), - createEntry(SpanTermQueryBuilder.class, SpanTermQueryBuilder::fieldName), - createEntry(TermQueryBuilder.class, TermQueryBuilder::fieldName), - createEntry(TermsQueryBuilder.class, TermsQueryBuilder::fieldName), - createEntry(WildcardQueryBuilder.class, WildcardQueryBuilder::fieldName) - ); + private static ConcurrentHashMap, List> getQueryFieldDataMap() { + ConcurrentHashMap, List> map = new ConcurrentHashMap<>(); + try { + map.put(WithFieldName.class, List.of(WithFieldName.class.getMethod("fieldName"))); + } catch (NoSuchMethodException e) { + throw new RuntimeException(INVALID_METHOD_EXCEPTION_STRING); + } + return map; } /** - * Returns a map where the keys are aggregation builders, and the values are lists of - * functions that extract field values from instances of these classes. + * Returns a map where the keys are aggregation builder classes, and the values are lists of + * Methods that extract field data from instances of these classes. * - * @return a map with class types as keys and lists of field extraction functions as values. + * @return a map with class types as keys and lists of Methods as values. */ - private static Map, List>> getAggFieldDataMap() { - return Map.ofEntries( - createEntry(IpRangeAggregationBuilder.class, IpRangeAggregationBuilder::field), - createEntry(AutoDateHistogramAggregationBuilder.class, AutoDateHistogramAggregationBuilder::field), - createEntry(DateHistogramAggregationBuilder.class, DateHistogramAggregationBuilder::field), - createEntry(HistogramAggregationBuilder.class, HistogramAggregationBuilder::field), - createEntry(VariableWidthHistogramAggregationBuilder.class, VariableWidthHistogramAggregationBuilder::field), - createEntry(MissingAggregationBuilder.class, MissingAggregationBuilder::field), - createEntry(AbstractRangeBuilder.class, AbstractRangeBuilder::field), - createEntry(GeoDistanceAggregationBuilder.class, GeoDistanceAggregationBuilder::field), - createEntry(DiversifiedAggregationBuilder.class, DiversifiedAggregationBuilder::field), - createEntry(RareTermsAggregationBuilder.class, RareTermsAggregationBuilder::field), - createEntry(SignificantTermsAggregationBuilder.class, SignificantTermsAggregationBuilder::field), - createEntry(TermsAggregationBuilder.class, TermsAggregationBuilder::field), - createEntry(AvgAggregationBuilder.class, AvgAggregationBuilder::field), - createEntry(CardinalityAggregationBuilder.class, CardinalityAggregationBuilder::field), - createEntry(ExtendedStatsAggregationBuilder.class, ExtendedStatsAggregationBuilder::field), - createEntry(GeoCentroidAggregationBuilder.class, GeoCentroidAggregationBuilder::field), - createEntry(MaxAggregationBuilder.class, MaxAggregationBuilder::field), - createEntry(MinAggregationBuilder.class, MinAggregationBuilder::field), - createEntry(StatsAggregationBuilder.class, StatsAggregationBuilder::field), - createEntry(SumAggregationBuilder.class, SumAggregationBuilder::field), - createEntry(ValueCountAggregationBuilder.class, ValueCountAggregationBuilder::field), - createEntry(ValuesSourceAggregationBuilder.class, ValuesSourceAggregationBuilder::field) - ); + private static ConcurrentHashMap, List> getAggFieldDataMap() { + ConcurrentHashMap, List> map = new ConcurrentHashMap<>(); + try { + map.put(ValuesSourceAggregationBuilder.class, List.of(ValuesSourceAggregationBuilder.class.getMethod("field"))); + } catch (NoSuchMethodException e) { + throw new RuntimeException(INVALID_METHOD_EXCEPTION_STRING); + } + return map; } /** - * Returns a map where the keys are sort builders, and the values are lists of - * functions that extract field values from instances of these classes. + * Returns a map where the keys are sort builder classes, and the values are lists of + * Methods that extract field data from instances of these classes. * - * @return a map with class types as keys and lists of field extraction functions as values. + * @return a map with class types as keys and lists of Methods as values. */ - private static Map, List>> getSortFieldDataMap() { - return Map.ofEntries(createEntry(FieldSortBuilder.class, FieldSortBuilder::getFieldName)); + private static ConcurrentHashMap, List> getSortFieldDataMap() { + ConcurrentHashMap, List> map = new ConcurrentHashMap<>(); + try { + map.put(FieldSortBuilder.class, List.of(FieldSortBuilder.class.getMethod("getFieldName"))); + } catch (NoSuchMethodException e) { + throw new RuntimeException(INVALID_METHOD_EXCEPTION_STRING); + } + return map; } } } diff --git a/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeVisitor.java b/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeVisitor.java index b0642005..9c573fe7 100644 --- a/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeVisitor.java +++ b/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeVisitor.java @@ -10,13 +10,13 @@ import static org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator.ONE_SPACE_INDENT; import static org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator.QUERY_FIELD_DATA_MAP; +import static org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator.buildFieldDataString; import java.util.ArrayList; import java.util.EnumMap; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.function.Function; import org.apache.lucene.search.BooleanClause; import org.opensearch.common.SetOnce; import org.opensearch.index.query.QueryBuilder; @@ -33,15 +33,7 @@ public final class QueryShapeVisitor implements QueryBuilderVisitor { @Override public void accept(QueryBuilder queryBuilder) { queryType.set(queryBuilder.getName()); - - List fieldDataList = new ArrayList<>(); - List> methods = QUERY_FIELD_DATA_MAP.get(queryBuilder.getClass()); - if (methods != null) { - for (Function lambda : methods) { - fieldDataList.add(lambda.apply(queryBuilder)); - } - } - fieldData.set(String.join(", ", fieldDataList)); + fieldData.set(buildFieldDataString(QUERY_FIELD_DATA_MAP, queryBuilder)); } @Override @@ -101,7 +93,7 @@ public String toJson() { public String prettyPrintTree(String indent, Boolean showFields) { StringBuilder outputBuilder = new StringBuilder(indent).append(queryType.get()); if (showFields) { - outputBuilder.append(" [").append(fieldData.get()).append("]"); + outputBuilder.append(fieldData.get()); } outputBuilder.append("\n"); for (Map.Entry> entry : childVisitors.entrySet()) {