Skip to content

Commit

Permalink
Add example experiment program, and corresponding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Sep 10, 2023
1 parent 298572e commit 4637c69
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
pip install "setuptools>=64" --upgrade
# Install package in editable mode.
pip install --use-pep517 --prefer-binary --editable=.[develop,docs,test]
pip install --use-pep517 --prefer-binary --editable=.[examples,develop,docs,test]
- name: Run linter and software tests
run: |
Expand Down
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
- Project: Add `versioningit`, for effortless versioning
- Add patch for SQLAlchemy Inspector's `get_table_names`
- Reorder CrateDB SQLAlchemy Dialect polyfills
- Add example experiment program, and corresponding tests
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ for [MLflow Tracking].

Install the most recent version of the `mlflow-cratedb` package.
```shell
pip install --upgrade 'git+https://github.com/crate-workbench/mlflow-cratedb'
pip install --upgrade 'git+https://github.com/crate-workbench/mlflow-cratedb#egg=mlflow-cratedb[examples]'
```

To verify if the installation worked, you can inspect the version numbers
Expand Down Expand Up @@ -54,7 +54,7 @@ git clone https://github.com/crate-workbench/mlflow-cratedb
cd mlflow-cratedb
python3 -m venv .venv
source .venv/bin/activate
pip install --editable='.[develop,docs,test]'
pip install --editable='.[examples,develop,docs,test]'
```

Run linters and software tests, skipping slow tests:
Expand All @@ -74,13 +74,17 @@ pytest -m slow
[Siddharth Murching], [Corey Zumar], [Harutaka Kawamura], [Ben Wilson], and
all other contributors for conceiving and maintaining [MLflow].

[Andreas Nigg] for contributing the [tracking_merlion.py](./examples/tracking_merlion.py)
ML experiment program, which is using [Merlion].


[Andreas Nigg]: https://github.com/andnig
[Ben Wilson]: https://github.com/BenWilson2
[Corey Zumar]: https://github.com/dbczumar
[CrateDB]: https://github.com/crate/crate
[CrateDB Cloud]: https://console.cratedb.cloud/
[Harutaka Kawamura]: https://github.com/harupy
[Merlion]: https://github.com/salesforce/Merlion
[MLflow]: https://mlflow.org/
[MLflow Tracking]: https://mlflow.org/docs/latest/tracking.html
[Siddharth Murching]: https://github.com/smurching
155 changes: 155 additions & 0 deletions examples/tracking_merlion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
Use MLflow and CrateDB to track the events and outcomes of an ML experiment program
using Merlion. It uses the `machine_temperature_system_failure.csv` dataset from
the Numenta Anomaly Benchmark data.
- https://github.com/crate-workbench/mlflow-cratedb
Resources:
- https://mlflow.org/
- https://crate.io/
- https://github.com/salesforce/Merlion
- https://github.com/numenta/NAB
"""

import os

import mlflow
import numpy as np
import pandas as pd
from crate import client
from crate.client.exceptions import ProgrammingError
from merlion.evaluate.anomaly import TSADMetric
from merlion.models.defaults import DefaultDetector, DefaultDetectorConfig
from merlion.utils import TimeSeries


def connect_database():
"""
Connect to CrateDB, and return database connection object.
"""
dburi = os.getenv("CRATEDB_HTTP_URL", "http://crate@localhost:4200")
return client.connect(dburi)


def provision_data():
"""
Download Numenta Anomaly Benchmark data, and load into database.
"""
data = pd.read_csv(
"https://raw.githubusercontent.com/numenta/NAB/master/data/realKnownCause/machine_temperature_system_failure.csv"
)

# Connecting to CrateDB.
conn = connect_database()
try:
conn.cursor().execute("SELECT * FROM machine_data LIMIT 1;")
return
except ProgrammingError as ex:
if "Relation 'machine_data' unknown" not in str(ex):
raise

# Split the data into chunks of 1000 rows each for better insert performance.
chunk_size = 1000
chunks = np.array_split(data, int(len(data) / chunk_size))

# Insert the data into CrateDB.
with conn:
cursor = conn.cursor()
# Create the table if it doesn't exist.
cursor.execute("CREATE TABLE IF NOT EXISTS machine_data (timestamp TIMESTAMP, temperature DOUBLE)")
# Insert the data in chunks.
for chunk in chunks:
cursor.executemany(
"INSERT INTO machine_data (timestamp, temperature) VALUES (?, ?)",
list(chunk.itertuples(index=False, name=None)),
)


def read_data() -> pd.DataFrame:
"""
Read data from database into pandas DataFrame.
"""
conn = connect_database()
with conn:
cursor = conn.cursor()
cursor.execute(
"""SELECT
DATE_BIN('5 min'::INTERVAL, "timestamp", 0) AS timestamp,
MAX(temperature) AS value
FROM machine_data
GROUP BY timestamp
ORDER BY timestamp ASC"""
)
data = cursor.fetchall()

# Convert database response to pandas DataFrame.
time_series = pd.DataFrame(
[{"timestamp": pd.Timestamp.fromtimestamp(item[0] / 1000), "value": item[1]} for item in data]
)
# Set the timestamp as the index
return time_series.set_index("timestamp")


def run_experiment(time_series: pd.DataFrame):
"""
Run experiment on DataFrame, using Merlion. Track it using MLflow.
"""
mlflow.set_experiment("numenta-merlion-experiment")

with mlflow.start_run():
train_data = TimeSeries.from_pd(time_series[time_series.index < pd.to_datetime("2013-12-15")])
test_data = TimeSeries.from_pd(time_series[time_series.index >= pd.to_datetime("2013-12-15")])

model = DefaultDetector(DefaultDetectorConfig())
model.train(train_data=train_data)

test_pred = model.get_anomaly_label(time_series=test_data)

# Prepare the test labels
time_frames = [
["2013-12-15 17:50:00.000000", "2013-12-17 17:00:00.000000"],
["2014-01-27 14:20:00.000000", "2014-01-29 13:30:00.000000"],
["2014-02-07 14:55:00.000000", "2014-02-09 14:05:00.000000"],
]

time_frames = [[pd.to_datetime(start), pd.to_datetime(end)] for start, end in time_frames]
time_series["test_labels"] = 0
for start, end in time_frames:
time_series.loc[(time_series.index >= start) & (time_series.index <= end), "test_labels"] = 1

test_labels = TimeSeries.from_pd(time_series["test_labels"])

p = TSADMetric.Precision.value(ground_truth=test_labels, predict=test_pred)
r = TSADMetric.Recall.value(ground_truth=test_labels, predict=test_pred)
f1 = TSADMetric.F1.value(ground_truth=test_labels, predict=test_pred)
mttd = TSADMetric.MeanTimeToDetect.value(ground_truth=test_labels, predict=test_pred)
print(f"Precision: {p:.4f}, Recall: {r:.4f}, F1: {f1:.4f}\n" f"Mean Time To Detect: {mttd}") # noqa: T201

mlflow.log_metric("precision", p)
mlflow.log_metric("recall", r)
mlflow.log_metric("f1", f1)
mlflow.log_metric("mttd", mttd.total_seconds())
mlflow.log_param("anomaly_threshold", model.config.threshold.alm_threshold)
mlflow.log_param("min_alm_window", model.config.threshold.min_alm_in_window)
mlflow.log_param("alm_window_minutes", model.config.threshold.alm_window_minutes)
mlflow.log_param("alm_suppress_minutes", model.config.threshold.alm_suppress_minutes)
mlflow.log_param("ensemble_size", model.config.model.combiner.n_models)

# Save the model to MLflow.
model.save("model")
mlflow.log_artifact("model")


def main():
"""
Provision dataset, and run experiment.
"""
provision_data()
data = read_data()
run_experiment(data)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,17 @@ develop = [
"ruff==0.0.287",
"validate-pyproject<0.15",
]
examples = [
"salesforce-merlion<2.1",
]
release = [
"build<2",
'minibump<1; python_version >= "3.10"',
"twine<5",
]
test = [
"coverage<8",
"psutil<6",
"pytest<8",
]
[project.scripts]
Expand Down
57 changes: 57 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import sys
import time
from pathlib import Path

import mlflow
import pytest
import sqlalchemy as sa

from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables
from tests.util import process

# The canonical database schema used for example purposes is `examples`.
DB_URI = "crate://crate@localhost/?schema=examples"
MLFLOW_TRACKING_URI = "http://127.0.0.1:5000"


logger = logging.getLogger(__name__)


@pytest.fixture
def engine():
yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI)


def test_tracking_merlion(engine: sa.Engine):
_setup_db_drop_tables(engine=engine)
_setup_db_create_tables(engine=engine)
tracking_merlion = Path(__file__).parent.parent.joinpath("examples").joinpath("tracking_merlion.py")
cmd_server = [
"mlflow-cratedb",
"server",
"--workers=1",
f"--backend-store-uri={DB_URI}",
"--gunicorn-opts='--log-level=debug'",
]
cmd_client = [
sys.executable,
tracking_merlion,
]

logger.info("Starting server")
with process(cmd_server, stdout=sys.stdout.buffer, stderr=sys.stderr.buffer, close_fds=True) as server_process:
logger.info(f"Started server with process id: {server_process.pid}")
# TODO: Wait for HTTP response.
time.sleep(4)
logger.info("Starting client")
with process(
cmd_client,
env={"MLFLOW_TRACKING_URI": MLFLOW_TRACKING_URI},
stdout=sys.stdout.buffer,
stderr=sys.stderr.buffer,
) as client_process:
client_process.wait(timeout=120)
assert client_process.returncode == 0

# TODO: Verify database content.
23 changes: 23 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Source: mlflow:tests/integration/utils.py and mlflow:tests/store/tracking/test_file_store.py
import subprocess
from contextlib import contextmanager
from typing import List

import psutil
from click.testing import CliRunner
from mlflow.entities import DatasetInput

Expand Down Expand Up @@ -28,3 +31,23 @@ def assert_dataset_inputs_equal(inputs1: List[DatasetInput], inputs2: List[Datas
tag2 = tags2[idx]
assert tag1.key == tag1.key
assert tag1.value == tag2.value


@contextmanager
def process(*args, **kwargs) -> subprocess.Popen:
"""
Wrapper around `subprocess.Popen` to also terminate child processes after exiting.
https://gist.github.com/jizhilong/6687481#gistcomment-3057122
"""
proc = subprocess.Popen(*args, **kwargs) # noqa: S603
try:
yield proc
finally:
try:
children = psutil.Process(proc.pid).children(recursive=True)
except psutil.NoSuchProcess:
return
for child in children:
child.kill()
proc.kill()

0 comments on commit 4637c69

Please sign in to comment.