Skip to content

Commit

Permalink
Added class for loading locally saved model
Browse files Browse the repository at this point in the history
  • Loading branch information
MarleneKress79789 committed Jan 17, 2024
1 parent 0508cea commit 0e4ac27
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 4 deletions.
2 changes: 1 addition & 1 deletion doc/changes/changes_0.8.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ T.B.D

### Features

- n/a
- #145: Added load function for loading local models

### Bug Fixes

Expand Down
58 changes: 58 additions & 0 deletions exasol_transformers_extension/utils/load_local_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import transformers.pipelines
from transformers import AutoModel, AutoTokenizer
from pathlib import Path


class LoadLocalModel:
"""
Class for loading locally saved models and tokenizers. Also stores information regarding the model and pipeline.
:pipeline: current model pipeline
:task_name: name of the current task
:device: device to be used for pipeline creation
"""
def __init__(self,
pipeline,
task_name,
device
):
self.pipeline = pipeline
self.task_name = task_name
self.device = device
self.last_loaded_model = None
self.last_loaded_tokenizer = None
self.last_loaded_model_key = None

def load_models(self, model_name: str,
current_model_key,
cache_dir: Path
) -> transformers.pipelines.Pipeline:
"""
Loads a locally saved model and tokenizer from "cache_dir / "pretrained" / model_name".
Returns new pipeline corresponding to the model and task.
:model_name: name of the model to be loaded
:current_model_key: Key of the model to be loaded
:cache_dir: location of the saved model
"""

self.last_loaded_model = AutoModel.from_pretrained(str(cache_dir / "pretrained" / model_name)) # or do we want to load tokenizer
self.last_loaded_tokenizer = AutoTokenizer.from_pretrained(str(cache_dir / "pretrained" / model_name))

last_created_pipeline = self.pipeline(
self.task_name,
model=self.last_loaded_model,
tokenizer=self.last_loaded_tokenizer,
device=self.device,
framework="pt")
self.last_loaded_model_key = current_model_key
return last_created_pipeline

def clear_device_memory(self):
"""
Delete models and free device memory
"""
self.last_loaded_model = None
self.last_loaded_tokenizer = None
torch.cuda.empty_cache()
15 changes: 12 additions & 3 deletions exasol_transformers_extension/utils/model_factory_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Protocol, Union, runtime_checkable
from typing import Protocol, Union, runtime_checkable, Optional

import transformers

Expand All @@ -9,8 +9,17 @@ class ModelFactoryProtocol(Protocol):
"""
Protocol for better type hints.
"""
def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str) -> transformers.PreTrainedModel:
def from_pretrained(self, model_name: str, cache_dir: Optional[Path]=None, use_auth_token: Optional[str]=None) \
-> transformers.PreTrainedModel:
"""
Either downloads a model from Huggingface Hub(all parameters required),
or loads a locally saved model from file (only requires filepath)
:model_name: model name, or path to locally saved model files
:cache_dir: optional. Path where downloaded model should be cached
:use_auth_token: optional. token for Huggingface hub private models
"""
pass

def save_pretrained(self, save_directory: Union[str, Path]):
pass
pass
43 changes: 43 additions & 0 deletions tests/integration_tests/without_db/utils/test_load_local_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from pathlib import Path
from transformers import AutoModel, AutoTokenizer

from exasol_transformers_extension.utils.load_local_model import LoadLocalModel

from tests.utils.parameters import model_params

from tests.fixtures.bucketfs_fixture import bucketfs_location

import tempfile


class TestSetup:
def __init__(self, bucketfs_location):
self.bucketfs_location = bucketfs_location # do with this?

self.token = "token"
model_params_ = model_params.tiny_model
self.model_name = model_params_

self.mock_current_model_key = None
mock_pipeline = lambda task_name, model, tokenizer, device, framework: None
self.loader = LoadLocalModel(
mock_pipeline,
task_name="test_task",
device=0)


def test_integration(bucketfs_location):
test_setup = TestSetup(bucketfs_location)

with tempfile.TemporaryDirectory() as dir:
dir_p = Path(dir)
model_save_path = dir_p / "pretrained" / test_setup.model_name
# download a model
model = AutoModel.from_pretrained(test_setup.model_name)
tokenizer = AutoTokenizer.from_pretrained(test_setup.model_name)
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

test_setup.loader.load_models(model_name=test_setup.model_name,
current_model_key=test_setup.mock_current_model_key,
cache_dir=dir_p)
45 changes: 45 additions & 0 deletions tests/unit_tests/utils/test_load_local_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import tempfile
from pathlib import Path
from typing import Union
from unittest.mock import create_autospec, MagicMock, call

from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation

from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol
from exasol_transformers_extension.utils.load_local_model import LoadLocalModel

from tests.utils.parameters import model_params


class TestSetup:
def __init__(self):

self.bucketfs_location_mock: Union[BucketFSLocation, MagicMock] = create_autospec(BucketFSLocation)

self.token = "token"
model_params_ = model_params.tiny_model
self.model_name = model_params_

mock_pipeline = lambda task_name, model, tokenizer, device, framework: None #todo do we want a pipeline and check creation?
self.loader = LoadLocalModel(
mock_pipeline,
task_name="test_task",
device=0)


#todo test current model key? test load model twice, test wrong model given
def test_load_function_call():
test_setup = TestSetup()
mock_current_model_key = None
with tempfile.TemporaryDirectory() as dir:
dir_p = Path(dir)
cache_dir = dir_p
model_save_path = Path(cache_dir) / "pretrained" / test_setup.model_name

test_setup.loader.load_models(model_name=test_setup.model_name,
current_model_key=mock_current_model_key,
cache_dir=cache_dir)

#assert test_setup.model_factory_mock.mock_calls == [
# call.from_pretrained(str(model_save_path))]

0 comments on commit 0e4ac27

Please sign in to comment.