diff --git a/backend-app/app/sql_db/file_crud.py b/backend-app/app/sql_db/file_crud.py index 65332798..1d4c2e08 100644 --- a/backend-app/app/sql_db/file_crud.py +++ b/backend-app/app/sql_db/file_crud.py @@ -35,6 +35,8 @@ def create_update_table(df, engine, table_name): def insert_data(db: Session, df: pd.DataFrame, FileTable, update_column_name="id"): + # TODO: error handling + # Ensure all numeric columns are correctly cast to numeric types numeric_cols = df.select_dtypes(include=["number"]).columns df[numeric_cols] = df[numeric_cols].apply(pd.to_numeric, errors="coerce") diff --git a/backend-app/tests/unit/conftest.py b/backend-app/tests/unit/conftest.py index 11e82e3a..b69b395c 100644 --- a/backend-app/tests/unit/conftest.py +++ b/backend-app/tests/unit/conftest.py @@ -1,18 +1,20 @@ -from fastapi.testclient import TestClient +from io import BytesIO + import pandas as pd import pytest -from app.sql_db.database import Base -from app.models.user import UserCreate -from app.models.database import User as db_user -from sqlalchemy.orm import Session +from fastapi.testclient import TestClient from sqlalchemy import create_engine -from io import BytesIO +from sqlalchemy.orm import Session +from app.api.auth import get_current_active_admin, get_current_active_user, get_user_by_email from app.main import app -from app.api.auth import get_current_active_user, get_current_active_admin, get_user_by_email +from app.models import user as api_m +from app.models.database import User as db_user +from app.models.file_db import create_file_table_class, update_schema +from app.models.user import UserCreate from app.sql_db.crud import create_user, get_db, update_is_active, update_is_admin +from app.sql_db.database import Base from app.sql_db.file_crud import create_update_table, insert_data -from app.models import user as api_m @pytest.fixture(scope="session") @@ -117,6 +119,18 @@ def files_good(): return files +@pytest.fixture(scope="function") +def df_files_good(files_good): + df = pd.read_csv(BytesIO(files_good["file"][1].encode())) + return df + + +@pytest.fixture(scope="function") +def file_table_good(df_files_good, db_engine): + file_table, msg = create_update_table(df_files_good, db_engine, "file_table") + return file_table + + @pytest.fixture(scope="function") def files_good_updated(): files = { @@ -128,6 +142,12 @@ def files_good_updated(): return files +@pytest.fixture(scope="function") +def df_files_good_updated(files_good_updated): + df = pd.read_csv(BytesIO(files_good_updated["file"][1].encode())) + return df + + @pytest.fixture(scope="function") def files_bad_type(): files = { diff --git a/backend-app/tests/unit/test_filedb_crud.py b/backend-app/tests/unit/test_filedb_crud.py new file mode 100644 index 00000000..5d03618d --- /dev/null +++ b/backend-app/tests/unit/test_filedb_crud.py @@ -0,0 +1,41 @@ +import logging + +import pytest + +from app.sql_db import file_crud + + +# Test get_db +def test_get_db(db): + session = db + + gen = file_crud.get_db() + db = next(gen) + + try: + assert type(db).__name__ == type(session).__name__ + + finally: + gen.close() + + +def test_create_update_table_new_table(df_files_good, db_engine, caplog): + with caplog.at_level(logging.INFO): + filetable, msg = file_crud.create_update_table(df_files_good, db_engine, "file_table") + assert "Creating new table 'file_table'." in caplog.text + assert msg == "Table with name file_table created" + assert filetable.__tablename__ == "file_table" + + +def test_creat_update_table_existing_table(df_files_good, db_engine, caplog): + with caplog.at_level(logging.INFO): + filetable, msg = file_crud.create_update_table(df_files_good, db_engine, "file_table") + assert "Table 'file_table' already exists. Using existing schema." in caplog.text + assert msg == "Table with name file_table updated" + assert filetable.__tablename__ == "file_table" + + +def test_insert_data_new_table(db, df_files_good, file_table_good, caplog): + with caplog.at_level(logging.INFO): + file_crud.insert_data(db, df_files_good, file_table_good) + assert "Data inserted into file_table" in caplog.text.strip()