diff --git a/src/bikes/core/models.py b/src/bikes/core/models.py index 7516188..462fe10 100644 --- a/src/bikes/core/models.py +++ b/src/bikes/core/models.py @@ -155,7 +155,7 @@ def fit(self, inputs: schemas.Inputs, targets: schemas.Targets) -> "BaselineSkle @T.override def predict(self, inputs: schemas.Inputs) -> schemas.Outputs: model = self.get_internal_model() - prediction = model.predict(inputs) # np.ndarray + prediction = model.predict(inputs) outputs = schemas.Outputs( {schemas.OutputsSchema.prediction: prediction}, index=inputs.index ) diff --git a/src/bikes/io/registries.py b/src/bikes/io/registries.py index 6e8d625..06d686f 100644 --- a/src/bikes/io/registries.py +++ b/src/bikes/io/registries.py @@ -57,7 +57,7 @@ class Saver(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"): e.g., to switch between serialization flavors. Parameters: - path (str): model path inside the MLflow store. + path (str): model path inside the Mlflow store. """ KIND: str @@ -81,7 +81,7 @@ def save( class CustomSaver(Saver): - """Saver for project models using the MLflow PyFunc module. + """Saver for project models using the Mlflow PyFunc module. https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html """ @@ -89,7 +89,7 @@ class CustomSaver(Saver): KIND: T.Literal["CustomSaver"] = "CustomSaver" class Adapter(mlflow.pyfunc.PythonModel): # type: ignore[misc] - """Adapt a custom model to the MLflow PyFunc flavor for saving operations. + """Adapt a custom model to the Mlflow PyFunc flavor for saving operations. https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html?#mlflow.pyfunc.PythonModel """ @@ -134,12 +134,12 @@ def save( class BuiltinSaver(Saver): - """Saver for built-in models using an MLflow flavor module. + """Saver for built-in models using an Mlflow flavor module. https://mlflow.org/docs/latest/models.html#built-in-model-flavors Parameters: - flavor (str): MLflow flavor module to use for the serialization. + flavor (str): Mlflow flavor module to use for the serialization. """ KIND: T.Literal["BuiltinSaver"] = "BuiltinSaver" @@ -201,7 +201,7 @@ def load(self, uri: str) -> "Loader.Adapter": class CustomLoader(Loader): - """Loader for custom models using the MLflow PyFunc module. + """Loader for custom models using the Mlflow PyFunc module. https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html """ @@ -233,9 +233,9 @@ def load(self, uri: str) -> "CustomLoader.Adapter": class BuiltinLoader(Loader): - """Loader for built-in models using the MLflow PyFunc module. + """Loader for built-in models using the Mlflow PyFunc module. - Note: use MLflow PyFunc instead of flavors to use standard API. + Note: use Mlflow PyFunc instead of flavors to use standard API. https://mlflow.org/docs/latest/models.html#built-in-model-flavors """ @@ -298,17 +298,17 @@ def register(self, name: str, model_uri: str) -> Version: """ -class MLflowRegister(Register): - """Register for models in the MLflow Model Registry. +class MlflowRegister(Register): + """Register for models in the Mlflow Model Registry. https://mlflow.org/docs/latest/model-registry.html """ - KIND: T.Literal["MLflowRegister"] = "MLflowRegister" + KIND: T.Literal["MlflowRegister"] = "MlflowRegister" @T.override def register(self, name: str, model_uri: str) -> Version: return mlflow.register_model(name=name, model_uri=model_uri, tags=self.tags) -RegisterKind = MLflowRegister +RegisterKind = MlflowRegister diff --git a/src/bikes/io/services.py b/src/bikes/io/services.py index 6bb9387..d8657ae 100644 --- a/src/bikes/io/services.py +++ b/src/bikes/io/services.py @@ -81,12 +81,12 @@ def logger(self) -> loguru.Logger: return loguru.logger -class MLflowService(Service): - """Service for MLflow tracking and registry. +class MlflowService(Service): + """Service for Mlflow tracking and registry. Parameters: - tracking_uri (str): the URI for the MLflow tracking server. - registry_uri (str): the URI for the MLflow model registry. + tracking_uri (str): the URI for the Mlflow tracking server. + registry_uri (str): the URI for the Mlflow model registry. experiment_name (str): the name of tracking experiment. registry_name (str): the name of model registry. autolog_disable (bool): disable autologging. @@ -96,9 +96,24 @@ class MLflowService(Service): autolog_log_model_signatures (bool): If True, logs model signatures during autologging. autolog_log_models (bool): If True, enables logging of models during autologging. autolog_log_datasets (bool): If True, logs datasets used during autologging. - autolog_silent (bool): If True, suppresses all MLflow warnings during autologging. + autolog_silent (bool): If True, suppresses all Mlflow warnings during autologging. """ + class RunConfig(pdt.BaseModel, strict=True, frozen=True, extra="forbid"): + """Run configuration for Mlflow tracking. + + Parameters: + name (str): name of the run. + description (str | None): description of the run. + tags (dict[str, T.Any] | None): tags for the run. + log_system_metrics (bool | None): enable system metrics logging. + """ + + name: str + description: str | None = None + tags: dict[str, T.Any] | None = None + log_system_metrics: bool | None = None + # server uri tracking_uri: str = "./mlruns" registry_uri: str = "./mlruns" @@ -135,31 +150,25 @@ def start(self) -> None: ) @ctx.contextmanager - def run( - self, - name: str, - description: str | None = None, - tags: dict[str, T.Any] | None = None, - log_system_metrics: bool | None = None, - ) -> T.Generator[mlflow.ActiveRun, None, None]: - """Yield an active MLflow run and exit it afterwards. + def run_context(self, run_config: RunConfig) -> T.Generator[mlflow.ActiveRun, None, None]: + """Yield an active Mlflow run and exit it afterwards. Args: - name (str): name of the run. - description (str | None, optional): description of the run. Defaults to None. - tags (dict[str, T.Any] | None, optional): dict of tags of the run. Defaults to None. - log_system_metrics (bool | None, optional): enable system metrics logging. Defaults to None. + run (str): run parameters. Yields: T.Generator[mlflow.ActiveRun, None, None]: active run context. Will be closed as the end of context. """ with mlflow.start_run( - run_name=name, description=description, tags=tags, log_system_metrics=log_system_metrics + run_name=run_config.name, + tags=run_config.tags, + description=run_config.description, + log_system_metrics=run_config.log_system_metrics, ) as run: yield run def client(self) -> mt.MlflowClient: - """Return a new MLflow client. + """Return a new Mlflow client. Returns: MlflowClient: the mlflow client. diff --git a/src/bikes/jobs/base.py b/src/bikes/jobs/base.py index 45fd65a..7e7406c 100644 --- a/src/bikes/jobs/base.py +++ b/src/bikes/jobs/base.py @@ -26,13 +26,13 @@ class Job(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"): Parameters: logger_service (services.LoggerService): manage the logging system. - mlflow_service (services.MLflowService): manage the mlflow system. + mlflow_service (services.MlflowService): manage the mlflow system. """ KIND: str logger_service: services.LoggerService = services.LoggerService() - mlflow_service: services.MLflowService = services.MLflowService() + mlflow_service: services.MlflowService = services.MlflowService() def __enter__(self) -> T.Self: """Enter the job context. @@ -43,7 +43,7 @@ def __enter__(self) -> T.Self: self.logger_service.start() logger = self.logger_service.logger() logger.debug("[START] Logger service: {}", self.logger_service) - logger.debug("[START] MLflow service: {}", self.mlflow_service) + logger.debug("[START] Mlflow service: {}", self.mlflow_service) self.mlflow_service.start() return self @@ -64,7 +64,7 @@ def __exit__( T.Literal[False]: always propagate exceptions. """ logger = self.logger_service.logger() - logger.debug("[STOP] MLflow service: {}", self.mlflow_service) + logger.debug("[STOP] Mlflow service: {}", self.mlflow_service) self.mlflow_service.stop() logger.debug("[STOP] Logger service: {}", self.logger_service) self.logger_service.stop() diff --git a/src/bikes/jobs/promotion.py b/src/bikes/jobs/promotion.py index 2554b43..85f4891 100644 --- a/src/bikes/jobs/promotion.py +++ b/src/bikes/jobs/promotion.py @@ -21,8 +21,8 @@ class PromotionJob(base.Job): KIND: T.Literal["PromotionJob"] = "PromotionJob" - version: int | None = None alias: str = "Champion" + version: int | None = None @T.override def run(self) -> base.Locals: diff --git a/src/bikes/jobs/training.py b/src/bikes/jobs/training.py index e7918ae..b245787 100644 --- a/src/bikes/jobs/training.py +++ b/src/bikes/jobs/training.py @@ -8,7 +8,7 @@ from bikes.core import metrics as metrics_ from bikes.core import models, schemas -from bikes.io import datasets, registries +from bikes.io import datasets, registries, services from bikes.jobs import base from bikes.utils import signers, splitters @@ -19,9 +19,7 @@ class TrainingJob(base.Job): """Train and register a single AI/ML model. Parameters: - run_name (str): name of the run. - run_description (str, optional): description of the run. - run_tags: (dict[str, T.Any], optional): tags for the run. + run_config (services.MlflowService.RunConfig): mlflow run config. inputs (datasets.ReaderKind): reader for the inputs data. targets (datasets.ReaderKind): reader for the targets data. model (models.ModelKind): machine learning model to train. @@ -35,9 +33,7 @@ class TrainingJob(base.Job): KIND: T.Literal["TrainingJob"] = "TrainingJob" # Run - run_name: str = "Tuning" - run_description: str | None = None - run_tags: dict[str, T.Any] | None = None + run_config: services.MlflowService.RunConfig = services.MlflowService.RunConfig(name="Training") # Data inputs: datasets.ReaderKind = pdt.Field(..., discriminator="KIND") targets: datasets.ReaderKind = pdt.Field(..., discriminator="KIND") @@ -55,7 +51,7 @@ class TrainingJob(base.Job): signer: signers.SignerKind = pdt.Field(signers.InferSigner(), discriminator="KIND") # Registrer # - avoid shadowing pydantic `register` pydantic function - registry: registries.RegisterKind = pdt.Field(registries.MLflowRegister(), discriminator="KIND") + registry: registries.RegisterKind = pdt.Field(registries.MlflowRegister(), discriminator="KIND") @T.override def run(self) -> base.Locals: @@ -65,9 +61,7 @@ def run(self) -> base.Locals: logger.info("With logger: {}", logger) # - mlflow client = self.mlflow_service.client() - with self.mlflow_service.run( - name=self.run_name, description=self.run_description, tags=self.run_tags - ) as run: + with self.mlflow_service.run_context(run_config=self.run_config) as run: logger.info("With mlflow run id: {}", run.info.run_id) # data # - inputs diff --git a/src/bikes/jobs/tuning.py b/src/bikes/jobs/tuning.py index 4855151..a5a3091 100644 --- a/src/bikes/jobs/tuning.py +++ b/src/bikes/jobs/tuning.py @@ -7,7 +7,7 @@ import pydantic as pdt from bikes.core import metrics, models, schemas -from bikes.io import datasets +from bikes.io import datasets, services from bikes.jobs import base from bikes.utils import searchers, splitters @@ -18,9 +18,7 @@ class TuningJob(base.Job): """Find the best hyperparameters for a model. Parameters: - run_name (str): name of the run. - run_description (str, optional): description of the run. - run_tags: (dict[str, T.Any], optional): tags for the run. + run_config (services.MlflowService.RunConfig): mlflow run config. inputs (datasets.ReaderKind): reader for the inputs data. targets (datasets.ReaderKind): reader for the targets data. model (models.ModelKind): machine learning model to tune. @@ -32,9 +30,7 @@ class TuningJob(base.Job): KIND: T.Literal["TuningJob"] = "TuningJob" # Run - run_name: str = "Tuning" - run_description: str | None = None - run_tags: dict[str, T.Any] | None = None + run_config: services.MlflowService.RunConfig = services.MlflowService.RunConfig(name="Tuning") # Data inputs: datasets.ReaderKind = pdt.Field(..., discriminator="KIND") targets: datasets.ReaderKind = pdt.Field(..., discriminator="KIND") @@ -64,9 +60,7 @@ def run(self) -> base.Locals: logger = self.logger_service.logger() logger.info("With logger: {}", logger) # - mlflow - with self.mlflow_service.run( - name=self.run_name, description=self.run_description, tags=self.run_tags - ) as run: + with self.mlflow_service.run_context(run_config=self.run_config) as run: logger.info("With mlflow run id: {}", run.info.run_id) # data # - inputs diff --git a/src/bikes/utils/searchers.py b/src/bikes/utils/searchers.py index 83cac4a..7ca7a34 100644 --- a/src/bikes/utils/searchers.py +++ b/src/bikes/utils/searchers.py @@ -60,17 +60,17 @@ def search( metric (metrics.Metric): main metric to optimize. inputs (schemas.Inputs): model inputs for tuning. targets (schemas.Targets): model targets for tuning. - cv (CrossValidation): structure for cross-folds strategy. + cv (CrossValidation): choice for cross-fold validation. Returns: - Results: all the results of the searcher process. + Results: all the results of the searcher execution process. """ class GridCVSearcher(Searcher): """Grid searcher with cross-fold validation. - Metric should return higher values for better models. + Convention: metric returns higher values for better models. Parameters: n_jobs (int, optional): number of jobs to run in parallel. diff --git a/src/bikes/utils/signers.py b/src/bikes/utils/signers.py index 2afcba3..b976bb5 100644 --- a/src/bikes/utils/signers.py +++ b/src/bikes/utils/signers.py @@ -22,7 +22,7 @@ class Signer(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"): """Base class for generating model signatures. Allow to switch between model signing strategies. - e.g., automatic inference, manual signatures, ... + e.g., automatic inference, manual model signature, ... https://mlflow.org/docs/latest/models.html#model-signature-and-input-example """ diff --git a/src/bikes/utils/splitters.py b/src/bikes/utils/splitters.py index f072a7b..bbc52ed 100644 --- a/src/bikes/utils/splitters.py +++ b/src/bikes/utils/splitters.py @@ -67,7 +67,7 @@ class TrainTestSplitter(Splitter): """Split a dataframe into a train and test set. Parameters: - shuffle (bool): shuffle dataset before splitting it. + shuffle (bool): shuffle the dataset. Default is False. test_size (int | float): number/ratio for the test set. random_state (int): random state for the splitter object. """ diff --git a/tasks/cleans.py b/tasks/cleans.py index 0babcc3..02a719a 100644 --- a/tasks/cleans.py +++ b/tasks/cleans.py @@ -30,7 +30,7 @@ def pytest(ctx: Context) -> None: @task def coverage(ctx: Context) -> None: - """Clean coverage tool.""" + """Clean the coverage tool.""" ctx.run("rm -f .coverage*") @@ -104,7 +104,7 @@ def folders(_: Context) -> None: @task(pre=[venv, poetry, python]) def sources(_: Context) -> None: - """Run all folders tasks.""" + """Run all sources tasks.""" @task(pre=[tools, folders], default=True) diff --git a/tasks/formats.py b/tasks/formats.py index 8ba996d..154593e 100644 --- a/tasks/formats.py +++ b/tasks/formats.py @@ -10,7 +10,7 @@ @task def code(ctx: Context) -> None: - """Format code with ruff.""" + """Format python code with ruff.""" ctx.run("poetry run ruff format src/ tasks/ tests/") diff --git a/tasks/installs.py b/tasks/installs.py index ebd0727..05669be 100644 --- a/tasks/installs.py +++ b/tasks/installs.py @@ -16,7 +16,7 @@ def poetry(ctx: Context) -> None: @task def pre_commit(ctx: Context) -> None: - """Run pre-commit install.""" + """Install pre-commit hooks on git.""" ctx.run("poetry run pre-commit install --hook-type pre-push") ctx.run("poetry run pre-commit install --hook-type commit-msg") diff --git a/tasks/mlflow.py b/tasks/mlflow.py index 2a80a58..0482ce2 100644 --- a/tasks/mlflow.py +++ b/tasks/mlflow.py @@ -18,7 +18,7 @@ def doctor(ctx: Context) -> None: def serve( ctx: Context, host: str = "127.0.0.1", port: str = "5000", backend_uri: str = "./mlruns" ) -> None: - """""" + """Start mlflow server with the given host, port, and backend uri.""" ctx.run( f"poetry run mlflow server --host={host} --port={port} --backend-store-uri={backend_uri}" ) diff --git a/tests/conftest.py b/tests/conftest.py index dd5b98e..24a408a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -245,9 +245,9 @@ def logger_caplog( @pytest.fixture(scope="function", autouse=True) -def mlflow_service(tmp_path: str) -> T.Generator[services.MLflowService, None, None]: +def mlflow_service(tmp_path: str) -> T.Generator[services.MlflowService, None, None]: """Return and start the mlflow service.""" - service = services.MLflowService( + service = services.MlflowService( tracking_uri=f"{tmp_path}/tracking/", registry_uri=f"{tmp_path}/registry/", experiment_name="Experiment-Testing", @@ -312,10 +312,10 @@ def loader() -> registries.CustomLoader: @pytest.fixture(scope="session") -def register() -> registries.MLflowRegister: +def register() -> registries.MlflowRegister: """Return the default model register.""" - tags = {"registry": "mlflow"} - return registries.MLflowRegister(tags=tags) + tags = {"context": "test", "role": "fixture"} + return registries.MlflowRegister(tags=tags) @pytest.fixture(scope="function") @@ -325,10 +325,11 @@ def model_version( signature: signers.Signature, saver: registries.Saver, register: registries.Register, - mlflow_service: services.MLflowService, + mlflow_service: services.MlflowService, ) -> registries.Version: """Save and register the default model version.""" - with mlflow_service.run(name="Custom-Run"): + run_config = mlflow_service.RunConfig(name="Custom-Run") + with mlflow_service.run_context(run_config=run_config): info = saver.save(model=model, signature=signature, input_example=inputs) version = register.register(name=mlflow_service.registry_name, model_uri=info.model_uri) return version @@ -337,9 +338,9 @@ def model_version( @pytest.fixture(scope="function") def model_alias( model_version: registries.Version, - mlflow_service: services.MLflowService, + mlflow_service: services.MlflowService, ) -> registries.Alias: - """Promote the default model version to an alias.""" + """Promote the default model version with an alias.""" alias = "Promotion" client = mlflow_service.client() client.set_registered_model_alias( diff --git a/tests/core/test_models.py b/tests/core/test_models.py index e6befda..fa7d04c 100644 --- a/tests/core/test_models.py +++ b/tests/core/test_models.py @@ -53,9 +53,9 @@ def test_baseline_sklearn_model( model.fit(inputs=inputs_train, targets=targets_train) outputs = model.predict(inputs=inputs_test) # then - assert outputs.ndim == 2, "Outputs should be a dataframe!" - assert model.get_params() == params, "Model should have the given params!" - assert model.get_internal_model() is not None, "Internal model should be fitted!" assert not_fitted_error.match( "Model is not fitted yet!" ), "Model should raise an error when not fitted!" + assert outputs.ndim == 2, "Outputs should be a dataframe!" + assert model.get_params() == params, "Model should have the given params!" + assert model.get_internal_model() is not None, "Internal model should be fitted!" diff --git a/tests/io/test_registries.py b/tests/io/test_registries.py index eb185f9..1002f6a 100644 --- a/tests/io/test_registries.py +++ b/tests/io/test_registries.py @@ -34,7 +34,7 @@ def test_custom_pipeline( model: models.Model, inputs: schemas.Inputs, signature: signers.Signature, - mlflow_service: services.MLflowService, + mlflow_service: services.MlflowService, ) -> None: # given path = "custom" @@ -42,9 +42,10 @@ def test_custom_pipeline( tags = {"registry": "mlflow"} saver = registries.CustomSaver(path=path) loader = registries.CustomLoader() - register = registries.MLflowRegister(tags=tags) + register = registries.MlflowRegister(tags=tags) + run_config = services.MlflowService.RunConfig(name="Custom-Run") # when - with mlflow_service.run(name="Custom") as run: + with mlflow_service.run_context(run_config=run_config) as run: info = saver.save(model=model, signature=signature, input_example=inputs) version = register.register(name=name, model_uri=info.model_uri) model_uri = registries.uri_for_model_version(name=name, version=version.version) @@ -81,7 +82,7 @@ def test_builtin_pipeline( model: models.Model, inputs: schemas.Inputs, signature: signers.Signature, - mlflow_service: services.MLflowService, + mlflow_service: services.MlflowService, ) -> None: # given path = "builtin" @@ -90,9 +91,10 @@ def test_builtin_pipeline( tags = {"registry": "mlflow"} saver = registries.BuiltinSaver(path=path, flavor=flavor) loader = registries.BuiltinLoader() - register = registries.MLflowRegister(tags=tags) + register = registries.MlflowRegister(tags=tags) + run_config = services.MlflowService.RunConfig(name="Custom-Run") # when - with mlflow_service.run(name="Custom") as run: + with mlflow_service.run_context(run_config=run_config) as run: info = saver.save(model=model, signature=signature, input_example=inputs) version = register.register(name=name, model_uri=info.model_uri) model_uri = registries.uri_for_model_version(name=name, version=version.version) diff --git a/tests/io/test_services.py b/tests/io/test_services.py index 1502d7c..dca2ae7 100644 --- a/tests/io/test_services.py +++ b/tests/io/test_services.py @@ -21,21 +21,23 @@ def test_logger_service( assert "ERROR" in logger_caplog.messages, "Error message should be logged!" -def test_mlflow_service(mlflow_service: services.MLflowService) -> None: +def test_mlflow_service(mlflow_service: services.MlflowService) -> None: # given service = mlflow_service - name = "testing" - tags = {"service": "mlflow"} - description = "a test run." - log_system_metrics = True + run_config = mlflow_service.RunConfig( + name="testing", + tags={"service": "mlflow"}, + description="a test run.", + log_system_metrics=True, + ) # when client = service.client() - with service.run( - name=name, tags=tags, description=description, log_system_metrics=log_system_metrics - ) as context: + with service.run_context(run_config=run_config) as context: pass finished = client.get_run(run_id=context.info.run_id) # then + # - run + assert run_config.tags is not None, "Run config tags should be set!" # - mlflow assert service.tracking_uri == mlflow.get_tracking_uri(), "Tracking URI should be the same!" assert service.registry_uri == mlflow.get_registry_uri(), "Registry URI should be the same!" @@ -45,12 +47,13 @@ def test_mlflow_service(mlflow_service: services.MLflowService) -> None: assert service.registry_uri == client._registry_uri, "Tracking URI should be the same!" assert client.get_experiment_by_name(service.experiment_name), "Experiment should be setup!" # - context - assert context.info.status == "RUNNING", "Context should be running!" - assert context.info.run_name == name, "Context name should be the same!" - assert description in context.data.tags.values(), "Context desc. should be in tags values!" + assert context.info.run_name == run_config.name, "Context name should be the same!" + assert ( + run_config.description in context.data.tags.values() + ), "Context desc. should be in tags values!" assert ( - context.data.tags.items() > tags.items() + context.data.tags.items() > run_config.tags.items() ), "Context tags should be a subset of the given tags!" + assert context.info.status == "RUNNING", "Context should be running!" # - finished assert finished.info.status == "FINISHED", "Finished should be finished!" - assert finished.data.tags == context.data.tags, "Finished tags should be the same!" diff --git a/tests/jobs/test_base.py b/tests/jobs/test_base.py index 69df888..2bb2648 100644 --- a/tests/jobs/test_base.py +++ b/tests/jobs/test_base.py @@ -7,7 +7,7 @@ def test_job( - logger_service: services.LoggerService, mlflow_service: services.MLflowService + logger_service: services.LoggerService, mlflow_service: services.MlflowService ) -> None: # given class MyJob(base.Job): @@ -24,6 +24,6 @@ def run(self) -> base.Locals: # then # - inputs assert hasattr(job, "logger_service"), "Job should have an Logger service!" - assert hasattr(job, "mlflow_service"), "Job should have an MLflow service!" + assert hasattr(job, "mlflow_service"), "Job should have an Mlflow service!" # - outputs - assert set(out) == {"self", "a", "b"}, "Run should return the local variables!" + assert set(out) == {"self", "a", "b"}, "Run should return local variables!" diff --git a/tests/jobs/test_inference.py b/tests/jobs/test_inference.py index ee54941..ef721e6 100644 --- a/tests/jobs/test_inference.py +++ b/tests/jobs/test_inference.py @@ -8,7 +8,7 @@ def test_inference_job( - mlflow_service: services.MLflowService, + mlflow_service: services.MlflowService, logger_service: services.LoggerService, inputs_reader: datasets.Reader, tmp_outputs_writer: datasets.Writer, diff --git a/tests/jobs/test_promotion.py b/tests/jobs/test_promotion.py index c084010..9b8322b 100644 --- a/tests/jobs/test_promotion.py +++ b/tests/jobs/test_promotion.py @@ -24,7 +24,7 @@ ) def test_promotion_job( version: int | None, - mlflow_service: services.MLflowService, + mlflow_service: services.MlflowService, logger_service: services.LoggerService, model_version: registries.Version, ) -> None: diff --git a/tests/jobs/test_training.py b/tests/jobs/test_training.py index 76f77ef..a1cfe10 100644 --- a/tests/jobs/test_training.py +++ b/tests/jobs/test_training.py @@ -9,7 +9,7 @@ def test_training_job( - mlflow_service: services.MLflowService, + mlflow_service: services.MlflowService, logger_service: services.LoggerService, inputs_reader: datasets.Reader, targets_reader: datasets.Reader, @@ -21,18 +21,16 @@ def test_training_job( register: registries.Register, ) -> None: # given - run_name = "Training Test" - run_description = "Training job." - run_tags = {"context": "training"} + run_config = services.MlflowService.RunConfig( + name="TrainingTest", tags={"context": "training"}, description="Training job." + ) splitter = train_test_splitter client = mlflow_service.client() # when job = jobs.TrainingJob( mlflow_service=mlflow_service, logger_service=logger_service, - run_name=run_name, - run_description=run_description, - run_tags=run_tags, + run_config=run_config, inputs=inputs_reader, targets=targets_reader, model=model, @@ -71,9 +69,12 @@ def test_training_job( "model_version", } # - run - assert out["run"].info.run_name == run_name, "Run name should be the same!" - assert run_description in out["run"].data.tags.values(), "Run desc. should be tags!" - assert out["run"].data.tags.items() > run_tags.items(), "Run tags should be a subset of tags!" + assert run_config.tags is not None, "Run config tags should be set!" + assert out["run"].info.run_name == run_config.name, "Run name should be the same!" + assert run_config.description in out["run"].data.tags.values(), "Run desc. should be tags!" + assert ( + out["run"].data.tags.items() > run_config.tags.items() + ), "Run tags should be a subset of tags!" # - data assert out["inputs"].ndim == out["inputs_"].ndim == 2, "Inputs should be a dataframe!" assert out["targets"].ndim == out["targets_"].ndim == 2, "Targets should be a dataframe!" @@ -125,11 +126,11 @@ def test_training_job( experiment = client.get_experiment_by_name(name=mlflow_service.experiment_name) assert ( experiment.name == mlflow_service.experiment_name - ), "MLflow Experiment name should be the same!" + ), "Mlflow Experiment name should be the same!" runs = client.search_runs(experiment_ids=experiment.experiment_id) - assert len(runs) == 1, "There should be a single MLflow run for training!" - assert metric.name in runs[0].data.metrics, "Metric should be logged in MLflow!" - assert runs[0].info.status == "FINISHED", "MLflow run status should be set as FINISHED!" + assert len(runs) == 1, "There should be a single Mlflow run for training!" + assert metric.name in runs[0].data.metrics, "Metric should be logged in Mlflow!" + assert runs[0].info.status == "FINISHED", "Mlflow run status should be set as FINISHED!" # - mlflow registry model_version = client.get_model_version( name=mlflow_service.registry_name, version=out["model_version"].version diff --git a/tests/jobs/test_tuning.py b/tests/jobs/test_tuning.py index 6313472..f5d437c 100644 --- a/tests/jobs/test_tuning.py +++ b/tests/jobs/test_tuning.py @@ -9,7 +9,7 @@ def test_tuning_job( - mlflow_service: services.MLflowService, + mlflow_service: services.MlflowService, logger_service: services.LoggerService, inputs_reader: datasets.Reader, targets_reader: datasets.Reader, @@ -19,18 +19,16 @@ def test_tuning_job( searcher: searchers.Searcher, ) -> None: # given - run_name = "TuningTest" - run_description = "Tuning job." - run_tags = {"context": "tuning"} + run_config = services.MlflowService.RunConfig( + name="TuningTest", tags={"context": "tuning"}, description="Tuning job." + ) splitter = time_series_splitter client = mlflow_service.client() # when job = jobs.TuningJob( mlflow_service=mlflow_service, logger_service=logger_service, - run_name=run_name, - run_description=run_description, - run_tags=run_tags, + run_config=run_config, inputs=inputs_reader, targets=targets_reader, model=model, @@ -55,9 +53,12 @@ def test_tuning_job( "best_score", } # - run - assert out["run"].info.run_name == run_name, "Run name should be the same!" - assert run_description in out["run"].data.tags.values(), "Run desc. should be tags!" - assert out["run"].data.tags.items() > run_tags.items(), "Run tags should be a subset of tags!" + assert run_config.tags is not None, "Run config tags should be set!" + assert out["run"].info.run_name == run_config.name, "Run name should be the same!" + assert run_config.description in out["run"].data.tags.values(), "Run desc. should be tags!" + assert ( + out["run"].data.tags.items() > run_config.tags.items() + ), "Run tags should be a subset of tags!" # - data assert out["inputs"].ndim == out["inputs_"].ndim == 2, "Inputs should be a dataframe!" assert out["targets"].ndim == out["inputs_"].ndim == 2, "Targets should be a dataframe!" @@ -75,6 +76,6 @@ def test_tuning_job( experiment = client.get_experiment_by_name(name=mlflow_service.experiment_name) assert ( experiment.name == mlflow_service.experiment_name - ), "MLflow experiment name should be the same!" + ), "Mlflow experiment name should be the same!" runs = client.search_runs(experiment_ids=experiment.experiment_id) - assert len(runs) == len(out["results"]) + 1, "MLflow should have 1 run per result + parent!" + assert len(runs) == len(out["results"]) + 1, "Mlflow should have 1 run per result + parent!" diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 40d3934..063a0f9 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -16,10 +16,10 @@ def test_schema(capsys: pc.CaptureFixture[str]) -> None: args = ["prog", "--schema"] # when scripts.main(args) - capture = capsys.readouterr() + captured = capsys.readouterr() # then - assert capture.err == "", "Captured error should be empty!" - assert json.loads(capture.out), "Captured output should be a JSON!" + assert captured.err == "", "Captured error should be empty!" + assert json.loads(captured.out), "Captured output should be a JSON!" @pytest.mark.parametrize( @@ -45,4 +45,4 @@ def test_main(scenario: str, confs_path: str, extra_config: str) -> None: argv = [config, "-e", extra_config] status = scripts.main(argv=argv) # then - assert status == 0, f"Job should succeed with status 0! Config: {config}" + assert status == 0, f"Job should succeed for config: {config}"