From 8b04aa166a94845ca5975a861c211007a1fea7ed Mon Sep 17 00:00:00 2001 From: Alexey Zankevich Date: Wed, 13 Nov 2024 16:58:35 -0500 Subject: [PATCH] Add vertex init calls --- orient_express/model_wrapper.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/orient_express/model_wrapper.py b/orient_express/model_wrapper.py index 52005f4..18442fc 100644 --- a/orient_express/model_wrapper.py +++ b/orient_express/model_wrapper.py @@ -40,13 +40,17 @@ def __init__( else: self.endpoint_name = endpoint_name + self._vertex_initialized = False + def colab_auth(self): from google.colab import auth auth.authenticate_user() def _vertex_init(self): - aiplatform.init(project=self.project_name, location=self.region) + if not self._vertex_initialized: + aiplatform.init(project=self.project_name, location=self.region) + self._vertex_initialized = True def get_latest_vertex_model(self, model_name): self._vertex_init() @@ -147,6 +151,8 @@ def deploy(self): ) def remote_predict(self, input_df): + self._vertex_init() + if not self.endpoint: endpoint = self.get_endpoint() if not endpoint: @@ -160,6 +166,8 @@ def remote_predict(self, input_df): return predictions.predictions def local_predict(self, input_df): + self._vertex_init() + if not self.model: self.load_model_from_registry() @@ -172,6 +180,8 @@ def local_predict_proba(self, input_df): return self.model.predict_proba(input_df) def load_model_from_registry(self): + self._vertex_init() + if self.model_version: vertex_model = aiplatform.Model( model_name=self.model_name, version=self.model_version