Skip to content

Commit

Permalink
feat(services): Use contextmanager for performancemonitor
Browse files Browse the repository at this point in the history
  • Loading branch information
devsjc committed Nov 22, 2024
1 parent b478038 commit cf42210
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 185 deletions.
54 changes: 38 additions & 16 deletions src/nwp_consumer/internal/entities/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import time
from threading import Thread
from types import TracebackType

import psutil

Expand All @@ -20,43 +21,64 @@ class PerformanceMonitor(Thread):

thread: Thread
memory_buffer: list[int]
stop: bool
cpu_buffer: list[float]
start_time: float
end_time: float
stop: bool = True

def __init__(self) -> None:
"""Create a new instance."""
def __enter__(self) -> None:
"""Start the monitor."""
super().__init__()
self.stop = False
self.memory_buffer: list[int] = []
self.cpu_buffer: list[float] = []
self.start_time = time.time()
self.start()

def get_memory(self) -> int:
"""Get memory of a process and its children."""
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Stop the performance monitor, saving the results."""
self.stop = True
self.end_time = time.time()
super().join(timeout=30)

def get_usage(self) -> tuple[int, float]:
"""Get usage of a process and its children."""
p = psutil.Process()
# CPU usage of process and its children
cpu: float = p.cpu_percent()
# Memory usage does not reflect child processes
# * Manually add the memory usage of child processes
memory: int = p.memory_info().rss
for c in p.children():
memory += c.memory_info().rss
return memory
return memory, cpu

def get_runtime(self) -> int:
"""Get the runtime of the thread in seconds."""
return int(self.end_time - self.start_time)

def run(self) -> None:
"""Run the thread."""
memory_start = self.get_memory()
memory_start, cpu_start = self.get_usage()
while not self.stop:
self.memory_buffer.append(self.get_memory() - memory_start)
new_memory, new_cpu = self.get_usage()
# Memory is just a total, so get the delta
self.memory_buffer.append(new_memory - memory_start)
# CPU is calculated by psutil against the base CPU,
# so no need to get a delta
self.cpu_buffer.append(new_cpu)
time.sleep(0.2)

def join(self, timeout: int | None = None) -> None: # type: ignore
"""Stop the thread."""
self.stop = True
self.end_time = time.time()
super().join(timeout=timeout)
def max_memory_mb(self) -> float:
"""Get the maximum memory usage during the thread's runtime."""
return max(self.memory_buffer) / 1e6

def max_cpu_percent(self) -> float:
"""Get the maximum CPU usage during the thread's runtime."""
return max(self.cpu_buffer)

def __enter__(self) -> "PerformanceMonitor":
"""Enter a context."""
return self
13 changes: 10 additions & 3 deletions src/nwp_consumer/internal/entities/test_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import unittest
from collections.abc import Generator
from types import TracebackType
from unittest.mock import patch

import numpy as np
Expand All @@ -23,13 +24,14 @@
logging.getLogger("werkzeug").setLevel(logging.ERROR)


class MockS3Bucket(contextlib.ContextDecorator):
class MockS3Bucket:

client: BotocoreClient
server: ThreadedMotoServer
bucket: str = "test-bucket"

def __enter__(self) -> None:
"""Create a mock S3 server and bucket."""
self.server = ThreadedMotoServer()
self.server.start()

Expand All @@ -46,18 +48,23 @@ def __enter__(self) -> None:
Bucket=self.bucket,
)

def __exit__(self, *exc) -> bool: # type:ignore
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
response = self.client.list_objects_v2(
Bucket=self.bucket,
)
"""Delete all objects in the bucket and stop the server."""
if "Contents" in response:
for obj in response["Contents"]:
self.client.delete_object(
Bucket=self.bucket,
Key=obj["Key"],
)
self.server.stop()
return False


