diff --git a/.gitignore b/.gitignore index 5d8a5e59..239e8455 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,5 @@ _readthedocs **/charts/coffea-casa/charts **/Chart.lock +# pyright lsp +pyrightconfig.json diff --git a/coffea_casa/plugin.py b/coffea_casa/plugin.py index 181b5928..ef966f18 100644 --- a/coffea_casa/plugin.py +++ b/coffea_casa/plugin.py @@ -4,9 +4,12 @@ import logging import uuid import subprocess +import gc +import datetime -from distributed.diagnostics.plugin import NannyPlugin -from dask.utils import tmpfile +from distributed.compatibility import PeriodicCallback +from distributed.diagnostics.plugin import NannyPlugin, WorkerPlugin +from dask.utils import tmpfile, parse_bytes logger = logging.getLogger(__name__) @@ -128,3 +131,56 @@ def teardown(self, nanny): return return + + + +class PeriodicGC(WorkerPlugin): + """ + A WorkerPlugin that periodically triggers garbage collection (GC) on a worker node. + The GC is triggered if the process memory exceeds a specified threshold. + + Attributes + ---------- + freq : datetime.timedelta + The frequency of garbage collection. Default is 1ms. + tresh : int + The threshold memory in bytes. If the process memory exceeds this value, garbage collection is triggered. Default is 100MB. + + + Setup via: + >>> periodic_gc = PeriodicGC() + >>> client.register_plugin(periodic_gc) + """ + + def __init__( + self, + freq: datetime.timedelta = datetime.timedelta(milliseconds=1), + tresh: int = parse_bytes("100 MB"), + ) -> None: + """ + Parameters: + freq: Frequency of garbage collection in seconds. Default is 1ms. + tresh: Threshold memory in bytes. If the process memory exceeds this value, garbage collection is triggered. Default is 100MB. + """ + self.freq = freq + self.tresh = tresh + + def setup(self, worker) -> None: + """ + Set up the periodic callback for garbage collection on the worker node. + + Parameters + ---------- + worker : distributed.worker.Worker + The worker node on which to set up the periodic callback. + """ + pc = PeriodicCallback(self._gc_collect, self.freq) + worker.periodic_callbacks["coffea_casa_gc_collect"] = pc + self.worker = worker + + def _gc_collect(self) -> None: + """ + Trigger garbage collection if the process memory exceeds the threshold. + """ + if self.worker.monitor.get_process_memory() >= self.tresh: + gc.collect()