diff --git a/doc/changes/changes_0.7.0.md b/doc/changes/changes_0.7.0.md new file mode 100644 index 00000000..0f3fef39 --- /dev/null +++ b/doc/changes/changes_0.7.0.md @@ -0,0 +1,24 @@ +# Transformers Extension 0.7.0, released T.B.D + +Code name: T.B.D + + +## Summary + +T.B.D + +### Features + +### Bug Fixes + +### Refactorings + + - #144: Extracted base_model_udf.load_models into separate class + + +### Documentation + + + +### Security + - #144: Updated Cryptography to version 41.0.7 \ No newline at end of file diff --git a/exasol_transformers_extension/udfs/models/base_model_udf.py b/exasol_transformers_extension/udfs/models/base_model_udf.py index 675ca10d..343db5a0 100644 --- a/exasol_transformers_extension/udfs/models/base_model_udf.py +++ b/exasol_transformers_extension/udfs/models/base_model_udf.py @@ -7,6 +7,7 @@ from exasol_transformers_extension.deployment import constants from exasol_transformers_extension.utils import device_management, \ bucketfs_operations, dataframe_operations +from exasol_transformers_extension.utils.load_model import LoadModel class BaseModelUDF(ABC): @@ -20,7 +21,6 @@ class BaseModelUDF(ABC): - creates model pipeline through transformer api - manages the creation of predictions and the preparation of results. """ - def __init__(self, exa, batch_size, @@ -36,15 +36,14 @@ def __init__(self, self.task_name = task_name self.device = None self.cache_dir = None - self.last_loaded_model_key = None - self.last_loaded_model = None - self.last_loaded_tokenizer = None + self.model_loader = None self.last_created_pipeline = None self.new_columns = [] def run(self, ctx): device_id = ctx.get_dataframe(1).iloc[0]['device_id'] self.device = device_management.get_torch_device(device_id) + self.create_model_loader() ctx.reset() while True: @@ -54,7 +53,17 @@ def run(self, ctx): predictions_df = self.get_predictions_from_batch(batch_df) ctx.emit(predictions_df) - self.clear_device_memory() + self.model_loader.clear_device_memory() + + def create_model_loader(self): + """ + Creates the model_loader. + """ + self.model_loader = LoadModel(self.pipeline, + self.base_model, + self.tokenizer, + self.task_name, + self.device) def get_predictions_from_batch(self, batch_df: pd.DataFrame) -> pd.DataFrame: """ @@ -171,11 +180,17 @@ def check_cache(self, model_df: pd.DataFrame) -> None: token_conn = model_df["token_conn"].iloc[0] current_model_key = (bucketfs_conn, sub_dir, model_name, token_conn) - if self.last_loaded_model_key != current_model_key: + if self.model_loader.last_loaded_model_key != current_model_key: self.set_cache_dir(model_name, bucketfs_conn, sub_dir) - self.clear_device_memory() - self.load_models(model_name, token_conn) - self.last_loaded_model_key = current_model_key + self.model_loader.clear_device_memory() + if token_conn: + token_conn_obj = self.exa.get_connection(token_conn) + else: + token_conn_obj = None + self.last_created_pipeline = self.model_loader.load_models(model_name, + current_model_key, + self.cache_dir, + token_conn_obj) def set_cache_dir( self, model_name: str, bucketfs_conn_name: str, @@ -195,39 +210,6 @@ def set_cache_dir( self.cache_dir = bucketfs_operations.get_local_bucketfs_path( bucketfs_location=bucketfs_location, model_path=str(model_path)) - def clear_device_memory(self): - """ - Delete models and free device memory - """ - self.last_loaded_model = None - self.last_loaded_tokenizer = None - torch.cuda.empty_cache() - - def load_models(self, model_name: str, token_conn_name: str) -> None: - """ - Load model and tokenizer model from the cached location in bucketfs. - If the desired model is not cached, this method will attempt to - download the model to the read-only path /bucket/.. and cause an error. - This error will be addressed in ticket - https://github.com/exasol/transformers-extension/issues/43. - - :param model_name: The model name to be loaded - """ - token = False - if token_conn_name: - token_conn_obj = self.exa.get_connection(token_conn_name) - token = token_conn_obj.password - - self.last_loaded_model = self.base_model.from_pretrained( - model_name, cache_dir=self.cache_dir, use_auth_token=token) - self.last_loaded_tokenizer = self.tokenizer.from_pretrained( - model_name, cache_dir=self.cache_dir, use_auth_token=token) - self.last_created_pipeline = self.pipeline( - self.task_name, - model=self.last_loaded_model, - tokenizer=self.last_loaded_tokenizer, - device=self.device, - framework="pt") def get_prediction(self, model_df: pd.DataFrame) -> pd.DataFrame: """ diff --git a/exasol_transformers_extension/utils/load_model.py b/exasol_transformers_extension/utils/load_model.py new file mode 100644 index 00000000..74c5b881 --- /dev/null +++ b/exasol_transformers_extension/utils/load_model.py @@ -0,0 +1,56 @@ +import torch + +class LoadModel: + def __init__(self, + pipeline, + base_model, + tokenizer, + task_name, + device + ): + self.pipeline = pipeline + self.base_model = base_model + self.tokenizer = tokenizer + self.task_name = task_name + self.device = device + self.last_loaded_model = None + self.last_loaded_tokenizer = None + self.last_loaded_model_key = None + + def load_models(self, model_name: str, + current_model_key, + cache_dir, + token_conn_obj) -> None: + """ + Load model and tokenizer model from the cached location in bucketfs. + If the desired model is not cached, this method will attempt to + download the model to the read-only path /bucket/.. and cause an error. + This error will be addressed in ticket + https://github.com/exasol/transformers-extension/issues/43. + + :param model_name: The model name to be loaded + """ + token = False + if token_conn_obj: + token = token_conn_obj.password + + self.last_loaded_model = self.base_model.from_pretrained( + model_name, cache_dir=cache_dir, use_auth_token=token) + self.last_loaded_tokenizer = self.tokenizer.from_pretrained( + model_name, cache_dir=cache_dir, use_auth_token=token) + last_created_pipeline = self.pipeline( + self.task_name, + model=self.last_loaded_model, + tokenizer=self.last_loaded_tokenizer, + device=self.device, + framework="pt") + self.last_loaded_model_key = current_model_key + return last_created_pipeline + + def clear_device_memory(self): + """ + Delete models and free device memory + """ + self.last_loaded_model = None + self.last_loaded_tokenizer = None + torch.cuda.empty_cache() \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index d3660465..87464f1b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -304,35 +304,35 @@ development = ["black", "flake8", "mypy", "pytest", "types-colorama"] [[package]] name = "cryptography" -version = "41.0.5" +version = "41.0.7" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "cryptography-41.0.5-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:da6a0ff8f1016ccc7477e6339e1d50ce5f59b88905585f77193ebd5068f1e797"}, - {file = "cryptography-41.0.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b948e09fe5fb18517d99994184854ebd50b57248736fd4c720ad540560174ec5"}, - {file = "cryptography-41.0.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e6031e113b7421db1de0c1b1f7739564a88f1684c6b89234fbf6c11b75147"}, - {file = "cryptography-41.0.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e270c04f4d9b5671ebcc792b3ba5d4488bf7c42c3c241a3748e2599776f29696"}, - {file = "cryptography-41.0.5-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ec3b055ff8f1dce8e6ef28f626e0972981475173d7973d63f271b29c8a2897da"}, - {file = "cryptography-41.0.5-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:7d208c21e47940369accfc9e85f0de7693d9a5d843c2509b3846b2db170dfd20"}, - {file = "cryptography-41.0.5-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:8254962e6ba1f4d2090c44daf50a547cd5f0bf446dc658a8e5f8156cae0d8548"}, - {file = "cryptography-41.0.5-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:a48e74dad1fb349f3dc1d449ed88e0017d792997a7ad2ec9587ed17405667e6d"}, - {file = "cryptography-41.0.5-cp37-abi3-win32.whl", hash = "sha256:d3977f0e276f6f5bf245c403156673db103283266601405376f075c849a0b936"}, - {file = "cryptography-41.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:73801ac9736741f220e20435f84ecec75ed70eda90f781a148f1bad546963d81"}, - {file = "cryptography-41.0.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3be3ca726e1572517d2bef99a818378bbcf7d7799d5372a46c79c29eb8d166c1"}, - {file = "cryptography-41.0.5-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:e886098619d3815e0ad5790c973afeee2c0e6e04b4da90b88e6bd06e2a0b1b72"}, - {file = "cryptography-41.0.5-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:573eb7128cbca75f9157dcde974781209463ce56b5804983e11a1c462f0f4e88"}, - {file = "cryptography-41.0.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0c327cac00f082013c7c9fb6c46b7cc9fa3c288ca702c74773968173bda421bf"}, - {file = "cryptography-41.0.5-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:227ec057cd32a41c6651701abc0328135e472ed450f47c2766f23267b792a88e"}, - {file = "cryptography-41.0.5-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:22892cc830d8b2c89ea60148227631bb96a7da0c1b722f2aac8824b1b7c0b6b8"}, - {file = "cryptography-41.0.5-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:5a70187954ba7292c7876734183e810b728b4f3965fbe571421cb2434d279179"}, - {file = "cryptography-41.0.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:88417bff20162f635f24f849ab182b092697922088b477a7abd6664ddd82291d"}, - {file = "cryptography-41.0.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c707f7afd813478e2019ae32a7c49cd932dd60ab2d2a93e796f68236b7e1fbf1"}, - {file = "cryptography-41.0.5-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:580afc7b7216deeb87a098ef0674d6ee34ab55993140838b14c9b83312b37b86"}, - {file = "cryptography-41.0.5-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:fba1e91467c65fe64a82c689dc6cf58151158993b13eb7a7f3f4b7f395636723"}, - {file = "cryptography-41.0.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0d2a6a598847c46e3e321a7aef8af1436f11c27f1254933746304ff014664d84"}, - {file = "cryptography-41.0.5.tar.gz", hash = "sha256:392cb88b597247177172e02da6b7a63deeff1937fa6fec3bbf902ebd75d97ec7"}, + {file = "cryptography-41.0.7-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:3c78451b78313fa81607fa1b3f1ae0a5ddd8014c38a02d9db0616133987b9cdf"}, + {file = "cryptography-41.0.7-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:928258ba5d6f8ae644e764d0f996d61a8777559f72dfeb2eea7e2fe0ad6e782d"}, + {file = "cryptography-41.0.7-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a1b41bc97f1ad230a41657d9155113c7521953869ae57ac39ac7f1bb471469a"}, + {file = "cryptography-41.0.7-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:841df4caa01008bad253bce2a6f7b47f86dc9f08df4b433c404def869f590a15"}, + {file = "cryptography-41.0.7-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5429ec739a29df2e29e15d082f1d9ad683701f0ec7709ca479b3ff2708dae65a"}, + {file = "cryptography-41.0.7-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:43f2552a2378b44869fe8827aa19e69512e3245a219104438692385b0ee119d1"}, + {file = "cryptography-41.0.7-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:af03b32695b24d85a75d40e1ba39ffe7db7ffcb099fe507b39fd41a565f1b157"}, + {file = "cryptography-41.0.7-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:49f0805fc0b2ac8d4882dd52f4a3b935b210935d500b6b805f321addc8177406"}, + {file = "cryptography-41.0.7-cp37-abi3-win32.whl", hash = "sha256:f983596065a18a2183e7f79ab3fd4c475205b839e02cbc0efbbf9666c4b3083d"}, + {file = "cryptography-41.0.7-cp37-abi3-win_amd64.whl", hash = "sha256:90452ba79b8788fa380dfb587cca692976ef4e757b194b093d845e8d99f612f2"}, + {file = "cryptography-41.0.7-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:079b85658ea2f59c4f43b70f8119a52414cdb7be34da5d019a77bf96d473b960"}, + {file = "cryptography-41.0.7-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:b640981bf64a3e978a56167594a0e97db71c89a479da8e175d8bb5be5178c003"}, + {file = "cryptography-41.0.7-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e3114da6d7f95d2dee7d3f4eec16dacff819740bbab931aff8648cb13c5ff5e7"}, + {file = "cryptography-41.0.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d5ec85080cce7b0513cfd233914eb8b7bbd0633f1d1703aa28d1dd5a72f678ec"}, + {file = "cryptography-41.0.7-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7a698cb1dac82c35fcf8fe3417a3aaba97de16a01ac914b89a0889d364d2f6be"}, + {file = "cryptography-41.0.7-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:37a138589b12069efb424220bf78eac59ca68b95696fc622b6ccc1c0a197204a"}, + {file = "cryptography-41.0.7-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:68a2dec79deebc5d26d617bfdf6e8aab065a4f34934b22d3b5010df3ba36612c"}, + {file = "cryptography-41.0.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:09616eeaef406f99046553b8a40fbf8b1e70795a91885ba4c96a70793de5504a"}, + {file = "cryptography-41.0.7-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48a0476626da912a44cc078f9893f292f0b3e4c739caf289268168d8f4702a39"}, + {file = "cryptography-41.0.7-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c7f3201ec47d5207841402594f1d7950879ef890c0c495052fa62f58283fde1a"}, + {file = "cryptography-41.0.7-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c5ca78485a255e03c32b513f8c2bc39fedb7f5c5f8535545bdc223a03b24f248"}, + {file = "cryptography-41.0.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d6c391c021ab1f7a82da5d8d0b3cee2f4b2c455ec86c8aebbc84837a631ff309"}, + {file = "cryptography-41.0.7.tar.gz", hash = "sha256:13f93ce9bea8016c253b34afc6bd6a75993e5c40672ed5405a9c832f0d4a00bc"}, ] [package.dependencies] diff --git a/tests/unit_tests/udfs/base_model_dummy_implementation.py b/tests/unit_tests/udfs/base_model_dummy_implementation.py index 801549ab..f15ea1f7 100644 --- a/tests/unit_tests/udfs/base_model_dummy_implementation.py +++ b/tests/unit_tests/udfs/base_model_dummy_implementation.py @@ -44,9 +44,3 @@ def create_dataframes_from_predictions( results_df_list.append(result_df) return results_df_list - def load_models(self, model_name: str, token_conn_name: str) -> None: - token = False - self.last_loaded_model = self.base_model.from_pretrained( - model_name, cache_dir=self.cache_dir, use_auth_token=token) - self.last_loaded_tokenizer = self.tokenizer.from_pretrained( - model_name, cache_dir=self.cache_dir, use_auth_token=token) \ No newline at end of file diff --git a/tests/unit_tests/udfs/test_base_udf.py b/tests/unit_tests/udfs/test_base_udf.py index 6f8328fc..a43c21cb 100644 --- a/tests/unit_tests/udfs/test_base_udf.py +++ b/tests/unit_tests/udfs/test_base_udf.py @@ -10,6 +10,7 @@ from tests.unit_tests.utils_for_udf_tests import create_mock_exa_environment, create_mock_udf_context from tests.unit_tests.udfs.base_model_dummy_implementation import DummyImplementationUDF from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer import ModelFactoryProtocol +from exasol_transformers_extension.utils.load_model import LoadModel from tests.utils.mock_cast import mock_cast import re @@ -78,10 +79,13 @@ def setup_tests_and_run(bucketfs_conn_name, bucketfs_conn, sub_dir, model_name): mock_meta, '', None) + + mock_pipeline = lambda task_name, model, tokenizer, device, framework: None mock_ctx = create_mock_udf_context(input_data, mock_meta) udf = DummyImplementationUDF(exa=mock_exa, - base_model=mock_base_model_factory, - tokenizer=mock_tokenizer_factory) + base_model=mock_base_model_factory, + tokenizer=mock_tokenizer_factory, + pipeline=mock_pipeline) udf.run(mock_ctx) res = mock_ctx.output return res, mock_meta