class TestTensorStore(unittest.TestCase):
Expand Down
150 changes: 73 additions & 77 deletions src/nwp_consumer/internal/services/archiver_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,92 +40,88 @@ def __init__(
@override
def archive(self, year: int, month: int) -> ResultE[str]:
monitor = entities.PerformanceMonitor()
with monitor:

init_times = self.mr.repository().month_its(year=year, month=month)

# Create a store for the archive
init_store_result: ResultE[entities.TensorStore] = \
entities.TensorStore.initialize_empty_store(
model=self.mr.model().name,
repository=self.mr.repository().name,
coords=dataclasses.replace(
self.mr.model().expected_coordinates,
init_time=init_times,
),
)

init_times = self.mr.repository().month_its(year=year, month=month)
if isinstance(init_store_result, Failure):
return Failure(OSError(
f"Failed to initialize store for {year}-{month}: {init_store_result!s}"),
)
store = init_store_result.unwrap()

missing_times_result = store.missing_times()
if isinstance(missing_times_result, Failure):
return Failure(missing_times_result.failure())
log.info(f"{len(missing_times_result.unwrap())} missing init_times in store.")

failed_times: list[dt.datetime] = []
for n, it in enumerate(missing_times_result.unwrap()):
log.info(
f"Consuming data from {self.mr.repository().name} for {it:%Y-%m-%d %H:%M} "
f"(time {n + 1}/{len(missing_times_result.unwrap())})",
)

# Create a store for the archive
init_store_result: ResultE[entities.TensorStore] = \
entities.TensorStore.initialize_empty_store(
model=self.mr.model().name,
repository=self.mr.repository().name,
coords=dataclasses.replace(
self.mr.model().expected_coordinates,
init_time=init_times,
),
)
# Authenticate with the model repository
amr_result = self.mr.authenticate()
if isinstance(amr_result, Failure):
store.delete_store()
return Failure(OSError(
"Unable to authenticate with model repository "
f"'{self.mr.repository().name}': "
f"{amr_result.failure()}",
))
amr = amr_result.unwrap()

# Create a generator to fetch and process raw data
n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections)
if os.getenv("CONCURRENCY", "True").capitalize() == "False":
n_jobs = 1
log.debug(f"Downloading using {n_jobs} concurrent thread(s)")
da_result_generator = Parallel(
n_jobs=n_jobs,
prefer="threads",
return_as="generator_unordered",
)(amr.fetch_init_data(it=it))

# Regionally write the results of the generator as they are ready
for da_result in da_result_generator:
write_result = da_result.bind(store.write_to_region)
# Fail soft if a region fails to write
if isinstance(write_result, Failure):
log.error(f"Failed to write time {it:%Y-%m-%d %H:%M}: {write_result}")
failed_times.append(it)

del da_result_generator

# Add the failed times to the store's metadata
store.update_attrs({
"failed_times": ", ".join([t.strftime("Day %d %H:%M") for t in failed_times]),
})

# Postprocess the dataset as required
# postprocess_result = store.postprocess(self._mr.metadata().postprocess_options)
# if isinstance(postprocess_result, Failure):
# return Failure(postprocess_result.failure())

if isinstance(init_store_result, Failure):
monitor.join() # TODO: Make this a context manager instead
return Failure(OSError(
f"Failed to initialize store for {year}-{month}: {init_store_result!s}"),
)
store = init_store_result.unwrap()

missing_times_result = store.missing_times()
if isinstance(missing_times_result, Failure):
monitor.join()
return Failure(missing_times_result.failure())
log.info(f"{len(missing_times_result.unwrap())} missing init_times in store.")

failed_times: list[dt.datetime] = []
for n, it in enumerate(missing_times_result.unwrap()):
log.info(
f"Consuming data from {self.mr.repository().name} for {it:%Y-%m-%d %H:%M} "
f"(time {n + 1}/{len(missing_times_result.unwrap())})",
)

# Authenticate with the model repository
amr_result = self.mr.authenticate()
if isinstance(amr_result, Failure):
store.delete_stoe()
monitor.join()
return Failure(OSError(
"Unable to authenticate with model repository "
f"'{self.mr.repository().name}': "
f"{amr_result.failure()}",
))
amr = amr_result.unwrap()

# Create a generator to fetch and process raw data
n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections)
if os.getenv("CONCURRENCY", "True").capitalize() == "False":
n_jobs = 1
log.debug(f"Downloading using {n_jobs} concurrent thread(s)")
da_result_generator = Parallel(
n_jobs=n_jobs,
prefer="threads",
return_as="generator_unordered",
)(amr.fetch_init_data(it=it))

# Regionally write the results of the generator as they are ready
for da_result in da_result_generator:
write_result = da_result.bind(store.write_to_region)
# Fail soft if a region fails to write
if isinstance(write_result, Failure):
log.error(f"Failed to write time {it:%Y-%m-%d %H:%M}: {write_result}")
failed_times.append(it)

del da_result_generator

# Add the failed times to the store's metadata
store.update_attrs({
"failed_times": ", ".join([t.strftime("Day %d %H:%M") for t in failed_times]),
})

# Postprocess the dataset as required
# postprocess_result = store.postprocess(self._mr.metadata().postprocess_options)
# if isinstance(postprocess_result, Failure):
# monitor.join() # TODO: Make this a context manager instead
# return Failure(postprocess_result.failure())

monitor.join()
notify_result = self.nr().notify(
message=entities.StoreCreatedNotification(
filename=pathlib.Path(store.path).name,
size_mb=store.size_kb // 1024,
performance=entities.PerformanceMetadata(
duration_seconds=monitor.get_runtime(),
memory_mb=max(monitor.memory_buffer) / 1e6,
memory_mb=monitor.max_memory_mb(),
),
),
)
Expand Down
Loading

0 comments on commit cf42210

Please sign in to comment.