diff --git a/acquire/volatilestream.py b/acquire/volatilestream.py index 19e57f0..d985a83 100644 --- a/acquire/volatilestream.py +++ b/acquire/volatilestream.py @@ -1,7 +1,9 @@ import os +from concurrent import futures from io import SEEK_SET, UnsupportedOperation from pathlib import Path from stat import S_IRGRP, S_IROTH, S_IRUSR +from typing import Any, Callable from dissect.util.stream import AlignedStream @@ -14,6 +16,35 @@ HAS_FCNTL = False +def timeout(func: Callable, *, timelimit: int) -> Callable: + """Timeout a function if it takes too long to complete. + + Args: + func: a function to wrap. + timelimit: The time in seconds that an operation is allowed to run. + + Raises: + TimeoutError: If its time exceeds the timelimit + """ + + def wrapper(*args: Any, **kwargs: Any) -> Any: + with futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(func, *args, **kwargs) + + try: + result = future.result(timelimit) + except futures.TimeoutError: + raise TimeoutError + finally: + # Make sure the thread stops right away. + executor._threads.clear() + futures.thread._threads_queues.clear() + + return result + + return wrapper + + class VolatileStream(AlignedStream): """Streaming class to handle various procfs and sysfs edge-cases. Backed by `AlignedStream`. @@ -41,6 +72,8 @@ def __init__( st_mode = os.fstat(self.fd).st_mode write_only = (st_mode & (S_IRUSR | S_IRGRP | S_IROTH)) == 0 # novermin + self._os_read = timeout(os.read, timelimit=5) + super().__init__(0 if write_only else size) def seek(self, pos: int, whence: int = SEEK_SET) -> int: @@ -53,8 +86,8 @@ def _read(self, offset: int, length: int) -> bytes: result = [] while length: try: - buf = os.read(self.fd, min(length, self.size - offset)) - except BlockingIOError: + buf = self._os_read(self.fd, min(length, self.size - offset)) + except (BlockingIOError, TimeoutError): break if not buf: diff --git a/tests/test_volatile.py b/tests/test_volatile.py new file mode 100644 index 0000000..6e73df3 --- /dev/null +++ b/tests/test_volatile.py @@ -0,0 +1,20 @@ +from time import sleep, time + +import pytest + +from acquire.volatilestream import timeout + + +def test_timeout(): + def snooze(): + sleep(10) + + function = timeout(snooze, timelimit=5) + start = time() + + with pytest.raises(TimeoutError): + function() + + end = time() + + assert end - start < 6