diff --git a/orient_express/model_wrapper.py b/orient_express/model_wrapper.py index d0e5c35..52005f4 100644 --- a/orient_express/model_wrapper.py +++ b/orient_express/model_wrapper.py @@ -165,6 +165,12 @@ def local_predict(self, input_df): return self.model.predict(input_df) + def local_predict_proba(self, input_df): + if not self.model: + self.load_model_from_registry() + + return self.model.predict_proba(input_df) + def load_model_from_registry(self): if self.model_version: vertex_model = aiplatform.Model( diff --git a/pyproject.toml b/pyproject.toml index 9c7c8d2..7dc7055 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "orient_express" -version = "0.2.5" +version = "0.3.1" description = "A library to simplify model deployment to Vertex AI" authors = ["Alexey Zankevich "] readme = "README.md"