Skip to content

Commit

Permalink
Add example program tracking_merlion.py, and a corresponding test case
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Sep 12, 2023
1 parent 2ea2e70 commit 07a0c65
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
pip install "setuptools>=64" --upgrade
# Install package in editable mode.
pip install --use-pep517 --prefer-binary --editable=.[develop,docs,test]
pip install --use-pep517 --prefer-binary --editable=.[examples,develop,docs,test]
- name: Run linter and software tests
run: |
Expand Down
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
- Project: Add `versioningit`, for effortless versioning
- Add patch for SQLAlchemy Inspector's `get_table_names`
- Reorder CrateDB SQLAlchemy Dialect polyfills
- Add example experiment program `tracking_merlion.py`, and corresponding tests
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ for [MLflow Tracking].

Install the most recent version of the `mlflow-cratedb` package.
```shell
pip install --upgrade 'git+https://github.com/crate-workbench/mlflow-cratedb'
pip install --upgrade 'git+https://github.com/crate-workbench/mlflow-cratedb#egg=mlflow-cratedb[examples]'
```

To verify if the installation worked, you can inspect the version numbers
Expand Down Expand Up @@ -54,7 +54,7 @@ git clone https://github.com/crate-workbench/mlflow-cratedb
cd mlflow-cratedb
python3 -m venv .venv
source .venv/bin/activate
pip install --editable='.[develop,docs,test]'
pip install --editable='.[examples,develop,docs,test]'
```

Run linters and software tests, skipping slow tests:
Expand All @@ -74,13 +74,17 @@ pytest -m slow
[Siddharth Murching], [Corey Zumar], [Harutaka Kawamura], [Ben Wilson], and
all other contributors for conceiving and maintaining [MLflow].

[Andreas Nigg] for contributing the [tracking_merlion.py](./examples/tracking_merlion.py)
ML experiment program, which is using [Merlion].


[Andreas Nigg]: https://github.com/andnig
[Ben Wilson]: https://github.com/BenWilson2
[Corey Zumar]: https://github.com/dbczumar
[CrateDB]: https://github.com/crate/crate
[CrateDB Cloud]: https://console.cratedb.cloud/
[Harutaka Kawamura]: https://github.com/harupy
[Merlion]: https://github.com/salesforce/Merlion
[MLflow]: https://mlflow.org/
[MLflow Tracking]: https://mlflow.org/docs/latest/tracking.html
[Siddharth Murching]: https://github.com/smurching
186 changes: 186 additions & 0 deletions examples/tracking_merlion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
About
Use MLflow and CrateDB to track the metrics, parameters, and outcomes of an ML
experiment program using Merlion. It uses the `machine_temperature_system_failure.csv`
dataset from the Numenta Anomaly Benchmark data.
- https://github.com/crate-workbench/mlflow-cratedb
- https://mlflow.org/docs/latest/tracking.html
Usage
Before running the program, optionally define the `MLFLOW_TRACKING_URI` environment
variable, in order to record events and metrics either directly into the database,
or by submitting them to an MLflow Tracking Server.
# Use CrateDB database directly
export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=mlflow"
# Use MLflow Tracking Server
export MLFLOW_TRACKING_URI=http://127.0.0.1:5000
Resources
- https://mlflow.org/
- https://github.com/crate/crate
- https://github.com/salesforce/Merlion
- https://github.com/numenta/NAB
"""

import os

import mlflow
import numpy as np
import pandas as pd
from crate import client
from merlion.evaluate.anomaly import TSADMetric
from merlion.models.defaults import DefaultDetector, DefaultDetectorConfig
from merlion.utils import TimeSeries


def connect_database():
"""
Connect to CrateDB, and return database connection object.
"""
dburi = os.getenv("CRATEDB_HTTP_URL", "http://crate@localhost:4200")
return client.connect(dburi)


def table_exists(table_name: str, schema_name: str = "doc") -> bool:
"""
Check if database table exists.
"""
conn = connect_database()
cursor = conn.cursor()
sql = (
f"SELECT table_name FROM information_schema.tables " # noqa: S608
f"WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'"
)
cursor.execute(sql)
rowcount = cursor.rowcount
cursor.close()
conn.close()
return rowcount > 0


def import_data(table_name: str):
"""
Download Numenta Anomaly Benchmark data, and load into database.
"""

data = pd.read_csv(
"https://raw.githubusercontent.com/numenta/NAB/master/data/realKnownCause/machine_temperature_system_failure.csv"
)

# Split the data into chunks of 1000 rows each for better insert performance.
chunk_size = 1000
chunks = np.array_split(data, int(len(data) / chunk_size))

# Insert data into CrateDB.
with connect_database() as conn:
cursor = conn.cursor()
# Create the table if it doesn't exist.
cursor.execute(f"CREATE TABLE IF NOT EXISTS {table_name} (timestamp TIMESTAMP, temperature DOUBLE)")
# Insert the data in chunks.
for chunk in chunks:
sql = f"INSERT INTO {table_name} (timestamp, temperature) VALUES (?, ?)" # noqa: S608
cursor.executemany(sql, list(chunk.itertuples(index=False, name=None)))


def read_data(table_name: str) -> pd.DataFrame:
"""
Read data from database into pandas DataFrame.
"""
conn = connect_database()
with conn:
cursor = conn.cursor()
cursor.execute(
f"""SELECT
DATE_BIN('5 min'::INTERVAL, "timestamp", 0) AS timestamp,
MAX(temperature) AS value
FROM {table_name}
GROUP BY timestamp
ORDER BY timestamp ASC"""
)
data = cursor.fetchall()

# Convert database response to pandas DataFrame.
time_series = pd.DataFrame(
[{"timestamp": pd.Timestamp.fromtimestamp(item[0] / 1000), "value": item[1]} for item in data]
)
# Set the timestamp as the index
return time_series.set_index("timestamp")


def run_experiment(time_series: pd.DataFrame):
"""
Run experiment on DataFrame, using Merlion. Track it using MLflow.
"""
mlflow.set_experiment("numenta-merlion-experiment")

with mlflow.start_run():
train_data = TimeSeries.from_pd(time_series[time_series.index < pd.to_datetime("2013-12-15")])
test_data = TimeSeries.from_pd(time_series[time_series.index >= pd.to_datetime("2013-12-15")])

model = DefaultDetector(DefaultDetectorConfig())
model.train(train_data=train_data)

test_pred = model.get_anomaly_label(time_series=test_data)

# Prepare the test labels
time_frames = [
["2013-12-15 17:50:00.000000", "2013-12-17 17:00:00.000000"],
["2014-01-27 14:20:00.000000", "2014-01-29 13:30:00.000000"],
["2014-02-07 14:55:00.000000", "2014-02-09 14:05:00.000000"],
]

time_frames = [[pd.to_datetime(start), pd.to_datetime(end)] for start, end in time_frames]
time_series["test_labels"] = 0
for start, end in time_frames:
time_series.loc[(time_series.index >= start) & (time_series.index <= end), "test_labels"] = 1

test_labels = TimeSeries.from_pd(time_series["test_labels"])

p = TSADMetric.Precision.value(ground_truth=test_labels, predict=test_pred)
r = TSADMetric.Recall.value(ground_truth=test_labels, predict=test_pred)
f1 = TSADMetric.F1.value(ground_truth=test_labels, predict=test_pred)
mttd = TSADMetric.MeanTimeToDetect.value(ground_truth=test_labels, predict=test_pred)
print(f"Precision: {p:.4f}, Recall: {r:.4f}, F1: {f1:.4f}\n" f"Mean Time To Detect: {mttd}") # noqa: T201

mlflow.log_metric("precision", p)
mlflow.log_metric("recall", r)
mlflow.log_metric("f1", f1)
mlflow.log_metric("mttd", mttd.total_seconds())
mlflow.log_param("anomaly_threshold", model.config.threshold.alm_threshold)
mlflow.log_param("min_alm_window", model.config.threshold.min_alm_in_window)
mlflow.log_param("alm_window_minutes", model.config.threshold.alm_window_minutes)
mlflow.log_param("alm_suppress_minutes", model.config.threshold.alm_suppress_minutes)
mlflow.log_param("ensemble_size", model.config.model.combiner.n_models)

# Save the model to MLflow.
model.save("model")
mlflow.log_artifact("model")


def main():
"""
Provision dataset, and run experiment.
"""

# Table name where the actual data is stored.
data_table = "machine_data"

# Provision data to operate on, only once.
if not table_exists(data_table):
import_data(data_table)

# Read data into pandas DataFrame.
data = read_data(data_table)

# Run experiment on data.
run_experiment(data)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,17 @@ develop = [
"ruff==0.0.287",
"validate-pyproject<0.15",
]
examples = [
"salesforce-merlion<2.1",
]
release = [
"build<2",
'minibump<1; python_version >= "3.10"',
"twine<5",
]
test = [
"coverage<8",
"psutil<6",
"pytest<8",
]
[project.scripts]
Expand Down
57 changes: 57 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import sys
import time
from pathlib import Path

import mlflow
import pytest
import sqlalchemy as sa

from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables
from tests.util import process

# The canonical database schema used for example purposes is `examples`.
DB_URI = "crate://crate@localhost/?schema=examples"
MLFLOW_TRACKING_URI = "http://127.0.0.1:5000"


logger = logging.getLogger(__name__)


@pytest.fixture
def engine():
yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI)


def test_tracking_merlion(engine: sa.Engine):
_setup_db_drop_tables(engine=engine)
_setup_db_create_tables(engine=engine)
tracking_merlion = Path(__file__).parent.parent.joinpath("examples").joinpath("tracking_merlion.py")
cmd_server = [
"mlflow-cratedb",
"server",
"--workers=1",
f"--backend-store-uri={DB_URI}",
"--gunicorn-opts='--log-level=debug'",
]
cmd_client = [
sys.executable,
tracking_merlion,
]

logger.info("Starting server")
with process(cmd_server, stdout=sys.stdout.buffer, stderr=sys.stderr.buffer, close_fds=True) as server_process:
logger.info(f"Started server with process id: {server_process.pid}")
# TODO: Wait for HTTP response.
time.sleep(4)
logger.info("Starting client")
with process(
cmd_client,
env={"MLFLOW_TRACKING_URI": MLFLOW_TRACKING_URI},
stdout=sys.stdout.buffer,
stderr=sys.stderr.buffer,
) as client_process:
client_process.wait(timeout=120)
assert client_process.returncode == 0

# TODO: Verify database content.
23 changes: 23 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Source: mlflow:tests/integration/utils.py and mlflow:tests/store/tracking/test_file_store.py
import subprocess
from contextlib import contextmanager
from typing import List

import psutil
from click.testing import CliRunner
from mlflow.entities import DatasetInput

Expand Down Expand Up @@ -28,3 +31,23 @@ def assert_dataset_inputs_equal(inputs1: List[DatasetInput], inputs2: List[Datas
tag2 = tags2[idx]
assert tag1.key == tag1.key
assert tag1.value == tag2.value


@contextmanager
def process(*args, **kwargs) -> subprocess.Popen:
"""
Wrapper around `subprocess.Popen` to also terminate child processes after exiting.
https://gist.github.com/jizhilong/6687481#gistcomment-3057122
"""
proc = subprocess.Popen(*args, **kwargs) # noqa: S603
try:
yield proc
finally:
try:
children = psutil.Process(proc.pid).children(recursive=True)
except psutil.NoSuchProcess:
return
for child in children:
child.kill()
proc.kill()

0 comments on commit 07a0c65

Please sign in to comment.