diff --git a/src/koza/cli_runner.py b/src/koza/cli_utils.py similarity index 92% rename from src/koza/cli_runner.py rename to src/koza/cli_utils.py index ab4a76e..93e87bb 100644 --- a/src/koza/cli_runner.py +++ b/src/koza/cli_utils.py @@ -36,7 +36,7 @@ def get_koza_app(source_name) -> Optional[KozaApp]: def transform_source( source: str, output_dir: str, - output_format: OutputFormat = OutputFormat('tsv'), + output_format: OutputFormat = OutputFormat("tsv"), global_table: str = None, local_table: str = None, schema: str = None, @@ -61,9 +61,9 @@ def transform_source( """ logger = get_logger(name=Path(source).name if log else None, verbose=verbose) - with open(source, 'r') as source_fh: + with open(source, "r") as source_fh: source_config = PrimaryFileConfig(**yaml.load(source_fh, Loader=UniqueIncludeLoader)) - + # TODO: Try moving this to source_config class if not source_config.name: source_config.name = Path(source).stem @@ -73,7 +73,7 @@ def transform_source( if not Path(filename).exists(): filename = Path(source).parent / "transform.py" if not Path(filename).exists(): - raise FileNotFoundError(f"Could not find transform file for {source}") + raise FileNotFoundError(f"Could not find transform file for {source}") source_config.transform_code = filename koza_source = Source(source_config, row_limit) @@ -94,7 +94,7 @@ def transform_source( def validate_file( file: str, format: FormatType = FormatType.csv, - delimiter: str = ',', + delimiter: str = ",", header_delimiter: str = None, skip_blank_lines: bool = True, ): @@ -149,14 +149,14 @@ def get_translation_table( logger.debug("No global table used for transform") else: if isinstance(global_table, str): - with open(global_table, 'r') as global_tt_fh: + with open(global_table, "r") as global_tt_fh: global_tt = yaml.safe_load(global_tt_fh) elif isinstance(global_table, Dict): global_tt = global_table if local_table: if isinstance(local_table, str): - with open(local_table, 'r') as local_tt_fh: + with open(local_table, "r") as local_tt_fh: local_tt = yaml.safe_load(local_tt_fh) elif isinstance(local_table, Dict): local_tt = local_table @@ -170,8 +170,8 @@ def get_translation_table( def _set_koza_app( source: Source, translation_table: TranslationTable = None, - output_dir: str = './output', - output_format: OutputFormat = OutputFormat('tsv'), + output_dir: str = "./output", + output_format: OutputFormat = OutputFormat("tsv"), schema: str = None, node_type: str = None, edge_type: str = None, @@ -184,9 +184,3 @@ def _set_koza_app( ) logger.debug(f"koza_apps entry created for {source.config.name}: {koza_apps[source.config.name]}") return koza_apps[source.config.name] - - -def test_koza(koza: KozaApp): - """Manually sets KozaApp (for testing)""" - global koza_app - koza_app = koza diff --git a/src/koza/converter/biolink_converter.py b/src/koza/converter/biolink_converter.py index e5d6747..af6eb60 100644 --- a/src/koza/converter/biolink_converter.py +++ b/src/koza/converter/biolink_converter.py @@ -1,6 +1,6 @@ from biolink_model.datamodel.pydanticmodel_v2 import Gene -from koza.cli_runner import koza_app +from koza.cli_utils import koza_app def gpi2gene(row: dict) -> Gene: diff --git a/src/koza/main.py b/src/koza/main.py index 74bc55b..b349a56 100755 --- a/src/koza/main.py +++ b/src/koza/main.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Optional -from koza.cli_runner import transform_source, validate_file +from koza.cli_utils import transform_source, validate_file from koza.model.config.source_config import FormatType, OutputFormat import typer diff --git a/src/koza/utils/testing_utils.py b/src/koza/utils/testing_utils.py new file mode 100644 index 0000000..b4bf0bb --- /dev/null +++ b/src/koza/utils/testing_utils.py @@ -0,0 +1,78 @@ +import types +from typing import Iterable + +from loguru import logger + +from koza.app import KozaApp +from koza.cli_utils import get_koza_app, get_translation_table, _set_koza_app +from koza.model.config.source_config import PrimaryFileConfig +from koza.model.source import Source + +def test_koza(koza: KozaApp): + """Manually sets KozaApp for testing""" + global koza_app + koza_app = koza + +def mock_koza(): + """Mock KozaApp for testing""" + def _mock_write(self, *entities): + if hasattr(self, '_entities'): + self._entities.extend(list(entities)) + else: + self._entities = list(entities) + + def _make_mock_koza_app( + name: str, + data: Iterable, + transform_code: str, + map_cache=None, + filters=None, + global_table=None, + local_table=None, + ): + mock_source_file_config = PrimaryFileConfig( + name=name, + files=[], + transform_code=transform_code, + ) + mock_source_file = Source(mock_source_file_config) + mock_source_file._reader = data + + _set_koza_app( + source=mock_source_file, + translation_table=get_translation_table(global_table, local_table, logger), + logger=logger, + ) + koza = get_koza_app(name) + + # TODO filter mocks + koza._map_cache = map_cache + koza.write = types.MethodType(_mock_write, koza) + + return koza + + def _transform( + name: str, + data: Iterable, + transform_code: str, + map_cache=None, + filters=None, + global_table=None, + local_table=None, + ): + koza_app = _make_mock_koza_app( + name, + data, + transform_code, + map_cache=map_cache, + filters=filters, + global_table=global_table, + local_table=local_table, + ) + test_koza(koza_app) + koza_app.process_sources() + if not hasattr(koza_app, '_entities'): + koza_app._entities = [] + return koza_app._entities + + return _transform