From cf42210487c4b3b9e58b63ddeffffe6416f49b31 Mon Sep 17 00:00:00 2001 From: devsjc <47188100+devsjc@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:31:18 +0000 Subject: [PATCH] feat(services): Use contextmanager for performancemonitor --- .../internal/entities/performance.py | 54 ++++-- .../internal/entities/test_tensorstore.py | 13 +- .../internal/services/archiver_service.py | 150 ++++++++--------- .../internal/services/consumer_service.py | 156 +++++++++--------- .../internal/services/test_consumer.py | 14 +- 5 files changed, 202 insertions(+), 185 deletions(-) diff --git a/src/nwp_consumer/internal/entities/performance.py b/src/nwp_consumer/internal/entities/performance.py index 3c646c6d..4b36fe27 100644 --- a/src/nwp_consumer/internal/entities/performance.py +++ b/src/nwp_consumer/internal/entities/performance.py @@ -6,6 +6,7 @@ import time from threading import Thread +from types import TracebackType import psutil @@ -20,25 +21,42 @@ class PerformanceMonitor(Thread): thread: Thread memory_buffer: list[int] - stop: bool + cpu_buffer: list[float] start_time: float end_time: float + stop: bool = True - def __init__(self) -> None: - """Create a new instance.""" + def __enter__(self) -> None: + """Start the monitor.""" super().__init__() self.stop = False self.memory_buffer: list[int] = [] + self.cpu_buffer: list[float] = [] self.start_time = time.time() self.start() - def get_memory(self) -> int: - """Get memory of a process and its children.""" + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Stop the performance monitor, saving the results.""" + self.stop = True + self.end_time = time.time() + super().join(timeout=30) + + def get_usage(self) -> tuple[int, float]: + """Get usage of a process and its children.""" p = psutil.Process() + # CPU usage of process and its children + cpu: float = p.cpu_percent() + # Memory usage does not reflect child processes + # * Manually add the memory usage of child processes memory: int = p.memory_info().rss for c in p.children(): memory += c.memory_info().rss - return memory + return memory, cpu def get_runtime(self) -> int: """Get the runtime of the thread in seconds.""" @@ -46,17 +64,21 @@ def get_runtime(self) -> int: def run(self) -> None: """Run the thread.""" - memory_start = self.get_memory() + memory_start, cpu_start = self.get_usage() while not self.stop: - self.memory_buffer.append(self.get_memory() - memory_start) + new_memory, new_cpu = self.get_usage() + # Memory is just a total, so get the delta + self.memory_buffer.append(new_memory - memory_start) + # CPU is calculated by psutil against the base CPU, + # so no need to get a delta + self.cpu_buffer.append(new_cpu) time.sleep(0.2) - def join(self, timeout: int | None = None) -> None: # type: ignore - """Stop the thread.""" - self.stop = True - self.end_time = time.time() - super().join(timeout=timeout) + def max_memory_mb(self) -> float: + """Get the maximum memory usage during the thread's runtime.""" + return max(self.memory_buffer) / 1e6 + + def max_cpu_percent(self) -> float: + """Get the maximum CPU usage during the thread's runtime.""" + return max(self.cpu_buffer) - def __enter__(self) -> "PerformanceMonitor": - """Enter a context.""" - return self diff --git a/src/nwp_consumer/internal/entities/test_tensorstore.py b/src/nwp_consumer/internal/entities/test_tensorstore.py index 07f3590e..69b8a835 100644 --- a/src/nwp_consumer/internal/entities/test_tensorstore.py +++ b/src/nwp_consumer/internal/entities/test_tensorstore.py @@ -5,6 +5,7 @@ import os import unittest from collections.abc import Generator +from types import TracebackType from unittest.mock import patch import numpy as np @@ -23,13 +24,14 @@ logging.getLogger("werkzeug").setLevel(logging.ERROR) -class MockS3Bucket(contextlib.ContextDecorator): +class MockS3Bucket: client: BotocoreClient server: ThreadedMotoServer bucket: str = "test-bucket" def __enter__(self) -> None: + """Create a mock S3 server and bucket.""" self.server = ThreadedMotoServer() self.server.start() @@ -46,10 +48,16 @@ def __enter__(self) -> None: Bucket=self.bucket, ) - def __exit__(self, *exc) -> bool: # type:ignore + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: response = self.client.list_objects_v2( Bucket=self.bucket, ) + """Delete all objects in the bucket and stop the server.""" if "Contents" in response: for obj in response["Contents"]: self.client.delete_object( @@ -57,7 +65,6 @@ def __exit__(self, *exc) -> bool: # type:ignore Key=obj["Key"], ) self.server.stop() - return False class TestTensorStore(unittest.TestCase): diff --git a/src/nwp_consumer/internal/services/archiver_service.py b/src/nwp_consumer/internal/services/archiver_service.py index eea68b14..94901f1b 100644 --- a/src/nwp_consumer/internal/services/archiver_service.py +++ b/src/nwp_consumer/internal/services/archiver_service.py @@ -40,92 +40,88 @@ def __init__( @override def archive(self, year: int, month: int) -> ResultE[str]: monitor = entities.PerformanceMonitor() + with monitor: + + init_times = self.mr.repository().month_its(year=year, month=month) + + # Create a store for the archive + init_store_result: ResultE[entities.TensorStore] = \ + entities.TensorStore.initialize_empty_store( + model=self.mr.model().name, + repository=self.mr.repository().name, + coords=dataclasses.replace( + self.mr.model().expected_coordinates, + init_time=init_times, + ), + ) - init_times = self.mr.repository().month_its(year=year, month=month) + if isinstance(init_store_result, Failure): + return Failure(OSError( + f"Failed to initialize store for {year}-{month}: {init_store_result!s}"), + ) + store = init_store_result.unwrap() + + missing_times_result = store.missing_times() + if isinstance(missing_times_result, Failure): + return Failure(missing_times_result.failure()) + log.info(f"{len(missing_times_result.unwrap())} missing init_times in store.") + + failed_times: list[dt.datetime] = [] + for n, it in enumerate(missing_times_result.unwrap()): + log.info( + f"Consuming data from {self.mr.repository().name} for {it:%Y-%m-%d %H:%M} " + f"(time {n + 1}/{len(missing_times_result.unwrap())})", + ) - # Create a store for the archive - init_store_result: ResultE[entities.TensorStore] = \ - entities.TensorStore.initialize_empty_store( - model=self.mr.model().name, - repository=self.mr.repository().name, - coords=dataclasses.replace( - self.mr.model().expected_coordinates, - init_time=init_times, - ), - ) + # Authenticate with the model repository + amr_result = self.mr.authenticate() + if isinstance(amr_result, Failure): + store.delete_store() + return Failure(OSError( + "Unable to authenticate with model repository " + f"'{self.mr.repository().name}': " + f"{amr_result.failure()}", + )) + amr = amr_result.unwrap() + + # Create a generator to fetch and process raw data + n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections) + if os.getenv("CONCURRENCY", "True").capitalize() == "False": + n_jobs = 1 + log.debug(f"Downloading using {n_jobs} concurrent thread(s)") + da_result_generator = Parallel( + n_jobs=n_jobs, + prefer="threads", + return_as="generator_unordered", + )(amr.fetch_init_data(it=it)) + + # Regionally write the results of the generator as they are ready + for da_result in da_result_generator: + write_result = da_result.bind(store.write_to_region) + # Fail soft if a region fails to write + if isinstance(write_result, Failure): + log.error(f"Failed to write time {it:%Y-%m-%d %H:%M}: {write_result}") + failed_times.append(it) + + del da_result_generator + + # Add the failed times to the store's metadata + store.update_attrs({ + "failed_times": ", ".join([t.strftime("Day %d %H:%M") for t in failed_times]), + }) + + # Postprocess the dataset as required + # postprocess_result = store.postprocess(self._mr.metadata().postprocess_options) + # if isinstance(postprocess_result, Failure): + # return Failure(postprocess_result.failure()) - if isinstance(init_store_result, Failure): - monitor.join() # TODO: Make this a context manager instead - return Failure(OSError( - f"Failed to initialize store for {year}-{month}: {init_store_result!s}"), - ) - store = init_store_result.unwrap() - - missing_times_result = store.missing_times() - if isinstance(missing_times_result, Failure): - monitor.join() - return Failure(missing_times_result.failure()) - log.info(f"{len(missing_times_result.unwrap())} missing init_times in store.") - - failed_times: list[dt.datetime] = [] - for n, it in enumerate(missing_times_result.unwrap()): - log.info( - f"Consuming data from {self.mr.repository().name} for {it:%Y-%m-%d %H:%M} " - f"(time {n + 1}/{len(missing_times_result.unwrap())})", - ) - - # Authenticate with the model repository - amr_result = self.mr.authenticate() - if isinstance(amr_result, Failure): - store.delete_stoe() - monitor.join() - return Failure(OSError( - "Unable to authenticate with model repository " - f"'{self.mr.repository().name}': " - f"{amr_result.failure()}", - )) - amr = amr_result.unwrap() - - # Create a generator to fetch and process raw data - n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections) - if os.getenv("CONCURRENCY", "True").capitalize() == "False": - n_jobs = 1 - log.debug(f"Downloading using {n_jobs} concurrent thread(s)") - da_result_generator = Parallel( - n_jobs=n_jobs, - prefer="threads", - return_as="generator_unordered", - )(amr.fetch_init_data(it=it)) - - # Regionally write the results of the generator as they are ready - for da_result in da_result_generator: - write_result = da_result.bind(store.write_to_region) - # Fail soft if a region fails to write - if isinstance(write_result, Failure): - log.error(f"Failed to write time {it:%Y-%m-%d %H:%M}: {write_result}") - failed_times.append(it) - - del da_result_generator - - # Add the failed times to the store's metadata - store.update_attrs({ - "failed_times": ", ".join([t.strftime("Day %d %H:%M") for t in failed_times]), - }) - - # Postprocess the dataset as required - # postprocess_result = store.postprocess(self._mr.metadata().postprocess_options) - # if isinstance(postprocess_result, Failure): - # monitor.join() # TODO: Make this a context manager instead - # return Failure(postprocess_result.failure()) - - monitor.join() notify_result = self.nr().notify( message=entities.StoreCreatedNotification( filename=pathlib.Path(store.path).name, size_mb=store.size_kb // 1024, performance=entities.PerformanceMetadata( duration_seconds=monitor.get_runtime(), - memory_mb=max(monitor.memory_buffer) / 1e6, + memory_mb=monitor.max_memory_mb(), ), ), ) diff --git a/src/nwp_consumer/internal/services/consumer_service.py b/src/nwp_consumer/internal/services/consumer_service.py index 46d603af..749c90b9 100644 --- a/src/nwp_consumer/internal/services/consumer_service.py +++ b/src/nwp_consumer/internal/services/consumer_service.py @@ -42,98 +42,96 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[str]: # generators involved - it seemed clearer to be explicit. However, # it would be much neater to refactor this to be more functional. monitor = entities.PerformanceMonitor() - - if it is None: - it = self.mr.repository().determine_latest_it_from(dt.datetime.now(tz=dt.UTC)) - log.info( - f"Consuming data from repository '{self.mr.repository().name}' " - f"for the '{self.mr.model().name}' model " - f"spanning init time '{it:%Y-%m-%d %H:%M}'", - ) - - # Create a store for the init time - init_store_result: ResultE[entities.TensorStore] = \ - entities.TensorStore.initialize_empty_store( - model=self.mr.model().name, - repository=self.mr.repository().name, - coords=dataclasses.replace(self.mr.model().expected_coordinates, init_time=[it]), + with monitor: + if it is None: + it = self.mr.repository().determine_latest_it_from(dt.datetime.now(tz=dt.UTC)) + log.info( + f"Consuming data from repository '{self.mr.repository().name}' " + f"for the '{self.mr.model().name}' model " + f"spanning init time '{it:%Y-%m-%d %H:%M}'", ) - if isinstance(init_store_result, Failure): - monitor.join() # TODO: Make this a context manager instead - return Failure(OSError( - f"Failed to initialize store for init time: {init_store_result!s}", - )) - store = init_store_result.unwrap() - - amr_result = self.mr.authenticate() - if isinstance(amr_result, Failure): - monitor.join() - store.delete_store() - return Failure(OSError( - "Unable to authenticate with model repository " - f"'{self.mr.repository().name}': " - f"{amr_result.failure()}", - )) - amr = amr_result.unwrap() - - # Create a generator to fetch and process raw data - n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections) - if os.getenv("CONCURRENCY", "True").capitalize() == "False": - n_jobs = 1 - log.debug(f"Downloading using {n_jobs} concurrent thread(s)") - fetch_result_generator = Parallel( - n_jobs=n_jobs, - prefer="threads", - return_as="generator_unordered", - )(amr.fetch_init_data(it=it)) - - # Regionally write the results of the generator as they are ready - failed_etls: int = 0 - for fetch_result in fetch_result_generator: - if isinstance(fetch_result, Failure): - log.error( - f"Error fetching data for init time '{it:%Y-%m-%d %H:%M}' " - f"and model {self.mr.repository().name}: {fetch_result.failure()!s}", + # Create a store for the init time + init_store_result: ResultE[entities.TensorStore] = \ + entities.TensorStore.initialize_empty_store( + model=self.mr.model().name, + repository=self.mr.repository().name, + coords=dataclasses.replace( + self.mr.model().expected_coordinates, + init_time=[it], + ), ) - failed_etls += 1 - continue - for da in fetch_result.unwrap(): - write_result = store.write_to_region(da) - if isinstance(write_result, Failure): + + if isinstance(init_store_result, Failure): + return Failure(OSError( + f"Failed to initialize store for init time: {init_store_result!s}", + )) + store = init_store_result.unwrap() + + amr_result = self.mr.authenticate() + if isinstance(amr_result, Failure): + store.delete_store() + return Failure(OSError( + "Unable to authenticate with model repository " + f"'{self.mr.repository().name}': " + f"{amr_result.failure()}", + )) + amr = amr_result.unwrap() + + # Create a generator to fetch and process raw data + n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections) + if os.getenv("CONCURRENCY", "True").capitalize() == "False": + n_jobs = 1 + log.debug(f"Downloading using {n_jobs} concurrent thread(s)") + fetch_result_generator = Parallel( + n_jobs=n_jobs, + prefer="threads", + return_as="generator_unordered", + )(amr.fetch_init_data(it=it)) + + # Regionally write the results of the generator as they are ready + failed_etls: int = 0 + for fetch_result in fetch_result_generator: + if isinstance(fetch_result, Failure): log.error( - f"Error writing data for init time '{it:%Y-%m-%d %H:%M}' " - f"and model {self.mr.repository().name}: " - f"{write_result.failure()!s}", + f"Error fetching data for init time '{it:%Y-%m-%d %H:%M}' " + f"and model {self.mr.repository().name}: {fetch_result.failure()!s}", ) failed_etls += 1 + continue + for da in fetch_result.unwrap(): + write_result = store.write_to_region(da) + if isinstance(write_result, Failure): + log.error( + f"Error writing data for init time '{it:%Y-%m-%d %H:%M}' " + f"and model {self.mr.repository().name}: " + f"{write_result.failure()!s}", + ) + failed_etls += 1 + + del fetch_result_generator + # Fail hard if any of the writes failed + # * TODO: Consider just how hard we want to fail in this instance + if failed_etls > 0: + store.delete_store() + return Failure(OSError( + f"Failed to write {failed_etls} regions " + f"for init time '{it:%Y-%m-%d %H:%M}'. " + "See error logs for details.", + )) + + # Postprocess the dataset as required + # postprocess_result = store.postprocess(self.mr.repository().postprocess_options) + # if isinstance(postprocess_result, Failure): + # return Failure(postprocess_result.failure()) - del fetch_result_generator - # Fail hard if any of the writes failed - # * TODO: Consider just how hard we want to fail in this instance - if failed_etls > 0: - monitor.join() - store.delete_store() - return Failure(OSError( - f"Failed to write {failed_etls} regions " - f"for init time '{it:%Y-%m-%d %H:%M}'. " - "See error logs for details.", - )) - - # Postprocess the dataset as required - # postprocess_result = store.postprocess(self.mr.repository().postprocess_options) - # if isinstance(postprocess_result, Failure): - # monitor.join() # TODO: Make this a context manager instead - # return Failure(postprocess_result.failure()) - - monitor.join() notify_result = self.nr().notify( message=entities.StoreCreatedNotification( filename=pathlib.Path(store.path).name, size_mb=store.size_kb // 1024, # TODO: 2024-11-19 check this is right performance=entities.PerformanceMetadata( duration_seconds=monitor.get_runtime(), - memory_mb=max(monitor.memory_buffer) / 1e6, + memory_mb=monitor.max_memory_mb(), ), ), ) diff --git a/src/nwp_consumer/internal/services/test_consumer.py b/src/nwp_consumer/internal/services/test_consumer.py index f73ecd96..8936835d 100644 --- a/src/nwp_consumer/internal/services/test_consumer.py +++ b/src/nwp_consumer/internal/services/test_consumer.py @@ -1,5 +1,5 @@ import datetime as dt -import pathlib +import shutil import unittest from collections.abc import Callable, Iterator from typing import override @@ -87,18 +87,9 @@ def notify( self, message: entities.StoreAppendedNotification | entities.StoreCreatedNotification, ) -> ResultE[str]: - """See parent class.""" return Success(str(message)) -class DummyZarrRepository(ports.ZarrRepository): - - @override - def save(self, src: pathlib.Path, dst: pathlib.Path) -> ResultE[str]: - """See parent class.""" - return Success(str(dst)) - - class TestParallelConsumer(unittest.TestCase): def test_consume(self) -> None: @@ -120,7 +111,10 @@ def test_consume(self) -> None: ["init_time", "step", "variable", "latitude", "longitude"], ) + path = result.unwrap() + shutil.rmtree(path) if __name__ == "__main__": unittest.main() +