From 702062b75c6beb0f6e51e352b7bb460ba194fc6f Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 7 Sep 2023 18:47:25 +0200 Subject: [PATCH] Tests: Add test infrastructure from MLflow repository --- tests/__init__.py | 0 tests/abstract.py | 63 ++++++++++++++++++++++++++++++++++++++++++ tests/test_tracking.py | 7 +++-- tests/util.py | 30 ++++++++++++++++++++ 4 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/abstract.py create mode 100644 tests/util.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/abstract.py b/tests/abstract.py new file mode 100644 index 0000000..6912033 --- /dev/null +++ b/tests/abstract.py @@ -0,0 +1,63 @@ +# Source: mlflow:tests/store/tracking/__init__.py +import json + +import pytest +from mlflow.entities import RunTag +from mlflow.models import Model +from mlflow.utils.mlflow_tags import MLFLOW_LOGGED_MODELS + + +class AbstractStoreTest: + def create_test_run(self): + raise Exception("this should be overridden") + + def get_store(self): + raise Exception("this should be overridden") + + def test_record_logged_model(self): + store = self.get_store() + run_id = self.create_test_run().info.run_id + m = Model(artifact_path="model/path", run_id=run_id, flavors={"tf": "flavor body"}) + store.record_logged_model(run_id, m) + self._verify_logged( + store, + run_id=run_id, + params=[], + metrics=[], + tags=[RunTag(MLFLOW_LOGGED_MODELS, json.dumps([m.to_dict()]))], + ) + m2 = Model(artifact_path="some/other/path", run_id=run_id, flavors={"R": {"property": "value"}}) + store.record_logged_model(run_id, m2) + self._verify_logged( + store, + run_id, + params=[], + metrics=[], + tags=[RunTag(MLFLOW_LOGGED_MODELS, json.dumps([m.to_dict(), m2.to_dict()]))], + ) + m3 = Model(artifact_path="some/other/path2", run_id=run_id, flavors={"R2": {"property": "value"}}) + store.record_logged_model(run_id, m3) + self._verify_logged( + store, + run_id, + params=[], + metrics=[], + tags=[RunTag(MLFLOW_LOGGED_MODELS, json.dumps([m.to_dict(), m2.to_dict(), m3.to_dict()]))], + ) + with pytest.raises( + TypeError, + match="Argument 'mlflow_model' should be mlflow.models.Model, got ''", + ): + store.record_logged_model(run_id, m.to_dict()) + + @staticmethod + def _verify_logged(store, run_id, metrics, params, tags): + run = store.get_run(run_id) + all_metrics = sum([store.get_metric_history(run_id, key) for key in run.data.metrics], []) + assert len(all_metrics) == len(metrics) + logged_metrics = [(m.key, m.value, m.timestamp, m.step) for m in all_metrics] + assert set(logged_metrics) == {(m.key, m.value, m.timestamp, m.step) for m in metrics} + logged_tags = set(run.data.tags.items()) + assert {(tag.key, tag.value) for tag in tags} <= logged_tags + assert len(run.data.params) == len(params) + assert set(run.data.params.items()) == {(param.key, param.value) for param in params} diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 718dff6..c3588bd 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -1,3 +1,4 @@ +# Source: mlflow:tests/tracking/test_tracking.py import json import math import os @@ -68,9 +69,9 @@ from mlflow.utils.time_utils import get_current_time_millis from mlflow.utils.uri import extract_db_type_from_uri -from tests.integration.utils import invoke_cli_runner -from tests.store.tracking import AbstractStoreTest -from tests.store.tracking.test_file_store import assert_dataset_inputs_equal +from mlflow_cratedb.adapter.db import CRATEDB +from .abstract import AbstractStoreTest +from .util import invoke_cli_runner, assert_dataset_inputs_equal DB_URI = "sqlite:///" ARTIFACT_URI = "artifact_folder" diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000..8741ada --- /dev/null +++ b/tests/util.py @@ -0,0 +1,30 @@ +# Source: mlflow:tests/integration/utils.py and mlflow:tests/store/tracking/test_file_store.py +from typing import List + +from click.testing import CliRunner +from mlflow.entities import DatasetInput + + +def invoke_cli_runner(*args, **kwargs): + """ + Helper method to invoke the CliRunner while asserting that the exit code is actually 0. + """ + + res = CliRunner().invoke(*args, **kwargs) + assert res.exit_code == 0, f"Got non-zero exit code {res.exit_code}. Output is: {res.output}" + return res + + +def assert_dataset_inputs_equal(inputs1: List[DatasetInput], inputs2: List[DatasetInput]): + inputs1 = sorted(inputs1, key=lambda inp: (inp.dataset.name, inp.dataset.digest)) + inputs2 = sorted(inputs2, key=lambda inp: (inp.dataset.name, inp.dataset.digest)) + assert len(inputs1) == len(inputs2) + for idx, inp1 in enumerate(inputs1): + inp2 = inputs2[idx] + assert dict(inp1.dataset) == dict(inp2.dataset) + tags1 = sorted(inp1.tags, key=lambda tag: tag.key) + tags2 = sorted(inp2.tags, key=lambda tag: tag.key) + for idx, tag1 in enumerate(tags1): + tag2 = tags2[idx] + assert tag1.key == tag1.key + assert tag1.value == tag2.value