Skip to content

Commit

Permalink
Output label transformer support (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nepherhotep authored Dec 12, 2024
1 parent 28d7524 commit 0b0248b
Show file tree
Hide file tree
Showing 6 changed files with 2,029 additions and 28 deletions.
40 changes: 21 additions & 19 deletions orient_express/model_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from typing import Optional

import joblib
import logging
import pandas as pd
Expand All @@ -9,20 +11,20 @@
class ModelExpress:
def __init__(
self,
model_name,
project_name,
bucket_name,
model_version=None,
model=None,
region="us-central1",
serialized_model_path="model.joblib",
serving_container_image_uri="us-west1-docker.pkg.dev/shiftsmart-api/orient-express/xgboost-scikit-learn:latest",
serving_container_predict_route="/v1/models/orient-express-model:predict",
serving_container_health_route="/v1/models/orient-express-model",
endpoint_name=None,
model_name: str,
project_name: str,
bucket_name: str,
model_version: Optional[int] = None,
model: object = None,
region: str = "us-central1",
serialized_model_path: str = "model.joblib",
serving_container_image_uri: str = "us-west1-docker.pkg.dev/shiftsmart-api/orient-express/xgboost-scikit-learn:latest",
serving_container_predict_route: str = "/v1/models/orient-express-model:predict",
serving_container_health_route: str = "/v1/models/orient-express-model",
endpoint_name: Optional[str] = None,
machine_type="n1-standard-4",
min_replica_count=1,
max_replica_count=1,
min_replica_count: int = 1,
max_replica_count: int = 1,
):
self.model = model
self.model_name = model_name
Expand Down Expand Up @@ -56,7 +58,7 @@ def _vertex_init(self):
aiplatform.init(project=self.project_name, location=self.region)
self._vertex_initialized = True

def get_latest_vertex_model(self, model_name):
def get_latest_vertex_model(self, model_name: str):
"""If there are a few models with the same name, load the most recent one.
It's highly recommended to keep only 1 model with the same name to avoid the confusion
"""
Expand Down Expand Up @@ -96,7 +98,7 @@ def upload(self):

return self.create_model_version(new_version, last_model)

def get_artifacts_path(self, version, file_name=None):
def get_artifacts_path(self, version: int, file_name: str = None):
dir_name = f"models/{self.model_name}/{version}"
if file_name:
return f"{dir_name}/{file_name}"
Expand Down Expand Up @@ -159,7 +161,7 @@ def deploy(self):
traffic_percentage=100,
)

def remote_predict(self, input_df):
def remote_predict(self, input_df: pd.DataFrame):
self._vertex_init()

if not self.endpoint:
Expand All @@ -174,15 +176,15 @@ def remote_predict(self, input_df):
predictions = self.endpoint.predict(instances=instances)
return predictions.predictions

def local_predict(self, input_df):
def local_predict(self, input_df: pd.DataFrame):
self._vertex_init()

if not self.model:
self.load_model_from_registry()

return self.model.predict(input_df)

def local_predict_proba(self, input_df):
def local_predict_proba(self, input_df: pd.DataFrame):
if not self.model:
self.load_model_from_registry()

Expand All @@ -207,7 +209,7 @@ def load_model_from_registry(self):

self.model = joblib.load(self.serialized_model_path)

def download_artifacts(self, artifact_uri):
def download_artifacts(self, artifact_uri: str):
storage_client = storage.Client()
bucket_name, artifact_path = artifact_uri.replace("gs://", "").split("/", 1)
bucket = storage_client.bucket(bucket_name)
Expand Down
81 changes: 81 additions & 0 deletions orient_express/sklearn_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import LabelEncoder


class LabelEncoderTransformer(BaseEstimator, TransformerMixin):
"""
A wrapper class that integrates a model and a label encoder to handle
encoded predictions and their corresponding probabilities.
Attributes:
model (object): A machine learning model with `fit`, `predict`, and `predict_proba` methods.
label_encoder (object): A label encoder that maps encoded class labels to their original string labels.
Methods:
fit(X, y):
Trains the model on the provided features and labels.
predict(X):
Predicts the classes for the given features and returns the original string labels.
predict_proba(X):
Returns the probabilities for each class for the given features, along with their original class labels.
"""

def __init__(self, model: BaseEstimator, label_encoder: LabelEncoder):
"""
Initializes the LabelEncoderTransformer.
Args:
model (BaseEstimator): A machine learning model.
label_encoder (LabelEncoder): A label encoder with `fit` and `inverse_transform` methods.
"""
self.model = model
self.label_encoder = label_encoder

def fit(self, X, y):
"""
Fits the model to the training data.
Args:
X (array-like): Feature matrix.
y (array-like): Target labels.
Returns:
self: Fitted LabelEncoderTransformer instance.
"""
self.model.fit(X, y)
return self

def predict(self, X):
"""
Predicts the target labels and returns the original string labels.
Args:
X (array-like): Feature matrix for predictions.
Returns:
array-like: Predicted class labels in their original form.
"""
encoded_predictions = self.model.predict(X)
return self.label_encoder.inverse_transform(encoded_predictions)

def predict_proba(self, X):
"""
Returns class probabilities along with their original labels.
Args:
X (array-like): Feature matrix for predictions.
Returns:
list: A list of lists, where each inner list contains `[class_name, probability]` pairs.
"""
# Get raw probabilities
probabilities = self.model.predict_proba(X)

# Get all class labels
class_names = self.label_encoder.classes_
# Combine class names with probabilities
combined_output = [
[[class_name, prob] for class_name, prob in zip(class_names, sample_probs)]
for sample_probs in probabilities
]
return combined_output
Loading

0 comments on commit 0b0248b

Please sign in to comment.