From 8782798aebcfa2f4080deb55fb6f53f014b759a3 Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Sun, 21 Apr 2024 09:38:55 -0500 Subject: [PATCH] Implement a file handle cache (#54) * Implement a file handle cache To prevent constantly opening and closing in calls to cat_file. Also, fix up the vector read size code to work correctly lint * Try to fire and forget sync close * No TTL yet but the tests pass * Implement TTL, but not checked in a watch loop * Background cache pruner * Play around with better typing a bit * Test to show closing is ok while reading --- src/fsspec_xrootd/xrootd.py | 359 +++++++++++++++++++++--------------- tests/test_basicio.py | 36 ++++ 2 files changed, 243 insertions(+), 152 deletions(-) diff --git a/src/fsspec_xrootd/xrootd.py b/src/fsspec_xrootd/xrootd.py index 6ec6e66..da2b7f7 100644 --- a/src/fsspec_xrootd/xrootd.py +++ b/src/fsspec_xrootd/xrootd.py @@ -3,13 +3,16 @@ import asyncio import io import os.path +import time import warnings +import weakref from collections import defaultdict +from dataclasses import dataclass from enum import IntEnum -from functools import partial -from typing import Any, Callable, Iterable +from typing import Any, Callable, Coroutine, Iterable, TypeVar -from fsspec.asyn import AsyncFileSystem, _run_coros_in_chunks, sync_wrapper +from fsspec.asyn import AsyncFileSystem, _run_coros_in_chunks, sync, sync_wrapper +from fsspec.exceptions import FSTimeoutError from fsspec.spec import AbstractBufferedFile from XRootD import client from XRootD.client.flags import ( @@ -26,52 +29,45 @@ class ErrorCodes(IntEnum): INVALID_PATH = 400 -def _handle( - future: asyncio.Future[tuple[XRootDStatus, Any]], - status: XRootDStatus, - content: Any, - servers: HostList, -) -> None: - """Sets result of _async_wrap() future. +T = TypeVar("T") +# TODO: Protocol typing when kwargs is supported - Parameters - ---------- - future: asyncio future, created in _async_wrap() - status: XRootDStatus, pyxrootd response object - content: any, whatever was returned from pyxrootd function - servers: Hostlist, iterable list of host info (currently unused) - - Returns - ------- - Sets the future result. - """ - if future.cancelled(): - return - try: - future.get_loop().call_soon_threadsafe(future.set_result, (status, content)) - except Exception as exc: - future.get_loop().call_soon_threadsafe(future.set_exception, exc) - -async def _async_wrap(func: Callable[..., Any], *args: Any) -> Any: - """Wraps pyxrootd functions to run asynchronously. Returns future to be awiated. +def _async_wrap( + func: Callable[..., XRootDStatus | tuple[XRootDStatus, T]] +) -> Callable[..., Coroutine[Any, Any, tuple[XRootDStatus, T]]]: + """Wraps pyxrootd functions to run asynchronously. Returns an async callable Parameters ---------- - func: pyxrootd function, needs to have a callback option - args: non-keyworded arguments for pyxrootd function + func: XRootD function that implements, needs to have a callback option Returns ------- - An asyncio future. Result is set when _handle() is called back. + A function with the same signature as func, but with an implicit `callback` argument """ - future = asyncio.get_running_loop().create_future() - submit_status = func(*args, callback=partial(_handle, future)) - if not submit_status.ok: - raise OSError( - f"Failed to submit {func!r} request: {submit_status.message.strip()}" - ) - return await future + future: asyncio.Future[tuple[XRootDStatus, T]] = ( + asyncio.get_running_loop().create_future() + ) + + def callback(status: XRootDStatus, content: T, servers: HostList) -> None: + if future.cancelled(): + return + loop = future.get_loop() + try: + loop.call_soon_threadsafe(future.set_result, (status, content)) + except Exception as exc: + loop.call_soon_threadsafe(future.set_exception, exc) + + async def wrapped(*args: Any, **kwargs: Any) -> tuple[XRootDStatus, T]: + submit_status: XRootDStatus = func(*args, **kwargs, callback=callback) + if not submit_status.ok: + raise OSError( + f"Failed to submit {func!r} request: {submit_status.message.strip()}" + ) + return await future + + return wrapped def _chunks_to_vectors( @@ -139,6 +135,90 @@ def _vectors_to_chunks( return deets +@dataclass +class _CacheItem: + accessed: float + handle: client.File + + +class ReadonlyFileHandleCache: + def __init__(self, loop: Any, max_items: int | None, ttl: int): + self.loop = loop + self._max_items = max_items + self._ttl = int(ttl) + self._cache: dict[str, _CacheItem] = {} + sync(loop, self._start_pruner) + weakref.finalize(self, self._close_all, loop, self._cache) + + @staticmethod + def _close_all(loop: Any, cache: dict[str, _CacheItem]) -> None: + if loop is not None and loop.is_running(): + + async def closure() -> None: + await asyncio.gather( + *(_async_wrap(item.handle.close)() for item in cache.values()) + ) + + try: + sync(loop, closure, timeout=0.5) + except (TimeoutError, FSTimeoutError, NotImplementedError): + pass + else: + # fire and forget + for item in cache.values(): + item.handle.close(callback=lambda *args: None) + cache.clear() + + def close_all(self) -> None: + self._close_all(self.loop, self._cache) + + async def _close(self, url: str, timeout: int) -> None: + item = self._cache.pop(url, None) + if item: + status, _ = await _async_wrap(item.handle.close)(timeout=timeout) + if not status.ok: + raise OSError(f"Failed to close file: {status.message}") + + close = sync_wrapper(_close) + + async def _start_pruner(self) -> None: + self._prune_task = asyncio.create_task(self._pruner()) + + async def _pruner(self) -> None: + while True: + await self._prune_cache(self._ttl // 2) + await asyncio.sleep(self._ttl) + + async def _prune_cache(self, timeout: int) -> None: + now = time.monotonic() + oldest_keys = sorted((item.accessed, key) for key, item in self._cache.items()) + to_close = [] + if self._max_items: + to_close += oldest_keys[: -self._max_items] + oldest_keys = oldest_keys[-self._max_items :] + for last_access, key in oldest_keys: + if now - last_access > self._ttl: + to_close.append((last_access, key)) + await asyncio.gather(*(self._close(key, timeout) for _, key in to_close)) + + async def _open(self, url: str, timeout: int) -> Any: # client.File + if url in self._cache: + item = self._cache[url] + item.accessed = time.monotonic() + return item.handle + handle = client.File() + status, _ = await _async_wrap(handle.open)( + url, + OpenFlags.READ, + timeout=timeout, + ) + if not status.ok: + raise OSError(f"Failed to open file: {status.message}") + self._cache[url] = _CacheItem(accessed=time.monotonic(), handle=handle) + await self._prune_cache(timeout) + return handle + + class XRootDFileSystem(AsyncFileSystem): # type: ignore[misc] protocol = "root" root_marker = "/" @@ -153,7 +233,7 @@ def __init__( self, hostid: str, asynchronous: bool = False, - loop: Any = None, + loop: asyncio.AbstractEventLoop | None = None, **storage_options: Any, ) -> None: """ @@ -177,15 +257,24 @@ def __init__( raise ValueError(f"Invalid hostid: {hostid!r}") storage_options.setdefault("listing_expiry_time", 0) self.storage_options = storage_options + self._readonly_filehandle_cache = ReadonlyFileHandleCache( + self.loop, + max_items=storage_options.get("filehandle_cache_size", 256), + ttl=storage_options.get("filehandle_cache_ttl", 30), + ) def invalidate_cache(self, path: str | None = None) -> None: if path is None: self.dircache.clear() + self._readonly_filehandle_cache.close_all() else: try: del self.dircache[path] except KeyError: pass + self._readonly_filehandle_cache.close( + self.unstrip_protocol(path), self.timeout + ) @staticmethod def _get_kwargs_from_urls(u: str) -> dict[Any, Any]: @@ -206,17 +295,22 @@ def _strip_protocol(cls, path: str | list[str]) -> Any: raise ValueError("Strip protocol not given string or list") def unstrip_protocol(self, name: str) -> str: - return f"{self.protocol}://{self.hostid}/{name}" + prefix = f"{self.protocol}://{self.hostid}/" + if name.startswith(prefix): + return name + return prefix + name async def _mkdir( self, path: str, create_parents: bool = True, **kwargs: Any ) -> None: if create_parents: - status, n = await _async_wrap( - self._myclient.mkdir, path, MkDirFlags.MAKEPATH, self.timeout + status, _ = await _async_wrap(self._myclient.mkdir)( + path, flags=MkDirFlags.MAKEPATH, timeout=self.timeout ) else: - status, n = await _async_wrap(self._myclient.mkdir, path, self.timeout) + status, _ = await _async_wrap(self._myclient.mkdir)( + path, timeout=self.timeout + ) if not status.ok: raise OSError(f"Directory not made properly: {status.message}") @@ -226,8 +320,8 @@ async def _makedirs(self, path: str, exist_ok: bool = False) -> None: raise OSError( "Location already exists and exist_ok arg was set to false" ) - status, n = await _async_wrap( - self._myclient.mkdir, path, MkDirFlags.MAKEPATH, self.timeout + status, _ = await _async_wrap(self._myclient.mkdir)( + path, MkDirFlags.MAKEPATH, timeout=self.timeout ) if not status.ok and not (status.code == ErrorCodes.INVALID_PATH and exist_ok): raise OSError(f"Directory not made properly: {status.message}") @@ -250,31 +344,30 @@ async def _rm( ) async def _rmdir(self, path: str) -> None: - status, n = await _async_wrap(self._myclient.rmdir, path, self.timeout) + status, _ = await _async_wrap(self._myclient.rmdir)(path, self.timeout) if not status.ok: raise OSError(f"Directory not removed properly: {status.message}") rmdir = sync_wrapper(_rmdir) async def _rm_file(self, path: str, **kwargs: Any) -> None: - status, n = await _async_wrap(self._myclient.rm, path, self.timeout) + status, _ = await _async_wrap(self._myclient.rm)(path, self.timeout) if not status.ok: raise OSError(f"File not removed properly: {status.message}") async def _touch(self, path: str, truncate: bool = False, **kwargs: Any) -> None: if truncate or not await self._exists(path): - status, _ = await _async_wrap( - self._myclient.truncate, path, 0, self.timeout + status, _ = await _async_wrap(self._myclient.truncate)( + path, size=0, timeout=self.timeout ) if not status.ok: raise OSError(f"File not touched properly: {status.message}") else: len = await self._info(path) - status, _ = await _async_wrap( - self._myclient.truncate, + status, _ = await _async_wrap(self._myclient.truncate)( path, - len.get("size"), - self.timeout, + size=len.get("size"), + timeout=self.timeout, ) if not status.ok: raise OSError(f"File not touched properly: {status.message}") @@ -282,7 +375,7 @@ async def _touch(self, path: str, truncate: bool = False, **kwargs: Any) -> None touch = sync_wrapper(_touch) async def _modified(self, path: str) -> Any: - status, statInfo = await _async_wrap(self._myclient.stat, path, self.timeout) + status, statInfo = await _async_wrap(self._myclient.stat)(path, self.timeout) # type: ignore[var-annotated] return statInfo.modtime modified = sync_wrapper(_modified) @@ -291,7 +384,7 @@ async def _exists(self, path: str, **kwargs: Any) -> bool: if path in self.dircache: return True else: - status, _ = await _async_wrap(self._myclient.stat, path, self.timeout) + status, _ = await _async_wrap(self._myclient.stat)(path, self.timeout) if status.code == ErrorCodes.INVALID_PATH: return False elif not status.ok: @@ -311,7 +404,7 @@ async def _info(self, path: str, **kwargs: Any) -> dict[str, Any]: } raise OSError("_ls_from_cache() failed to function") else: - status, deet = await _async_wrap(self._myclient.stat, path, self.timeout) + status, deet = await _async_wrap(self._myclient.stat)(path, self.timeout) if not status.ok: raise OSError(f"File stat request failed: {status.message}") if deet.flags & StatInfoFlags.IS_DIR: @@ -345,8 +438,8 @@ async def _ls(self, path: str, detail: bool = True, **kwargs: Any) -> list[Any]: os.path.basename(item["name"]) for item in self._ls_from_cache(path) ] else: - status, deets = await _async_wrap( - self._myclient.dirlist, path, DirListFlags.STAT, self.timeout + status, deets = await _async_wrap(self._myclient.dirlist)( # type: ignore[var-annotated] + path, DirListFlags.STAT, self.timeout ) if not status.ok: raise OSError( @@ -386,74 +479,50 @@ async def _ls(self, path: str, detail: bool = True, **kwargs: Any) -> list[Any]: async def _cat_file( self, path: str, start: int | None, end: int | None, **kwargs: Any ) -> Any: - _myFile = client.File() - try: - status, _n = await _async_wrap( - _myFile.open, - self.unstrip_protocol(path), - OpenFlags.READ, - self.timeout, - ) - if not status.ok: - raise OSError(f"File failed to read: {status.message}") - - n_bytes = end - if start is not None and end is not None: - n_bytes = end - start - - status, data = await _async_wrap( - _myFile.read, - start or 0, - n_bytes or 0, - self.timeout, - ) - if not status.ok: - raise OSError(f"Bytes failed to read from open file: {status.message}") - return data - finally: - status, _n = await _async_wrap( - _myFile.close, - self.timeout, - ) + _myFile = await self._readonly_filehandle_cache._open( + self.unstrip_protocol(path), + self.timeout, + ) + n_bytes = end + if start is not None and end is not None: + n_bytes = end - start + + status, data = await _async_wrap(_myFile.read)( # type: ignore[var-annotated] + start or 0, + n_bytes or 0, + self.timeout, + ) + if not status.ok: + raise OSError(f"Bytes failed to read from open file: {status.message}") + return data async def _get_file( self, rpath: str, lpath: str, chunk_size: int = 262_144, **kwargs: Any ) -> None: # Open the remote file for reading - remote_file = client.File() - - try: - status, _n = await _async_wrap( - remote_file.open, - self.unstrip_protocol(rpath), - OpenFlags.READ, - self.timeout, - ) - if not status.ok: - raise OSError(f"Remote file failed to open: {status.message}") - - with open(lpath, "wb") as local_file: - start: int = 0 - while True: - # Read a chunk of content from the remote file - status, chunk = await _async_wrap( - remote_file.read, start, chunk_size, self.timeout - ) - start += chunk_size + remote_file = await self._readonly_filehandle_cache._open( + self.unstrip_protocol(rpath), + self.timeout, + ) - if not status.ok: - raise OSError(f"Remote file failed to read: {status.message}") + with open(lpath, "wb") as local_file: + start: int = 0 + while True: + # Read a chunk of content from the remote file + status, chunk = await _async_wrap(remote_file.read)( # type: ignore[var-annotated] + start, chunk_size, self.timeout + ) + start += chunk_size - # Break if there is no more content - if not chunk: - break + if not status.ok: + raise OSError(f"Remote file failed to read: {status.message}") - # Write the chunk to the local file - local_file.write(chunk) + # Break if there is no more content + if not chunk: + break - finally: - # Close the remote file - await _async_wrap(remote_file.close, self.timeout) + # Write the chunk to the local file + local_file.write(chunk) @classmethod async def _get_max_chunk_info(cls, file: Any) -> tuple[int, int]: @@ -476,8 +545,8 @@ async def _get_max_chunk_info(cls, file: Any) -> tuple[int, int]: data_server = f"{data_server.protocol}://{data_server.hostid}/" if data_server not in cls._dataserver_info_cache: fs = client.FileSystem(data_server) - status, result = await _async_wrap( - fs.query, QueryCode.CONFIG, "readv_iov_max readv_ior_max" + status, result = await _async_wrap(fs.query)( # type: ignore[var-annotated] + QueryCode.CONFIG, "readv_iov_max readv_ior_max" ) if not status.ok: raise OSError( @@ -515,39 +584,24 @@ async def _cat_vector_read( Tuple containing path name and a list of returned bytes in the same order as requested. """ - try: - _myFile = client.File() - status, _n = await _async_wrap( - _myFile.open, - self.protocol + "://" + self.storage_options["hostid"] + "/" + path, - OpenFlags.READ, - self.timeout, - ) - if not status.ok: - raise OSError(f"File did not open properly: {status.message}") + _myFile = await self._readonly_filehandle_cache._open( + self.unstrip_protocol(path), + self.timeout, + ) - max_num_chunks, max_chunk_size = await self._get_max_chunk_info(_myFile) - vectors = _chunks_to_vectors(chunks, max_num_chunks, max_chunk_size) + max_num_chunks, max_chunk_size = await self._get_max_chunk_info(_myFile) + vectors = _chunks_to_vectors(chunks, max_num_chunks, max_chunk_size) - coros = [_async_wrap(_myFile.vector_read, v, self.timeout) for v in vectors] + coros = [_async_wrap(_myFile.vector_read)(v, self.timeout) for v in vectors] # type: ignore[var-annotated] - results = await _run_coros_in_chunks( - coros, batch_size=batch_size, nofiles=True - ) - result_bufs = [] - for status, buffers in results: - if not status.ok: - raise OSError( - f"File did not vector_read properly: {status.message}" - ) - result_bufs.append(buffers) - deets = _vectors_to_chunks(chunks, result_bufs) + results = await _run_coros_in_chunks(coros, batch_size=batch_size, nofiles=True) + result_bufs = [] + for status, buffers in results: + if not status.ok: + raise OSError(f"File did not vector_read properly: {status.message}") + result_bufs.append(buffers) + deets = _vectors_to_chunks(chunks, result_bufs) - finally: - status, _n = await _async_wrap( - _myFile.close, - self.timeout, - ) return (path, deets) async def _cat_ranges( @@ -702,13 +756,14 @@ def __init__( if not isinstance(path, str): raise ValueError(f"Path expected to be string, path: {path}") + # Ensure any read-only handle is closed + fs.invalidate_cache(path) self._myFile = client.File() - status, _n = self._myFile.open( + status, _ = self._myFile.open( fs.unstrip_protocol(path), self.mode, timeout=self.timeout, ) - if not status.ok: raise OSError(f"File did not open properly: {status.message}") diff --git a/tests/test_basicio.py b/tests/test_basicio.py index ed4033b..d257575 100644 --- a/tests/test_basicio.py +++ b/tests/test_basicio.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import os import shutil import socket @@ -49,6 +50,10 @@ def localserver(tmpdir_factory): @pytest.fixture() def clear_server(localserver): remoteurl, localpath = localserver + fs, _, _ = fsspec.get_fs_token_paths(remoteurl) + # The open file handles on client side imply an open file handle on the server, + # so removing the directory doesn't actually work until the client closes its handles! + fs.invalidate_cache() shutil.rmtree(localpath) os.mkdir(localpath) yield @@ -467,3 +472,34 @@ def test_cache_directory(localserver, clear_server, tmp_path): with open(cache_directory / os.listdir(cache_directory)[0], "rb") as f: contents = f.read() assert contents == TESTDATA1.encode("utf-8") + + +def test_close_while_reading(localserver, clear_server): + remoteurl, localpath = localserver + data = TESTDATA1 * int(1e8 / len(TESTDATA1)) + with open(localpath + "/testfile.txt", "w") as fout: + fout.write(data) + + fs, _, (path,) = fsspec.get_fs_token_paths(remoteurl + "/testfile.txt") + + async def reader(): + tic = time.monotonic() + await fs._cat_file(path, start=0, end=None) + toc = time.monotonic() + return tic, toc + + async def closer(): + await asyncio.sleep(0.001) + tic = time.monotonic() + await fs._readonly_filehandle_cache._close(path, 1) + toc = time.monotonic() + return tic, toc + + async def run(): + (read_start, read_stop), (close_start, close_stop) = await asyncio.gather( + reader(), closer() + ) + assert read_start < close_start < read_stop + assert read_start < close_stop < read_stop + + asyncio.run(run())