diff --git a/exasol_transformers_extension/utils/huggingface_hub_bucketfs_model_transfer_sp.py b/exasol_transformers_extension/utils/huggingface_hub_bucketfs_model_transfer_sp.py new file mode 100644 index 00000000..49d8e31a --- /dev/null +++ b/exasol_transformers_extension/utils/huggingface_hub_bucketfs_model_transfer_sp.py @@ -0,0 +1,78 @@ +import os +import tempfile +from pathlib import Path +from typing import Protocol, Union, runtime_checkable + +import transformers +from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation + +from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploaderFactory +from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory + + +@runtime_checkable +class ModelFactoryProtocol(Protocol): + def from_pretrained(self, model_name: str, cache_dir: Path, use_auth_token: str) -> transformers.PreTrainedModel: + pass + + def save_pretrained(self, save_directory: Union[str, Path]): + pass + + +class HuggingFaceHubBucketFSModelTransferSP: + def __init__(self, + bucketfs_location: BucketFSLocation, + model_name: str, + model_path: Path, + local_model_save_path: Path, + token: str, + temporary_directory_factory: TemporaryDirectoryFactory = TemporaryDirectoryFactory(), + bucketfs_model_uploader_factory: BucketFSModelUploaderFactory = BucketFSModelUploaderFactory()): + self._token = token + self._model_name = model_name + self._local_model_save_path = Path(local_model_save_path) + self._temporary_directory_factory = temporary_directory_factory + self._bucketfs_model_uploader = bucketfs_model_uploader_factory.create( + model_path=model_path, + bucketfs_location=bucketfs_location) + self._tmpdir = temporary_directory_factory.create() + self._tmpdir_name = self._tmpdir.__enter__() + + def __enter__(self): + return self + + def __del__(self): + self._tmpdir.cleanup() + + def __exit__(self, exc_type, exc_val, exc_tb): + self._tmpdir.__exit__(exc_type, exc_val, exc_tb) + + def download_from_huggingface_hub_sp(self, model_factory: ModelFactoryProtocol): + """ + Download a model from HuggingFace Hub into a temporary directory and save it with save_pretrained + at _local_model_save_path / _model_name for local storing + """ + model = model_factory.from_pretrained(self._model_name, cache_dir=self._tmpdir_name, use_auth_token=self._token) + path = self._local_model_save_path / self._model_name + model.save_pretrained(path) #todo save in cachedir in assuption will be uploaded and then deleted? + + def upload_to_bucketfs(self) -> Path: + """ + Upload the downloaded models into the BucketFS + """ + return self._bucketfs_model_uploader.upload_directory(self._tmpdir_name) + + +class HuggingFaceHubBucketFSModelTransferSPFactory: + + def create(self, + bucketfs_location: BucketFSLocation, + model_name: str, + model_path: Path, + local_model_save_path: Path, + token: str) -> HuggingFaceHubBucketFSModelTransferSP: + return HuggingFaceHubBucketFSModelTransferSP(bucketfs_location=bucketfs_location, + model_name=model_name, + model_path=model_path, + local_model_save_path=local_model_save_path, + token=token) diff --git a/tests/unit_tests/utils/test_huggingface_hub_bucketfs__model_transfer_sp.py b/tests/unit_tests/utils/test_huggingface_hub_bucketfs__model_transfer_sp.py new file mode 100644 index 00000000..1cae7766 --- /dev/null +++ b/tests/unit_tests/utils/test_huggingface_hub_bucketfs__model_transfer_sp.py @@ -0,0 +1,85 @@ +import tempfile +from pathlib import Path +from typing import Union +from unittest.mock import create_autospec, MagicMock, call + +from exasol_bucketfs_utils_python.bucketfs_location import BucketFSLocation +from transformers import AutoModel, PreTrainedModel + +from exasol_transformers_extension.utils.bucketfs_model_uploader import BucketFSModelUploader, \ + BucketFSModelUploaderFactory +from exasol_transformers_extension.utils.huggingface_hub_bucketfs_model_transfer_sp import ModelFactoryProtocol, \ + HuggingFaceHubBucketFSModelTransferSP +from exasol_transformers_extension.utils.temporary_directory_factory import TemporaryDirectoryFactory +from tests.utils.mock_cast import mock_cast + +from tests.utils.parameters import model_params + +class TestSetup: + def __init__(self, local_model_save_path: Path = "downloaded_models_test"): + self.bucketfs_location_mock: Union[BucketFSLocation, MagicMock] = create_autospec(BucketFSLocation) + self.model_factory_mock: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol) + self.temporary_directory_factory_mock: Union[TemporaryDirectoryFactory, MagicMock] = \ + create_autospec(TemporaryDirectoryFactory) + self.bucketfs_model_uploader_factory_mock: Union[BucketFSModelUploaderFactory, MagicMock] = \ + create_autospec(BucketFSModelUploaderFactory) + self.bucketfs_model_uploader_mock: Union[BucketFSModelUploader, MagicMock] = \ + create_autospec(BucketFSModelUploader) + mock_cast(self.bucketfs_model_uploader_factory_mock.create).side_effect = [self.bucketfs_model_uploader_mock] + + self.token = "token" + model_params_ = model_params.tiny_model + print(model_params_) + self.model_name = model_params_ + self.model_path = Path("test_model_path") + self.downloader = HuggingFaceHubBucketFSModelTransferSP( + bucketfs_location=self.bucketfs_location_mock, + model_path=self.model_path, + model_name=self.model_name, + local_model_save_path=local_model_save_path, + token=self.token, + temporary_directory_factory=self.temporary_directory_factory_mock, + bucketfs_model_uploader_factory=self.bucketfs_model_uploader_factory_mock + ) + + def reset_mocks(self): + self.bucketfs_location_mock.reset_mock() + self.temporary_directory_factory_mock.reset_mock() + self.model_factory_mock.reset_mock() + self.bucketfs_model_uploader_mock.reset_mock() + + +def test_init(): + test_setup = TestSetup() + assert test_setup.temporary_directory_factory_mock.mock_calls == [call.create(), call.create().__enter__()] \ + and test_setup.model_factory_mock.mock_calls == [] \ + and test_setup.bucketfs_location_mock.mock_calls == [] \ + and mock_cast(test_setup.bucketfs_model_uploader_factory_mock.create).mock_calls == [ + call.create(model_path=test_setup.model_path, bucketfs_location=test_setup.bucketfs_location_mock) + ] + + +def test_download_function_call(): + test_setup = TestSetup() + test_setup.downloader.download_from_huggingface_hub_sp(model_factory=test_setup.model_factory_mock) + cache_dir = test_setup.temporary_directory_factory_mock.create().__enter__() + model_save_path = (test_setup.downloader._local_model_save_path/test_setup.model_name) + assert test_setup.model_factory_mock.mock_calls == [ + call.from_pretrained(test_setup.model_name, cache_dir=cache_dir, + use_auth_token=test_setup.token), + call.from_pretrained().save_pretrained(model_save_path)] + + +# todo add test for model already downloaded? + +def test_download_with_model(): + with tempfile.TemporaryDirectory() as folder: + folder_path = Path(folder) + test_setup = TestSetup(local_model_save_path=folder_path/"downloaded_models") + base_model_factory: ModelFactoryProtocol = AutoModel + test_setup.downloader.download_from_huggingface_hub_sp(model_factory=base_model_factory) + assert AutoModel.from_pretrained(folder_path/"downloaded_models"/test_setup.model_name) + test_setup.downloader.__del__() + #todo delete model + +