Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition bug in NativeMemoryCacheManager #2435

Draft
wants to merge 12 commits into
base: 2.x
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352]
* Fix filter rewrite logic which was resulting in getting inconsistent / incorrect results for cases where filter was getting rewritten for shards (#2359)[https://github.com/opensearch-project/k-NN/pull/2359]
* Fixing it to retrieve space_type from index setting when both method and top level don't have the value. [#2374](https://github.com/opensearch-project/k-NN/pull/2374)
* Fix race condition bug in native memory cache where an entry can be evicted before being read for the query [#2262](https://github.com/opensearch-project/k-NN/pull/2435)
* Fixing the bug where setting rescore as false for on_disk knn_vector query is a no-op (#2399)[https://github.com/opensearch-project/k-NN/pull/2399]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,16 +297,25 @@ public CacheStats getCacheStats() {
return cache.stats();
}

public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryContext, boolean isAbleToTriggerEviction)
throws ExecutionException {
return get(nativeMemoryEntryContext, isAbleToTriggerEviction, false);
}

/**
* Retrieves NativeMemoryAllocation associated with the nativeMemoryEntryContext.
*
* @param nativeMemoryEntryContext Context from which to get NativeMemoryAllocation
* @param isAbleToTriggerEviction Determines if getting this allocation can evict other entries
* @param acquirePreemptiveReadLock Determines whether to increment ref count during cache loading to prevent eviction race condition
* @return NativeMemoryAllocation associated with nativeMemoryEntryContext
* @throws ExecutionException if there is an exception when loading from the cache
*/
public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryContext, boolean isAbleToTriggerEviction)
throws ExecutionException {
public NativeMemoryAllocation get(
NativeMemoryEntryContext<?> nativeMemoryEntryContext,
boolean isAbleToTriggerEviction,
boolean acquirePreemptiveReadLock
) throws ExecutionException {
if (!isAbleToTriggerEviction
&& (maxWeight - getCacheSizeInKilobytes() - nativeMemoryEntryContext.calculateSizeInKB()) <= 0
&& !cache.asMap().containsKey(nativeMemoryEntryContext.getKey())) {
Expand Down Expand Up @@ -340,31 +349,47 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryC
if (result != null) {
accessRecencyQueue.remove(key);
accessRecencyQueue.addLast(key);
if (acquirePreemptiveReadLock) {
result.incRef();
}
return result;
}

// Cache Miss
// Evict before put
synchronized (this) {
if (getCacheSizeInKilobytes() + nativeMemoryEntryContext.calculateSizeInKB() >= maxWeight) {
Iterator<String> lruIterator = accessRecencyQueue.iterator();
while (lruIterator.hasNext()
&& (getCacheSizeInKilobytes() + nativeMemoryEntryContext.calculateSizeInKB() >= maxWeight)) {

String keyToRemove = lruIterator.next();
NativeMemoryAllocation allocationToRemove = cache.getIfPresent(keyToRemove);
if (allocationToRemove != null) {
allocationToRemove.close();
cache.invalidate(keyToRemove);
AtomicBoolean lockAcquired = new AtomicBoolean(false);
try {
synchronized (this) {
if (getCacheSizeInKilobytes() + nativeMemoryEntryContext.calculateSizeInKB() >= maxWeight) {
Iterator<String> lruIterator = accessRecencyQueue.iterator();
while (lruIterator.hasNext()
&& (getCacheSizeInKilobytes() + nativeMemoryEntryContext.calculateSizeInKB() >= maxWeight)) {

String keyToRemove = lruIterator.next();
NativeMemoryAllocation allocationToRemove = cache.getIfPresent(keyToRemove);
if (allocationToRemove != null) {
allocationToRemove.close();
cache.invalidate(keyToRemove);
}
lruIterator.remove();
}
lruIterator.remove();
}
result = cache.get(key, () -> {
NativeMemoryAllocation allocation = nativeMemoryEntryContext.load();
if (acquirePreemptiveReadLock) {
allocation.incRef();
lockAcquired.set(true);
}
return allocation;
});
accessRecencyQueue.addLast(key);
return result;
}

result = cache.get(key, nativeMemoryEntryContext::load);
accessRecencyQueue.addLast(key);

return result;
} catch (Exception e) {
if (result != null && lockAcquired.get()) {
result.decRef();
}
throw e;
}
} else {
return cache.get(nativeMemoryEntryContext.getKey(), nativeMemoryEntryContext::load);
Expand Down
9 changes: 7 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -298,6 +299,7 @@ private Map<Integer, Float> doANNSearch(

// We need to first get index allocation
NativeMemoryAllocation indexAllocation;
boolean acquirePreemptiveReadLock = KNNFeatureFlags.isForceEvictCacheEnabled();
try {
indexAllocation = nativeMemoryCacheManager.get(
new NativeMemoryEntryContext.IndexEntryContext(
Expand All @@ -314,7 +316,8 @@ private Map<Integer, Float> doANNSearch(
knnQuery.getIndexName(),
modelId
),
true
true,
acquirePreemptiveReadLock
);
} catch (ExecutionException e) {
GRAPH_QUERY_ERRORS.increment();
Expand All @@ -327,7 +330,9 @@ private Map<Integer, Float> doANNSearch(
FilterIdsSelector.FilterIdsSelectorType filterType = filterIdsSelector.getFilterType();
// Now that we have the allocation, we need to readLock it
indexAllocation.readLock();
indexAllocation.incRef();
if (!acquirePreemptiveReadLock) {
indexAllocation.incRef();
}
try {
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.knn.KNNTestCase;
Expand Down Expand Up @@ -99,6 +100,7 @@
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_FORCE_EVICT_CACHE_ENABLED_SETTING;

public class KNNWeightTests extends KNNTestCase {
private static final String FIELD_NAME = "target_field";
Expand Down Expand Up @@ -134,6 +136,9 @@ public class KNNWeightTests extends KNNTestCase {
public static void setUpClass() throws Exception {
final KNNSettings knnSettings = mock(KNNSettings.class);
knnSettingsMockedStatic = mockStatic(KNNSettings.class);
ClusterSettings clusterSettings = mock(ClusterSettings.class);
when(knnSettings.getSettingValue("knn.feature.cache.force_evict.enabled")).thenReturn(true);
when(clusterSettings.get(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING)).thenReturn(false);
when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED))).thenReturn(true);
when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT))).thenReturn(CIRCUIT_BREAKER_LIMIT_100KB);
when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED))).thenReturn(false);
Expand All @@ -152,6 +157,7 @@ public static void setUpClass() throws Exception {
final NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class);
final NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class);
when(nativeMemoryCacheManager.get(any(), anyBoolean())).thenReturn(nativeMemoryAllocation);
when(nativeMemoryCacheManager.get(any(), anyBoolean(), anyBoolean())).thenReturn(nativeMemoryAllocation);

nativeMemoryCacheManagerMockedStatic.when(NativeMemoryCacheManager::getInstance).thenReturn(nativeMemoryCacheManager);

Expand Down
Loading