Skip to content

Commit

Permalink
start
Browse files Browse the repository at this point in the history
  • Loading branch information
MarleneKress79789 committed Nov 30, 2023
1 parent ec5cabd commit 360297c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 36 deletions.
49 changes: 13 additions & 36 deletions exasol_transformers_extension/udfs/models/base_model_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from exasol_transformers_extension.deployment import constants
from exasol_transformers_extension.utils import device_management, \
bucketfs_operations, dataframe_operations
from exasol_transformers_extension.utils.load_model import LoadModel


class BaseModelUDF(ABC):
Expand All @@ -20,7 +21,7 @@ class BaseModelUDF(ABC):
- creates model pipeline through transformer api
- manages the creation of predictions and the preparation of results.
"""

# todo does the token con change? (if yes need to be give at function call not class creation)
def __init__(self,
exa,
batch_size,
Expand All @@ -36,15 +37,17 @@ def __init__(self,
self.task_name = task_name
self.device = None
self.cache_dir = None
self.last_loaded_model_key = None
self.last_loaded_model = None
self.last_loaded_tokenizer = None
self.last_created_pipeline = None
self.model_loader = None
self.new_columns = []

def run(self, ctx):
device_id = ctx.get_dataframe(1).iloc[0]['device_id']
self.device = device_management.get_torch_device(device_id)
self.model_loader = LoadModel(self.pipeline,
self.base_model,
self.tokenizer,
self.task_name,
self.device)
ctx.reset()

while True:
Expand Down Expand Up @@ -171,11 +174,10 @@ def check_cache(self, model_df: pd.DataFrame) -> None:
token_conn = model_df["token_conn"].iloc[0]

current_model_key = (bucketfs_conn, sub_dir, model_name, token_conn)
if self.last_loaded_model_key != current_model_key:
if self.model_loader.last_loaded_model_key != current_model_key:
self.set_cache_dir(model_name, bucketfs_conn, sub_dir)
self.clear_device_memory()
self.load_models(model_name, token_conn)
self.last_loaded_model_key = current_model_key
self.model_loader.load_models(model_name, current_model_key, self.cache_dir, self.exa.get_connection(token_conn))

def set_cache_dir(
self, model_name: str, bucketfs_conn_name: str,
Expand All @@ -195,40 +197,15 @@ def set_cache_dir(
self.cache_dir = bucketfs_operations.get_local_bucketfs_path(
bucketfs_location=bucketfs_location, model_path=str(model_path))

# todo move this also?
def clear_device_memory(self):
"""
Delete models and free device memory
"""
self.last_loaded_model = None
self.last_loaded_tokenizer = None
self.model_loader.last_loaded_model = None
self.model_loader.last_loaded_tokenizer = None
torch.cuda.empty_cache()

def load_models(self, model_name: str, token_conn_name: str) -> None:
"""
Load model and tokenizer model from the cached location in bucketfs.
If the desired model is not cached, this method will attempt to
download the model to the read-only path /bucket/.. and cause an error.
This error will be addressed in ticket
https://github.com/exasol/transformers-extension/issues/43.
:param model_name: The model name to be loaded
"""
token = False
if token_conn_name:
token_conn_obj = self.exa.get_connection(token_conn_name)
token = token_conn_obj.password

self.last_loaded_model = self.base_model.from_pretrained(
model_name, cache_dir=self.cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=self.cache_dir, use_auth_token=token)
self.last_created_pipeline = self.pipeline(
self.task_name,
model=self.last_loaded_model,
tokenizer=self.last_loaded_tokenizer,
device=self.device,
framework="pt")

def get_prediction(self, model_df: pd.DataFrame) -> pd.DataFrame:
"""
Perform prediction of the given model and preparation of the prediction
Expand Down
48 changes: 48 additions & 0 deletions exasol_transformers_extension/utils/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@


class LoadModel:
def __init__(self,
pipeline,
base_model,
tokenizer,
task_name,
device
):
self.pipeline = pipeline
self.base_model = base_model
self.tokenizer = tokenizer
self.task_name = task_name
self.device = device
self.last_loaded_model = None
self.last_loaded_tokenizer = None
self.last_created_pipeline = None
self.last_loaded_model_key = None

def load_models(self, model_name: str,
current_model_key,
cache_dir,
token_conn_obj) -> None:
"""
Load model and tokenizer model from the cached location in bucketfs.
If the desired model is not cached, this method will attempt to
download the model to the read-only path /bucket/.. and cause an error.
This error will be addressed in ticket
https://github.com/exasol/transformers-extension/issues/43.
:param model_name: The model name to be loaded
"""
token = False
if token_conn_obj:
token = token_conn_obj.password

self.last_loaded_model = self.base_model.from_pretrained(
model_name, cache_dir=cache_dir, use_auth_token=token)
self.last_loaded_tokenizer = self.tokenizer.from_pretrained(
model_name, cache_dir=cache_dir, use_auth_token=token)
self.last_created_pipeline = self.pipeline(
self.task_name,
model=self.last_loaded_model,
tokenizer=self.last_loaded_tokenizer,
device=self.device,
framework="pt")
self.last_loaded_model_key = current_model_key

0 comments on commit 360297c

Please sign in to comment.