diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java index 68c36234b27d..e20ad0c9fdeb 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java @@ -1767,21 +1767,41 @@ public void testBrokerResponseMetadata(boolean useMultiStageQueryEngine) public void testInBuiltVirtualColumns(boolean useMultiStageQueryEngine) throws Exception { setUseMultiStageQueryEngine(useMultiStageQueryEngine); - String query = "SELECT $docId, $hostName, $segmentName FROM mytable"; + + String query = "SELECT $docId, $hostName, $segmentName FROM mytable LIMIT 10"; JsonNode response = postQuery(query); JsonNode resultTable = response.get("resultTable"); JsonNode dataSchema = resultTable.get("dataSchema"); assertEquals(dataSchema.get("columnNames").toString(), "[\"$docId\",\"$hostName\",\"$segmentName\"]"); assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"STRING\",\"STRING\"]"); JsonNode rows = resultTable.get("rows"); + assertEquals(rows.size(), 10); String expectedHostName = NetUtils.getHostnameOrAddress(); String expectedSegmentNamePrefix = "mytable_"; for (int i = 0; i < 10; i++) { JsonNode row = rows.get(i); assertEquals(row.get(0).asInt(), i); assertEquals(row.get(1).asText(), expectedHostName); - assertTrue(row.get(2).asText().startsWith(expectedSegmentNamePrefix)); + String segmentName = row.get(2).asText(); + assertTrue(segmentName.startsWith(expectedSegmentNamePrefix)); } + + // Collect all segment names + query = "SELECT DISTINCT $segmentName FROM mytable LIMIT 10000"; + response = postQuery(query); + rows = response.get("resultTable").get("rows"); + int numSegments = rows.size(); + List segmentNames = new ArrayList<>(numSegments); + for (int i = 0; i < numSegments; i++) { + segmentNames.add(rows.get(i).get(0).asText()); + } + // Test IN clause on $segmentName + Collections.shuffle(segmentNames); + int numSegmentsToQuery = RANDOM.nextInt(numSegments) + 1; + query = "SELECT COUNT(*) FROM mytable WHERE $segmentName IN ('" + String.join("','", + segmentNames.subList(0, numSegmentsToQuery)) + "')"; + response = postQuery(query); + assertEquals(response.get("numSegmentsMatched").asInt(), numSegmentsToQuery); } @Test(dataProvider = "useBothQueryEngines") diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/readers/ConstantValueStringDictionary.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/readers/ConstantValueStringDictionary.java index 76cecaeecf99..270a8905d2f4 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/readers/ConstantValueStringDictionary.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/readers/ConstantValueStringDictionary.java @@ -18,7 +18,10 @@ */ package org.apache.pinot.segment.local.segment.index.readers; +import it.unimi.dsi.fastutil.ints.IntSet; import java.math.BigDecimal; +import java.util.Collections; +import java.util.List; import org.apache.pinot.spi.data.FieldSpec.DataType; import static java.nio.charset.StandardCharsets.UTF_8; @@ -113,4 +116,18 @@ public String getStringValue(int dictId) { public byte[] getBytesValue(int dictId) { return _bytes; } + + @Override + public void getDictIds(List values, IntSet dictIds) { + if (values.contains(_value)) { + dictIds.add(0); + } + } + + @Override + public void getDictIds(List sortedValues, IntSet dictIds, SortedBatchLookupAlgorithm algorithm) { + if (Collections.binarySearch(sortedValues, _value) >= 0) { + dictIds.add(0); + } + } }