Skip to content

Commit

Permalink
feat: train udsink request deduplication (#39)
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
Co-authored-by: Nandita Koppisetty <[email protected]>
  • Loading branch information
ab93 and nkoppisetty authored Jan 17, 2023
1 parent d27bf45 commit 4789b19
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 70 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ LICENSE
.codecov.yml
.coveragerc
.flake8
.hack/
21 changes: 21 additions & 0 deletions .hack/changelog.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env sh
set -eu

echo '# Changelog'
echo

tag=
git tag -l 'v*' | sed 's/-rc/~/' | sort -rV | sed 's/~/-rc/' | while read last; do
if [ "$tag" != "" ]; then
echo "## $(git for-each-ref --format='%(refname:strip=2) (%(creatordate:short))' refs/tags/${tag})"
echo
git_log='git --no-pager log --no-merges --invert-grep --grep=^\(build\|chore\|ci\|docs\|test\):'
$git_log --format=' * [%h](https://github.com/numaproj/numalogic-prometheus/commit/%H) %s' $last..$tag
echo
echo "### Contributors"
echo
$git_log --format=' * %an' $last..$tag | sort -u
echo
fi
tag=$last
done
31 changes: 28 additions & 3 deletions numaprom/udsink/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@

from numaprom._constants import DEFAULT_PROMETHEUS_SERVER
from numaprom.prometheus import Prometheus
from numaprom.redis import get_redis_client
from numaprom.tools import get_metric_config, save_model

LOGGER = logging.getLogger(__name__)

HOST = os.getenv("REDIS_HOST")
PORT = os.getenv("REDIS_PORT")
AUTH = os.getenv("REDIS_AUTH")
EXPIRY = int(os.getenv("REDIS_EXPIRY", 300))


def clean_data(df: pd.DataFrame, limit=12) -> pd.DataFrame:
df.replace([np.inf, -np.inf], np.nan, inplace=True)
Expand Down Expand Up @@ -73,6 +79,18 @@ def _preprocess(x_raw):
return x_scaled


def _is_new_request(namespace: str, metric: str) -> bool:
redis_client = get_redis_client(HOST, PORT, password=AUTH, recreate=False)
r_key = f"train::{namespace}:{metric}"

value = redis_client.get(r_key)
if value:
return False

redis_client.setex(r_key, time=EXPIRY, value=1)
return True


def train(datums: List[Datum]) -> Responses:
responses = Responses()

Expand All @@ -82,6 +100,13 @@ def train(datums: List[Datum]) -> Responses:
namespace = payload["namespace"]
metric_name = payload["name"]

is_new = _is_new_request(namespace, metric_name)
if not is_new:
warn_msg = f"Skipping train request with namespace: {namespace}, metric: {metric_name}"
LOGGER.warning(warn_msg)
responses.append(Response.as_failure(_datum.id, err_msg=warn_msg))
continue

metric_config = get_metric_config(metric_name)
model_config = metric_config["model_config"]
win_size = model_config["win_size"]
Expand All @@ -90,12 +115,12 @@ def train(datums: List[Datum]) -> Responses:
train_df = clean_data(train_df)

if len(train_df) < model_config["win_size"]:
warn_msg = (
_info_msg = (
f"Skipping training since traindata size: {train_df.shape} "
f"is less than winsize: {win_size}"
)
LOGGER.warning(warn_msg)
responses.append(Response.as_failure(_datum.id, err_msg=warn_msg))
LOGGER.info(_info_msg)
responses.append(Response.as_failure(_datum.id, err_msg=_info_msg))
continue

x_train = _preprocess(train_df.to_numpy())
Expand Down
130 changes: 70 additions & 60 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic-prometheus"
version = "0.0.4"
version = "0.1.0a0"
description = "ML inference on numaflow using numalogic on Prometheus metrics"
authors = ["Numalogic developers"]
packages = [{ include = "numaprom" }]
Expand Down
16 changes: 15 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
import logging
from unittest.mock import patch

import fakeredis

server = fakeredis.FakeServer()
redis_client = fakeredis.FakeStrictRedis(server=server, decode_responses=True)

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.DEBUG)

stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)


formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)

LOGGER.addHandler(stream_handler)


with patch("numaprom.redis.get_redis_client") as mock_get_redis_client:
mock_get_redis_client.return_value = redis_client
from numaprom.udf import window
from numaprom.udsink import train


__all__ = ["redis_client", "window"]
__all__ = ["redis_client", "window", "train"]
26 changes: 22 additions & 4 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pynumaflow.sink import Datum

import trainer
from numaprom.udsink import train

from numaprom._constants import TESTS_DIR, METRIC_CONFIG
from numaprom.prometheus import Prometheus
Expand All @@ -19,18 +18,21 @@
mock_argocd_query_metric,
mock_rollout_query_metric,
)
from tests import train, redis_client

DATA_DIR = os.path.join(TESTS_DIR, "resources", "data")
STREAM_DATA_PATH = os.path.join(DATA_DIR, "stream.json")


def as_datum(data: Union[str, bytes, dict]) -> Datum:
def as_datum(data: Union[str, bytes, dict], msg_id="1") -> Datum:
if type(data) is not bytes:
data = json.dumps(data).encode("utf-8")
elif type(data) == dict:
data = json.dumps(data)

return Datum(sink_msg_id="1", value=data, event_time=datetime.now(), watermark=datetime.now())
return Datum(
sink_msg_id=msg_id, value=data, event_time=datetime.now(), watermark=datetime.now()
)


@patch.dict(METRIC_CONFIG, return_mock_metric_config())
Expand All @@ -41,13 +43,29 @@ class TestTrainer(unittest.TestCase):
"resume_training": False,
}

def setUp(self) -> None:
redis_client.flushall()

@patch.object(MLflowRegistry, "save", Mock(return_value=1))
@patch.object(Prometheus, "query_metric", Mock(return_value=mock_argocd_query_metric()))
def test_argocd_trainer(self):
def test_argocd_trainer_01(self):
datums = [as_datum(self.train_payload)]
_out = train(datums)
self.assertTrue(_out.items()[0].success)
self.assertEqual("1", _out.items()[0].id)
self.assertEqual(1, len(_out.items()))

@patch.object(MLflowRegistry, "save", Mock(return_value=1))
@patch.object(Prometheus, "query_metric", Mock(return_value=mock_argocd_query_metric()))
def test_argocd_trainer_02(self):
datums = [as_datum(self.train_payload), as_datum(self.train_payload, msg_id="2")]
_out = train(datums)
self.assertTrue(_out.items()[0].success)
self.assertEqual("1", _out.items()[0].id)

self.assertFalse(_out.items()[1].success)
self.assertEqual("2", _out.items()[1].id)
self.assertEqual(2, len(_out.items()))

@unittest.skip("Need to update for rollouts")
@patch.object(Prometheus, "query_metric", Mock(return_value=mock_rollout_query_metric()))
Expand Down
1 change: 0 additions & 1 deletion tests/udf/test_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def test_postprocess(self):
_in = get_datum(msg.value)
_out = postprocess("", _in)
data = _out.items()[0].value.decode("utf-8")
print("DATA", data)
prom_payload = PrometheusPayload.from_json(data)

if prom_payload.name != "metric_3_anomaly":
Expand Down

0 comments on commit 4789b19

Please sign in to comment.