diff --git a/metricflow-semantics/metricflow_semantics/collection_helpers/lru_cache.py b/metricflow-semantics/metricflow_semantics/collection_helpers/lru_cache.py new file mode 100644 index 0000000000..c861c702d4 --- /dev/null +++ b/metricflow-semantics/metricflow_semantics/collection_helpers/lru_cache.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import threading +from typing import Dict, Generic, Optional, TypeVar + +KeyT = TypeVar("KeyT") +ValueT = TypeVar("ValueT") + + +class LruCache(Generic[KeyT, ValueT]): + """An LRU cache based on the insertion order of dictionaries. + + Since Python dictionaries iterate in the order that keys were inserted, they are used as the basis of this cache. + When an item is retrieved, the item in the dictionary is removed then re-inserted. + + This cache is used instead of the `fuctools.lru_cache` decorator for class instance methods as `lru_cache` keeps a + reference to the instance, preventing garbage collection of the instance using the decorator until the eviction of + the associated entry. + """ + + def __init__(self, max_cache_items: int, cache_dict: Optional[Dict[KeyT, ValueT]] = None) -> None: + """Initializer. + + Args: + max_cache_items: Limit of cache items to store. Once the limit is hit, the oldest item is evicted. + cache_dict: For shared use cases - the dictionary to use for the cache. + """ + self._lock = threading.Lock() + self._max_cache_items = max_cache_items + self._cache_dict: Dict[KeyT, ValueT] = cache_dict or {} + + def get(self, key: KeyT) -> Optional[ValueT]: # noqa: D102 + with self._lock: + value = self._cache_dict.get(key) + if value is not None: + del self._cache_dict[key] + self._cache_dict[key] = value + return value + + return None + + def set(self, key: KeyT, value: ValueT) -> None: # noqa: D102 + with self._lock: + if key in self._cache_dict: + return + + while len(self._cache_dict) >= self._max_cache_items: + key_to_delete = next(iter(self._cache_dict)) + del self._cache_dict[key_to_delete] + + self._cache_dict[key] = value + + def copy(self) -> LruCache: # noqa: D102 + return LruCache(max_cache_items=self._max_cache_items, cache_dict=dict(self._cache_dict)) diff --git a/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_lru_cache.py b/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_lru_cache.py new file mode 100644 index 0000000000..de6cc69636 --- /dev/null +++ b/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_lru_cache.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from metricflow_semantics.collection_helpers.lru_cache import LruCache + + +def test_lru_cache() -> None: # noqa: D103 + cache = LruCache[str, str](max_cache_items=2) + cache.set("key_0", "value_0") + cache.set("key_1", "value_1") + cache.set("key_2", "value_2") + + # This should evict "key_0". + assert cache.get("key_0") is None + + # Get "key_1" so that it's not evicted next. + assert cache.get("key_1") == "value_1" + + # This should evict "key_2". + cache.set("key_0", "value_0") + assert cache.get("key_2") is None