Skip to content

Commit

Permalink
Add vertex init calls
Browse files Browse the repository at this point in the history
  • Loading branch information
Nepherhotep committed Nov 13, 2024
1 parent eedada7 commit 8b04aa1
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion orient_express/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,17 @@ def __init__(
else:
self.endpoint_name = endpoint_name

self._vertex_initialized = False

def colab_auth(self):
from google.colab import auth

auth.authenticate_user()

def _vertex_init(self):
aiplatform.init(project=self.project_name, location=self.region)
if not self._vertex_initialized:
aiplatform.init(project=self.project_name, location=self.region)
self._vertex_initialized = True

def get_latest_vertex_model(self, model_name):
self._vertex_init()
Expand Down Expand Up @@ -147,6 +151,8 @@ def deploy(self):
)

def remote_predict(self, input_df):
self._vertex_init()

if not self.endpoint:
endpoint = self.get_endpoint()
if not endpoint:
Expand All @@ -160,6 +166,8 @@ def remote_predict(self, input_df):
return predictions.predictions

def local_predict(self, input_df):
self._vertex_init()

if not self.model:
self.load_model_from_registry()

Expand All @@ -172,6 +180,8 @@ def local_predict_proba(self, input_df):
return self.model.predict_proba(input_df)

def load_model_from_registry(self):
self._vertex_init()

if self.model_version:
vertex_model = aiplatform.Model(
model_name=self.model_name, version=self.model_version
Expand Down

0 comments on commit 8b04aa1

Please sign in to comment.