Skip to content

Commit

Permalink
[CodeBuild] removed download sample model fixture because of duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
MarleneKress79789 committed Feb 9, 2024
1 parent 5a52093 commit 93efc77
Showing 1 changed file with 4 additions and 24 deletions.
28 changes: 4 additions & 24 deletions tests/integration_tests/with_db/test_upload_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,7 @@
from tests.integration_tests.with_db.udfs.python_rows_to_sql import python_rows_to_sql
from tests.utils import postprocessing
from tests.utils.parameters import bucketfs_params, model_params

from tests.fixtures.setup_database_fixture import setup_database
from tests.fixtures.database_connection_fixture import pyexasol_connection
from tests.fixtures.language_container_fixture import language_alias

#todo just use download model fixture?
@pytest.fixture(scope='function')
def download_sample_models(tmp_path: Path) -> Path:
tmp_path = Path(tmp_path)
for model_factory in [transformers.AutoModel, transformers.AutoTokenizer]:
print("start download")
model = model_factory.from_pretrained(model_params.base_model, cache_dir=tmp_path / "cache")
model.save_pretrained(tmp_path / "pretrained" / model_params.base_model)
print("model saved local")

yield tmp_path / "pretrained" / model_params.base_model, model_params.base_model
from tests.fixtures.model_fixture import download_model


def adapt_file_to_upload(path: PosixPath, download_path: PosixPath):
Expand All @@ -41,14 +26,13 @@ def adapt_file_to_upload(path: PosixPath, download_path: PosixPath):
return PosixPath(path)


def test_model_upload(setup_database, pyexasol_connection, download_sample_models: Path,
def test_model_upload(setup_database, pyexasol_connection, tmp_path: Path,
bucketfs_location: BucketFSLocation, bucketfs_config: config.BucketFs):
sub_dir = 'sub_dir'
download_path, model_name = download_sample_models
model_name = model_params.base_model
download_path = download_model(model_name, tmp_path)
upload_path = bucketfs_operations.get_model_path_with_pretrained(
sub_dir, model_name)
print("upload path")#todo remove prints
print(upload_path)
parsed_url = urlparse(bucketfs_config.url)
host = parsed_url.netloc.split(":")[0]
port = parsed_url.netloc.split(":")[1]
Expand All @@ -70,8 +54,6 @@ def test_model_upload(setup_database, pyexasol_connection, download_sample_model
runner = CliRunner()
result = runner.invoke(upload_model.main, args_list)
assert result.exit_code == 0
print("ls: . ")
print(bucketfs_location.list_files_in_bucketfs("."))
assert str(upload_path.with_suffix(".tar.gz")) in bucketfs_location.list_files_in_bucketfs(".")

bucketfs_conn_name, schema_name = setup_database
Expand Down Expand Up @@ -102,8 +84,6 @@ def test_model_upload(setup_database, pyexasol_connection, download_sample_model

# execute sequence classification UDF
result = pyexasol_connection.execute(query).fetchall()
print("result:")
print(result)
assert len(result) == 1 and result[0][-1] is None
finally:
postprocessing.cleanup_buckets(bucketfs_location, sub_dir)

0 comments on commit 93efc77

Please sign in to comment.