diff --git a/outlines/caching.py b/outlines/caching.py index 6fdda6214..6f05dd665 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import fcntl import functools import os from typing import Callable, Optional @@ -11,6 +12,21 @@ _caching_enabled = True +class FileLock: + def __init__(self, lock_path): + lock_dir = os.path.dirname(lock_path) + os.makedirs(lock_dir, exist_ok=True) + self.lock_file = open(lock_path, "a+") + + def __enter__(self): + fcntl.flock(self.lock_file, fcntl.LOCK_EX) + return self.lock_file + + def __exit__(self, *args): + fcntl.flock(self.lock_file, fcntl.LOCK_UN) + self.lock_file.close() + + class CloudpickleDisk(Disk): def __init__(self, directory, compress_level=1, **kwargs): self.compress_level = compress_level @@ -52,12 +68,13 @@ def get_cache(): home_dir = os.path.expanduser("~") cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines") - memory = Cache( - cache_dir, - eviction_policy="none", - cull_limit=0, - disk=CloudpickleDisk, - ) + with FileLock(os.path.join(cache_dir, "outlines_cache.lock")): + memory = Cache( + cache_dir, + eviction_policy="none", + cull_limit=0, + disk=CloudpickleDisk, + ) # ensure if version upgrade occurs, old cache is pruned if outlines_version != memory.get("__version__"):