-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example experiment program, and corresponding tests
- Loading branch information
Showing
7 changed files
with
247 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
""" | ||
Use MLflow and CrateDB to track the events and outcomes of an ML experiment program | ||
using Merlion. It uses the `machine_temperature_system_failure.csv` dataset from | ||
the Numenta Anomaly Benchmark data. | ||
- https://github.com/crate-workbench/mlflow-cratedb | ||
Resources: | ||
- https://mlflow.org/ | ||
- https://crate.io/ | ||
- https://github.com/salesforce/Merlion | ||
- https://github.com/numenta/NAB | ||
""" | ||
|
||
import os | ||
|
||
import mlflow | ||
import numpy as np | ||
import pandas as pd | ||
from crate import client | ||
from crate.client.exceptions import ProgrammingError | ||
from merlion.evaluate.anomaly import TSADMetric | ||
from merlion.models.defaults import DefaultDetector, DefaultDetectorConfig | ||
from merlion.utils import TimeSeries | ||
|
||
|
||
def connect_database(): | ||
""" | ||
Connect to CrateDB, and return database connection object. | ||
""" | ||
dburi = os.getenv("CRATEDB_HTTP_URL", "http://crate@localhost:4200") | ||
return client.connect(dburi) | ||
|
||
|
||
def provision_data(): | ||
""" | ||
Download Numenta Anomaly Benchmark data, and load into database. | ||
""" | ||
data = pd.read_csv( | ||
"https://raw.githubusercontent.com/numenta/NAB/master/data/realKnownCause/machine_temperature_system_failure.csv" | ||
) | ||
|
||
# Connecting to CrateDB. | ||
conn = connect_database() | ||
try: | ||
conn.cursor().execute("SELECT * FROM machine_data LIMIT 1;") | ||
return | ||
except ProgrammingError as ex: | ||
if "Relation 'machine_data' unknown" not in str(ex): | ||
raise | ||
|
||
# Split the data into chunks of 1000 rows each for better insert performance. | ||
chunk_size = 1000 | ||
chunks = np.array_split(data, int(len(data) / chunk_size)) | ||
|
||
# Insert the data into CrateDB. | ||
with conn: | ||
cursor = conn.cursor() | ||
# Create the table if it doesn't exist. | ||
cursor.execute("CREATE TABLE IF NOT EXISTS machine_data (timestamp TIMESTAMP, temperature DOUBLE)") | ||
# Insert the data in chunks. | ||
for chunk in chunks: | ||
cursor.executemany( | ||
"INSERT INTO machine_data (timestamp, temperature) VALUES (?, ?)", | ||
list(chunk.itertuples(index=False, name=None)), | ||
) | ||
|
||
|
||
def read_data() -> pd.DataFrame: | ||
""" | ||
Read data from database into pandas DataFrame. | ||
""" | ||
conn = connect_database() | ||
with conn: | ||
cursor = conn.cursor() | ||
cursor.execute( | ||
"""SELECT | ||
DATE_BIN('5 min'::INTERVAL, "timestamp", 0) AS timestamp, | ||
MAX(temperature) AS value | ||
FROM machine_data | ||
GROUP BY timestamp | ||
ORDER BY timestamp ASC""" | ||
) | ||
data = cursor.fetchall() | ||
|
||
# Convert database response to pandas DataFrame. | ||
time_series = pd.DataFrame( | ||
[{"timestamp": pd.Timestamp.fromtimestamp(item[0] / 1000), "value": item[1]} for item in data] | ||
) | ||
# Set the timestamp as the index | ||
return time_series.set_index("timestamp") | ||
|
||
|
||
def run_experiment(time_series: pd.DataFrame): | ||
""" | ||
Run experiment on DataFrame, using Merlion. Track it using MLflow. | ||
""" | ||
mlflow.set_experiment("numenta-merlion-experiment") | ||
|
||
with mlflow.start_run(): | ||
train_data = TimeSeries.from_pd(time_series[time_series.index < pd.to_datetime("2013-12-15")]) | ||
test_data = TimeSeries.from_pd(time_series[time_series.index >= pd.to_datetime("2013-12-15")]) | ||
|
||
model = DefaultDetector(DefaultDetectorConfig()) | ||
model.train(train_data=train_data) | ||
|
||
test_pred = model.get_anomaly_label(time_series=test_data) | ||
|
||
# Prepare the test labels | ||
time_frames = [ | ||
["2013-12-15 17:50:00.000000", "2013-12-17 17:00:00.000000"], | ||
["2014-01-27 14:20:00.000000", "2014-01-29 13:30:00.000000"], | ||
["2014-02-07 14:55:00.000000", "2014-02-09 14:05:00.000000"], | ||
] | ||
|
||
time_frames = [[pd.to_datetime(start), pd.to_datetime(end)] for start, end in time_frames] | ||
time_series["test_labels"] = 0 | ||
for start, end in time_frames: | ||
time_series.loc[(time_series.index >= start) & (time_series.index <= end), "test_labels"] = 1 | ||
|
||
test_labels = TimeSeries.from_pd(time_series["test_labels"]) | ||
|
||
p = TSADMetric.Precision.value(ground_truth=test_labels, predict=test_pred) | ||
r = TSADMetric.Recall.value(ground_truth=test_labels, predict=test_pred) | ||
f1 = TSADMetric.F1.value(ground_truth=test_labels, predict=test_pred) | ||
mttd = TSADMetric.MeanTimeToDetect.value(ground_truth=test_labels, predict=test_pred) | ||
print(f"Precision: {p:.4f}, Recall: {r:.4f}, F1: {f1:.4f}\n" f"Mean Time To Detect: {mttd}") # noqa: T201 | ||
|
||
mlflow.log_metric("precision", p) | ||
mlflow.log_metric("recall", r) | ||
mlflow.log_metric("f1", f1) | ||
mlflow.log_metric("mttd", mttd.total_seconds()) | ||
mlflow.log_param("anomaly_threshold", model.config.threshold.alm_threshold) | ||
mlflow.log_param("min_alm_window", model.config.threshold.min_alm_in_window) | ||
mlflow.log_param("alm_window_minutes", model.config.threshold.alm_window_minutes) | ||
mlflow.log_param("alm_suppress_minutes", model.config.threshold.alm_suppress_minutes) | ||
mlflow.log_param("ensemble_size", model.config.model.combiner.n_models) | ||
|
||
# Save the model to MLflow. | ||
model.save("model") | ||
mlflow.log_artifact("model") | ||
|
||
|
||
def main(): | ||
""" | ||
Provision dataset, and run experiment. | ||
""" | ||
provision_data() | ||
data = read_data() | ||
run_experiment(data) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import logging | ||
import sys | ||
import time | ||
from pathlib import Path | ||
|
||
import mlflow | ||
import pytest | ||
import sqlalchemy as sa | ||
|
||
from mlflow_cratedb.adapter.setup_db import _setup_db_create_tables, _setup_db_drop_tables | ||
from tests.util import process | ||
|
||
# The canonical database schema used for example purposes is `examples`. | ||
DB_URI = "crate://crate@localhost/?schema=examples" | ||
MLFLOW_TRACKING_URI = "http://127.0.0.1:5000" | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@pytest.fixture | ||
def engine(): | ||
yield mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(DB_URI) | ||
|
||
|
||
def test_tracking_merlion(engine: sa.Engine): | ||
_setup_db_drop_tables(engine=engine) | ||
_setup_db_create_tables(engine=engine) | ||
tracking_merlion = Path(__file__).parent.parent.joinpath("examples").joinpath("tracking_merlion.py") | ||
cmd_server = [ | ||
"mlflow-cratedb", | ||
"server", | ||
"--workers=1", | ||
f"--backend-store-uri={DB_URI}", | ||
"--gunicorn-opts='--log-level=debug'", | ||
] | ||
cmd_client = [ | ||
sys.executable, | ||
tracking_merlion, | ||
] | ||
|
||
logger.info("Starting server") | ||
with process(cmd_server, stdout=sys.stdout.buffer, stderr=sys.stderr.buffer, close_fds=True) as server_process: | ||
logger.info(f"Started server with process id: {server_process.pid}") | ||
# TODO: Wait for HTTP response. | ||
time.sleep(4) | ||
logger.info("Starting client") | ||
with process( | ||
cmd_client, | ||
env={"MLFLOW_TRACKING_URI": MLFLOW_TRACKING_URI}, | ||
stdout=sys.stdout.buffer, | ||
stderr=sys.stderr.buffer, | ||
) as client_process: | ||
client_process.wait(timeout=120) | ||
assert client_process.returncode == 0 | ||
|
||
# TODO: Verify database content. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters