Skip to content

Commit

Permalink
feat(structure): improve code structure (second review)
Browse files Browse the repository at this point in the history
  • Loading branch information
fmind committed Mar 16, 2024
1 parent f568fbc commit 20671c7
Show file tree
Hide file tree
Showing 24 changed files with 139 additions and 134 deletions.
2 changes: 1 addition & 1 deletion src/bikes/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def fit(self, inputs: schemas.Inputs, targets: schemas.Targets) -> "BaselineSkle
@T.override
def predict(self, inputs: schemas.Inputs) -> schemas.Outputs:
model = self.get_internal_model()
prediction = model.predict(inputs) # np.ndarray
prediction = model.predict(inputs)
outputs = schemas.Outputs(
{schemas.OutputsSchema.prediction: prediction}, index=inputs.index
)
Expand Down
24 changes: 12 additions & 12 deletions src/bikes/io/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Saver(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
e.g., to switch between serialization flavors.
Parameters:
path (str): model path inside the MLflow store.
path (str): model path inside the Mlflow store.
"""

KIND: str
Expand All @@ -81,15 +81,15 @@ def save(


class CustomSaver(Saver):
"""Saver for project models using the MLflow PyFunc module.
"""Saver for project models using the Mlflow PyFunc module.
https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html
"""

KIND: T.Literal["CustomSaver"] = "CustomSaver"

class Adapter(mlflow.pyfunc.PythonModel): # type: ignore[misc]
"""Adapt a custom model to the MLflow PyFunc flavor for saving operations.
"""Adapt a custom model to the Mlflow PyFunc flavor for saving operations.
https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html?#mlflow.pyfunc.PythonModel
"""
Expand Down Expand Up @@ -134,12 +134,12 @@ def save(


class BuiltinSaver(Saver):
"""Saver for built-in models using an MLflow flavor module.
"""Saver for built-in models using an Mlflow flavor module.
https://mlflow.org/docs/latest/models.html#built-in-model-flavors
Parameters:
flavor (str): MLflow flavor module to use for the serialization.
flavor (str): Mlflow flavor module to use for the serialization.
"""

KIND: T.Literal["BuiltinSaver"] = "BuiltinSaver"
Expand Down Expand Up @@ -201,7 +201,7 @@ def load(self, uri: str) -> "Loader.Adapter":


class CustomLoader(Loader):
"""Loader for custom models using the MLflow PyFunc module.
"""Loader for custom models using the Mlflow PyFunc module.
https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html
"""
Expand Down Expand Up @@ -233,9 +233,9 @@ def load(self, uri: str) -> "CustomLoader.Adapter":


class BuiltinLoader(Loader):
"""Loader for built-in models using the MLflow PyFunc module.
"""Loader for built-in models using the Mlflow PyFunc module.
Note: use MLflow PyFunc instead of flavors to use standard API.
Note: use Mlflow PyFunc instead of flavors to use standard API.
https://mlflow.org/docs/latest/models.html#built-in-model-flavors
"""
Expand Down Expand Up @@ -298,17 +298,17 @@ def register(self, name: str, model_uri: str) -> Version:
"""


class MLflowRegister(Register):
"""Register for models in the MLflow Model Registry.
class MlflowRegister(Register):
"""Register for models in the Mlflow Model Registry.
https://mlflow.org/docs/latest/model-registry.html
"""

KIND: T.Literal["MLflowRegister"] = "MLflowRegister"
KIND: T.Literal["MlflowRegister"] = "MlflowRegister"

@T.override
def register(self, name: str, model_uri: str) -> Version:
return mlflow.register_model(name=name, model_uri=model_uri, tags=self.tags)


RegisterKind = MLflowRegister
RegisterKind = MlflowRegister
47 changes: 28 additions & 19 deletions src/bikes/io/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def logger(self) -> loguru.Logger:
return loguru.logger


class MLflowService(Service):
"""Service for MLflow tracking and registry.
class MlflowService(Service):
"""Service for Mlflow tracking and registry.
Parameters:
tracking_uri (str): the URI for the MLflow tracking server.
registry_uri (str): the URI for the MLflow model registry.
tracking_uri (str): the URI for the Mlflow tracking server.
registry_uri (str): the URI for the Mlflow model registry.
experiment_name (str): the name of tracking experiment.
registry_name (str): the name of model registry.
autolog_disable (bool): disable autologging.
Expand All @@ -96,9 +96,24 @@ class MLflowService(Service):
autolog_log_model_signatures (bool): If True, logs model signatures during autologging.
autolog_log_models (bool): If True, enables logging of models during autologging.
autolog_log_datasets (bool): If True, logs datasets used during autologging.
autolog_silent (bool): If True, suppresses all MLflow warnings during autologging.
autolog_silent (bool): If True, suppresses all Mlflow warnings during autologging.
"""

class RunConfig(pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
"""Run configuration for Mlflow tracking.
Parameters:
name (str): name of the run.
description (str | None): description of the run.
tags (dict[str, T.Any] | None): tags for the run.
log_system_metrics (bool | None): enable system metrics logging.
"""

name: str
description: str | None = None
tags: dict[str, T.Any] | None = None
log_system_metrics: bool | None = None

# server uri
tracking_uri: str = "./mlruns"
registry_uri: str = "./mlruns"
Expand Down Expand Up @@ -135,31 +150,25 @@ def start(self) -> None:
)

@ctx.contextmanager
def run(
self,
name: str,
description: str | None = None,
tags: dict[str, T.Any] | None = None,
log_system_metrics: bool | None = None,
) -> T.Generator[mlflow.ActiveRun, None, None]:
"""Yield an active MLflow run and exit it afterwards.
def run_context(self, run_config: RunConfig) -> T.Generator[mlflow.ActiveRun, None, None]:
"""Yield an active Mlflow run and exit it afterwards.
Args:
name (str): name of the run.
description (str | None, optional): description of the run. Defaults to None.
tags (dict[str, T.Any] | None, optional): dict of tags of the run. Defaults to None.
log_system_metrics (bool | None, optional): enable system metrics logging. Defaults to None.
run (str): run parameters.
Yields:
T.Generator[mlflow.ActiveRun, None, None]: active run context. Will be closed as the end of context.
"""
with mlflow.start_run(
run_name=name, description=description, tags=tags, log_system_metrics=log_system_metrics
run_name=run_config.name,
tags=run_config.tags,
description=run_config.description,
log_system_metrics=run_config.log_system_metrics,
) as run:
yield run

def client(self) -> mt.MlflowClient:
"""Return a new MLflow client.
"""Return a new Mlflow client.
Returns:
MlflowClient: the mlflow client.
Expand Down
8 changes: 4 additions & 4 deletions src/bikes/jobs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ class Job(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
Parameters:
logger_service (services.LoggerService): manage the logging system.
mlflow_service (services.MLflowService): manage the mlflow system.
mlflow_service (services.MlflowService): manage the mlflow system.
"""

KIND: str

logger_service: services.LoggerService = services.LoggerService()
mlflow_service: services.MLflowService = services.MLflowService()
mlflow_service: services.MlflowService = services.MlflowService()

def __enter__(self) -> T.Self:
"""Enter the job context.
Expand All @@ -43,7 +43,7 @@ def __enter__(self) -> T.Self:
self.logger_service.start()
logger = self.logger_service.logger()
logger.debug("[START] Logger service: {}", self.logger_service)
logger.debug("[START] MLflow service: {}", self.mlflow_service)
logger.debug("[START] Mlflow service: {}", self.mlflow_service)
self.mlflow_service.start()
return self

Expand All @@ -64,7 +64,7 @@ def __exit__(
T.Literal[False]: always propagate exceptions.
"""
logger = self.logger_service.logger()
logger.debug("[STOP] MLflow service: {}", self.mlflow_service)
logger.debug("[STOP] Mlflow service: {}", self.mlflow_service)
self.mlflow_service.stop()
logger.debug("[STOP] Logger service: {}", self.logger_service)
self.logger_service.stop()
Expand Down
2 changes: 1 addition & 1 deletion src/bikes/jobs/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class PromotionJob(base.Job):

KIND: T.Literal["PromotionJob"] = "PromotionJob"

version: int | None = None
alias: str = "Champion"
version: int | None = None

@T.override
def run(self) -> base.Locals:
Expand Down
16 changes: 5 additions & 11 deletions src/bikes/jobs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from bikes.core import metrics as metrics_
from bikes.core import models, schemas
from bikes.io import datasets, registries
from bikes.io import datasets, registries, services
from bikes.jobs import base
from bikes.utils import signers, splitters

Expand All @@ -19,9 +19,7 @@ class TrainingJob(base.Job):
"""Train and register a single AI/ML model.
Parameters:
run_name (str): name of the run.
run_description (str, optional): description of the run.
run_tags: (dict[str, T.Any], optional): tags for the run.
run_config (services.MlflowService.RunConfig): mlflow run config.
inputs (datasets.ReaderKind): reader for the inputs data.
targets (datasets.ReaderKind): reader for the targets data.
model (models.ModelKind): machine learning model to train.
Expand All @@ -35,9 +33,7 @@ class TrainingJob(base.Job):
KIND: T.Literal["TrainingJob"] = "TrainingJob"

# Run
run_name: str = "Tuning"
run_description: str | None = None
run_tags: dict[str, T.Any] | None = None
run_config: services.MlflowService.RunConfig = services.MlflowService.RunConfig(name="Training")
# Data
inputs: datasets.ReaderKind = pdt.Field(..., discriminator="KIND")
targets: datasets.ReaderKind = pdt.Field(..., discriminator="KIND")
Expand All @@ -55,7 +51,7 @@ class TrainingJob(base.Job):
signer: signers.SignerKind = pdt.Field(signers.InferSigner(), discriminator="KIND")
# Registrer
# - avoid shadowing pydantic `register` pydantic function
registry: registries.RegisterKind = pdt.Field(registries.MLflowRegister(), discriminator="KIND")
registry: registries.RegisterKind = pdt.Field(registries.MlflowRegister(), discriminator="KIND")

@T.override
def run(self) -> base.Locals:
Expand All @@ -65,9 +61,7 @@ def run(self) -> base.Locals:
logger.info("With logger: {}", logger)
# - mlflow
client = self.mlflow_service.client()
with self.mlflow_service.run(
name=self.run_name, description=self.run_description, tags=self.run_tags
) as run:
with self.mlflow_service.run_context(run_config=self.run_config) as run:
logger.info("With mlflow run id: {}", run.info.run_id)
# data
# - inputs
Expand Down
14 changes: 4 additions & 10 deletions src/bikes/jobs/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pydantic as pdt

from bikes.core import metrics, models, schemas
from bikes.io import datasets
from bikes.io import datasets, services
from bikes.jobs import base
from bikes.utils import searchers, splitters

Expand All @@ -18,9 +18,7 @@ class TuningJob(base.Job):
"""Find the best hyperparameters for a model.
Parameters:
run_name (str): name of the run.
run_description (str, optional): description of the run.
run_tags: (dict[str, T.Any], optional): tags for the run.
run_config (services.MlflowService.RunConfig): mlflow run config.
inputs (datasets.ReaderKind): reader for the inputs data.
targets (datasets.ReaderKind): reader for the targets data.
model (models.ModelKind): machine learning model to tune.
Expand All @@ -32,9 +30,7 @@ class TuningJob(base.Job):
KIND: T.Literal["TuningJob"] = "TuningJob"

# Run
run_name: str = "Tuning"
run_description: str | None = None
run_tags: dict[str, T.Any] | None = None
run_config: services.MlflowService.RunConfig = services.MlflowService.RunConfig(name="Tuning")
# Data
inputs: datasets.ReaderKind = pdt.Field(..., discriminator="KIND")
targets: datasets.ReaderKind = pdt.Field(..., discriminator="KIND")
Expand Down Expand Up @@ -64,9 +60,7 @@ def run(self) -> base.Locals:
logger = self.logger_service.logger()
logger.info("With logger: {}", logger)
# - mlflow
with self.mlflow_service.run(
name=self.run_name, description=self.run_description, tags=self.run_tags
) as run:
with self.mlflow_service.run_context(run_config=self.run_config) as run:
logger.info("With mlflow run id: {}", run.info.run_id)
# data
# - inputs
Expand Down
6 changes: 3 additions & 3 deletions src/bikes/utils/searchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ def search(
metric (metrics.Metric): main metric to optimize.
inputs (schemas.Inputs): model inputs for tuning.
targets (schemas.Targets): model targets for tuning.
cv (CrossValidation): structure for cross-folds strategy.
cv (CrossValidation): choice for cross-fold validation.
Returns:
Results: all the results of the searcher process.
Results: all the results of the searcher execution process.
"""


class GridCVSearcher(Searcher):
"""Grid searcher with cross-fold validation.
Metric should return higher values for better models.
Convention: metric returns higher values for better models.
Parameters:
n_jobs (int, optional): number of jobs to run in parallel.
Expand Down
2 changes: 1 addition & 1 deletion src/bikes/utils/signers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Signer(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
"""Base class for generating model signatures.
Allow to switch between model signing strategies.
e.g., automatic inference, manual signatures, ...
e.g., automatic inference, manual model signature, ...
https://mlflow.org/docs/latest/models.html#model-signature-and-input-example
"""
Expand Down
2 changes: 1 addition & 1 deletion src/bikes/utils/splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class TrainTestSplitter(Splitter):
"""Split a dataframe into a train and test set.
Parameters:
shuffle (bool): shuffle dataset before splitting it.
shuffle (bool): shuffle the dataset. Default is False.
test_size (int | float): number/ratio for the test set.
random_state (int): random state for the splitter object.
"""
Expand Down
4 changes: 2 additions & 2 deletions tasks/cleans.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def pytest(ctx: Context) -> None:

@task
def coverage(ctx: Context) -> None:
"""Clean coverage tool."""
"""Clean the coverage tool."""
ctx.run("rm -f .coverage*")


Expand Down Expand Up @@ -104,7 +104,7 @@ def folders(_: Context) -> None:

@task(pre=[venv, poetry, python])
def sources(_: Context) -> None:
"""Run all folders tasks."""
"""Run all sources tasks."""


@task(pre=[tools, folders], default=True)
Expand Down
2 changes: 1 addition & 1 deletion tasks/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@task
def code(ctx: Context) -> None:
"""Format code with ruff."""
"""Format python code with ruff."""
ctx.run("poetry run ruff format src/ tasks/ tests/")


Expand Down
2 changes: 1 addition & 1 deletion tasks/installs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def poetry(ctx: Context) -> None:

@task
def pre_commit(ctx: Context) -> None:
"""Run pre-commit install."""
"""Install pre-commit hooks on git."""
ctx.run("poetry run pre-commit install --hook-type pre-push")
ctx.run("poetry run pre-commit install --hook-type commit-msg")

Expand Down
2 changes: 1 addition & 1 deletion tasks/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def doctor(ctx: Context) -> None:
def serve(
ctx: Context, host: str = "127.0.0.1", port: str = "5000", backend_uri: str = "./mlruns"
) -> None:
""""""
"""Start mlflow server with the given host, port, and backend uri."""
ctx.run(
f"poetry run mlflow server --host={host} --port={port} --backend-store-uri={backend_uri}"
)
Expand Down
Loading

0 comments on commit 20671c7

Please sign in to comment.