-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added class for loading locally saved model
- Loading branch information
1 parent
0508cea
commit 0e4ac27
Showing
5 changed files
with
159 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ T.B.D | |
|
||
### Features | ||
|
||
- n/a | ||
- #145: Added load function for loading local models | ||
|
||
### Bug Fixes | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
import transformers.pipelines | ||
from transformers import AutoModel, AutoTokenizer | ||
from pathlib import Path | ||
|
||
|
||
class LoadLocalModel: | ||
""" | ||
Class for loading locally saved models and tokenizers. Also stores information regarding the model and pipeline. | ||
:pipeline: current model pipeline | ||
:task_name: name of the current task | ||
:device: device to be used for pipeline creation | ||
""" | ||
def __init__(self, | ||
pipeline, | ||
task_name, | ||
device | ||
): | ||
self.pipeline = pipeline | ||
self.task_name = task_name | ||
self.device = device | ||
self.last_loaded_model = None | ||
self.last_loaded_tokenizer = None | ||
self.last_loaded_model_key = None | ||
|
||
def load_models(self, model_name: str, | ||
current_model_key, | ||
cache_dir: Path | ||
) -> transformers.pipelines.Pipeline: | ||
""" | ||
Loads a locally saved model and tokenizer from "cache_dir / "pretrained" / model_name". | ||
Returns new pipeline corresponding to the model and task. | ||
:model_name: name of the model to be loaded | ||
:current_model_key: Key of the model to be loaded | ||
:cache_dir: location of the saved model | ||
""" | ||
|
||
self.last_loaded_model = AutoModel.from_pretrained(str(cache_dir / "pretrained" / model_name)) # or do we want to load tokenizer | ||
self.last_loaded_tokenizer = AutoTokenizer.from_pretrained(str(cache_dir / "pretrained" / model_name)) | ||
|
||
last_created_pipeline = self.pipeline( | ||
self.task_name, | ||
model=self.last_loaded_model, | ||
tokenizer=self.last_loaded_tokenizer, | ||
device=self.device, | ||
framework="pt") | ||
self.last_loaded_model_key = current_model_key | ||
return last_created_pipeline | ||
|
||
def clear_device_memory(self): | ||
""" | ||
Delete models and free device memory | ||
""" | ||
self.last_loaded_model = None | ||
self.last_loaded_tokenizer = None | ||
torch.cuda.empty_cache() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
tests/integration_tests/without_db/utils/test_load_local_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from pathlib import Path | ||
from transformers import AutoModel, AutoTokenizer | ||
|
||
from exasol_transformers_extension.utils.load_local_model import LoadLocalModel | ||
|
||
from tests.utils.parameters import model_params | ||
|
||
from tests.fixtures.bucketfs_fixture import bucketfs_location | ||
|
||
import tempfile | ||
|
||
|
||
class TestSetup: | ||
def __init__(self, bucketfs_location): | ||
self.bucketfs_location = bucketfs_location # do with this? | ||
|
||
self.token = "token" | ||
model_params_ = model_params.tiny_model | ||
self.model_name = model_params_ | ||
|
||
self.mock_current_model_key = None | ||
mock_pipeline = lambda task_name, model, tokenizer, device, framework: None | ||
self.loader = LoadLocalModel( | ||
mock_pipeline, | ||
task_name="test_task", | ||
device=0) | ||
|
||
|
||
def test_integration(bucketfs_location): | ||
test_setup = TestSetup(bucketfs_location) | ||
|
||
with tempfile.TemporaryDirectory() as dir: | ||
dir_p = Path(dir) | ||
model_save_path = dir_p / "pretrained" / test_setup.model_name | ||
# download a model | ||
model = AutoModel.from_pretrained(test_setup.model_name) | ||
tokenizer = AutoTokenizer.from_pretrained(test_setup.model_name) | ||
model.save_pretrained(model_save_path) | ||
tokenizer.save_pretrained(model_save_path) | ||
|
||
test_setup.loader.load_models(model_name=test_setup.model_name, | ||
current_model_key=test_setup.mock_current_model_key, | ||
cache_dir=dir_p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Union | ||
from unittest.mock import create_autospec, MagicMock, call | ||
|
||
from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation | ||
|
||
from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol | ||
from exasol_transformers_extension.utils.load_local_model import LoadLocalModel | ||
|
||
from tests.utils.parameters import model_params | ||
|
||
|
||
class TestSetup: | ||
def __init__(self): | ||
|
||
self.bucketfs_location_mock: Union[BucketFSLocation, MagicMock] = create_autospec(BucketFSLocation) | ||
|
||
self.token = "token" | ||
model_params_ = model_params.tiny_model | ||
self.model_name = model_params_ | ||
|
||
mock_pipeline = lambda task_name, model, tokenizer, device, framework: None #todo do we want a pipeline and check creation? | ||
self.loader = LoadLocalModel( | ||
mock_pipeline, | ||
task_name="test_task", | ||
device=0) | ||
|
||
|
||
#todo test current model key? test load model twice, test wrong model given | ||
def test_load_function_call(): | ||
test_setup = TestSetup() | ||
mock_current_model_key = None | ||
with tempfile.TemporaryDirectory() as dir: | ||
dir_p = Path(dir) | ||
cache_dir = dir_p | ||
model_save_path = Path(cache_dir) / "pretrained" / test_setup.model_name | ||
|
||
test_setup.loader.load_models(model_name=test_setup.model_name, | ||
current_model_key=mock_current_model_key, | ||
cache_dir=cache_dir) | ||
|
||
#assert test_setup.model_factory_mock.mock_calls == [ | ||
# call.from_pretrained(str(model_save_path))] | ||
|