Skip to content

Commit

Permalink
support forecasting (#2499)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbchao authored Jan 31, 2024
1 parent 4bd1b97 commit 5036435
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 35 deletions.
1 change: 1 addition & 0 deletions erroranalysis/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pytest-mock==3.6.1

requirements-parser==0.2.0
rai_test_utils[object_detection]
scikit-learn<=1.3.2
interpret-core[required]<=0.3.2
17 changes: 8 additions & 9 deletions responsibleai/responsibleai/_internal/_served_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import json

import requests

from raiutils.webservice import post_with_retries
from responsibleai.serialization_utilities import serialize_json_safe


Expand Down Expand Up @@ -37,14 +36,14 @@ def forecast(self, X):
# request formatting according to mlflow docs
# https://mlflow.org/docs/latest/cli.html#mlflow-models-serve
# JSON safe serialization takes care of datetime columns
response = requests.post(
url=f"http://localhost:{self.port}/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(
{"dataframe_split": X.to_dict(orient='split')},
default=serialize_json_safe))
uri = f"http://localhost:{self.port}/invocations"
input_data = json.dumps(
{"dataframe_split": X.to_dict(orient='split')},
default=serialize_json_safe)
headers = {"Content-Type": "application/json"}
try:
response.raise_for_status()
response = post_with_retries(uri, input_data, headers,
max_retries=15, retry_delay=30)
except Exception:
raise RuntimeError(
"Could not retrieve predictions. "
Expand Down
8 changes: 6 additions & 2 deletions responsibleai/responsibleai/rai_insights/rai_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,10 +1285,14 @@ def _get_feature_ranges(
res_object[_UNIQUE_VALUES] = unique_value.tolist()
elif datetime_features is not None and col in datetime_features:
res_object[_RANGE_TYPE] = "datetime"
min_value = test[col].min()
min_value = pd.to_datetime(min_value)
res_object[_MIN_VALUE] = \
test[col].min().strftime(_STRF_TIME_FORMAT)
min_value.strftime(_STRF_TIME_FORMAT)
max_value = test[col].max()
max_value = pd.to_datetime(max_value)
res_object[_MAX_VALUE] = \
test[col].max().strftime(_STRF_TIME_FORMAT)
max_value.strftime(_STRF_TIME_FORMAT)
else:
col_min = test[col].min()
col_max = test[col].max()
Expand Down
25 changes: 1 addition & 24 deletions responsibleai/tests/rai_insights/test_served_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from unittest import mock

import pytest
import requests
from tests.common_utils import (RandomForecastingModel,
create_tiny_forecasting_dataset)

Expand Down Expand Up @@ -41,7 +40,7 @@ def rai_forecasting_insights_for_served_model():


@mock.patch("requests.post")
@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5123"})
@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5432"})
def test_served_model(
mock_post,
rai_forecasting_insights_for_served_model):
Expand All @@ -58,25 +57,3 @@ def test_served_model(
forecasts = rai_insights.model.forecast(X_test)
assert len(forecasts) == len(X_test)
assert mock_post.call_count == 1


@mock.patch("requests.post")
@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5123"})
def test_served_model_failed(
mock_post,
rai_forecasting_insights_for_served_model):
_, X_test, _, _ = create_tiny_forecasting_dataset()

response = requests.Response()
response.status_code = 400
response._content = b"Could not connect to host since it actively " \
b"refuses the connection."
mock_post.return_value = response

rai_insights = RAIInsights.load(RAI_INSIGHTS_DIR_NAME)
with pytest.raises(
Exception,
match="Could not retrieve predictions. "
"Model server returned status code 400 "
f"and the following response: {response.content}"):
rai_insights.model.forecast(X_test)

0 comments on commit 5036435

Please sign in to comment.