Skip to content

Commit

Permalink
feat(service): Enable concurrency specification
Browse files Browse the repository at this point in the history
  • Loading branch information
devsjc committed Nov 4, 2024
1 parent d836b23 commit 17d4b14
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 5 deletions.
3 changes: 2 additions & 1 deletion Containerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ COPY pyproject.toml /_lock/
# This layer is cached until uv.lock or pyproject.toml change.
# Delete any unwanted parts of the installed packages to reduce size
RUN --mount=type=cache,target=/root/.cache \
apt-get update && apt-get install build-essential -y && \
apt-get update && apt-get install gcc -y && \
echo "Creating virtualenv at /venv" && \
conda create -qy -p /venv python=3.12 numcodecs
RUN which gcc
Expand All @@ -49,6 +49,7 @@ RUN echo "Installing dependencies into /venv" && \
COPY . /src
RUN --mount=type=cache,target=/root/.cache \
uv pip install --no-deps --python=$UV_PROJECT_ENVIRONMENT /src
RUN uv pip install dllist && python

# Copy the virtualenv into a distroless image
# * These are small images that only contain the runtime dependencies
Expand Down
2 changes: 2 additions & 0 deletions src/nwp_consumer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
+-------------------------------+-------------------------------------+---------------------------------------------+
| MODEL_REPOSITORY | The model repository to use. | ceda-metoffice-global |
+-------------------------------+-------------------------------------+---------------------------------------------+
| CONCURRENCY | Whether to use concurrency. | True |
+-------------------------------+-------------------------------------+---------------------------------------------+
Development Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]:
).with_suffix(".grib").expanduser()

# Only download the file if not already present
if not local_path.exists():
if not local_path.exists() or local_path.stat().st_size == 0:
local_path.parent.mkdir(parents=True, exist_ok=True)
log.debug("Requesting file from S3 at: '%s'", url)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@ class TestCase:
max_step=max(NOAAGFSS3ModelRepository.model().expected_coordinates.step),
)
self.assertEqual(result, t.expected)

9 changes: 7 additions & 2 deletions src/nwp_consumer/internal/services/archiver_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import dataclasses
import logging
import os
import pathlib
from typing import TYPE_CHECKING, override

from joblib import Parallel
from joblib import Parallel, cpu_count
from returns.result import Failure, ResultE, Success

from nwp_consumer.internal import entities, ports
Expand Down Expand Up @@ -86,8 +87,12 @@ def archive(self, year: int, month: int) -> ResultE[pathlib.Path]:
amr = amr_result.unwrap()

# Create a generator to fetch and process raw data
n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections)
if os.getenv("CONCURRENCY", "True").capitalize() == "False":
n_jobs = 1
log.debug(f"Downloading using {n_jobs} concurrent thread(s)")
da_result_generator = Parallel(
n_jobs=self.mr.repository().max_connections - 1,
n_jobs=n_jobs,
prefer="threads",
return_as="generator_unordered",
)(amr.fetch_init_data(it=it))
Expand Down
7 changes: 6 additions & 1 deletion src/nwp_consumer/internal/services/consumer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import datetime as dt
import logging
import os
import pathlib
from typing import override

Expand Down Expand Up @@ -72,9 +73,13 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]:
))
amr = amr_result.unwrap()

n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections)
if os.getenv("CONCURRENCY", "True").capitalize() == "False":
n_jobs = 1
log.debug(f"Downloading using {n_jobs} concurrent thread(s)")
fetch_result_generator = Parallel(
# TODO - fix segfault when using multiple threads
n_jobs=max(cpu_count() - 1, self.mr.repository().max_connections),
n_jobs=n_jobs,
prefer="threads",
return_as="generator_unordered",
)(amr.fetch_init_data(it=it))
Expand Down

0 comments on commit 17d4b14

Please sign in to comment.