Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
disrupted committed Jun 26, 2024
1 parent 628ca21 commit c40ae5c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 23 deletions.
9 changes: 5 additions & 4 deletions kpops/api/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import pkgutil
import sys
from collections.abc import Iterable
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -38,7 +39,7 @@ def find_components(self) -> None:
:param module_name: name of the python module.
"""
custom_modules = self.iter_component_modules()
for _class in _find_classes(*custom_modules, base=PipelineComponent):
for _class in _find_classes(custom_modules, base=PipelineComponent):
self._classes[_class.type] = _class

def __getitem__(self, component_type: str) -> type[PipelineComponent]:
Expand All @@ -56,9 +57,9 @@ def iter_component_modules() -> Iterator[ModuleType]:
yield import_module(module_name)


def find_class(*modules: ModuleType, base: type[T]) -> type[T]:
def find_class(modules: Iterable[ModuleType], base: type[T]) -> type[T]:
try:
return next(_find_classes(*modules, base=base))
return next(_find_classes(modules, base=base))
except StopIteration as e:
raise ClassNotFoundError from e

Expand All @@ -76,7 +77,7 @@ def import_module(module_name: str) -> ModuleType:
return module


def _find_classes(*modules: ModuleType, base: type[T]) -> Iterator[type[T]]:
def _find_classes(modules: Iterable[ModuleType], base: type[T]) -> Iterator[type[T]]:
for module in modules:
for _, _class in inspect.getmembers(module, inspect.isclass):
if not __filter_internal_kpops_classes(
Expand Down
75 changes: 56 additions & 19 deletions tests/api/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,26 @@
import importlib
import shutil
from pathlib import Path
from types import ModuleType

import pytest
from pytest_mock import MockerFixture

from kpops.api.exception import ClassNotFoundError
from kpops.api.registry import Registry, _find_classes, _iter_namespace, find_class
from kpops.component_handlers.schema_handler.schema_provider import SchemaProvider
from kpops.components.base_components.pipeline_component import PipelineComponent
from kpops.components import (
HelmApp,
KafkaApp,
KafkaConnector,
KafkaSinkConnector,
KafkaSourceConnector,
KubernetesApp,
PipelineComponent,
ProducerApp,
StreamsApp,
StreamsBootstrap,
)
from tests.cli.resources.custom_module import CustomSchemaProvider


Expand All @@ -24,7 +36,10 @@ class Unrelated:
pass


@pytest.fixture(autouse=True)
MODULE = SubComponent.__module__


@pytest.fixture()
def custom_components(mocker: MockerFixture):
src = Path("tests/pipeline/test_components")
dst = Path("kpops/components/test_components")
Expand All @@ -35,29 +50,45 @@ def custom_components(mocker: MockerFixture):
shutil.rmtree(dst)


@pytest.mark.usefixtures("custom_components")
def test_iter_namespace():
components_module = importlib.import_module("kpops.components")
assert [
module_name for _, module_name, _ in _iter_namespace(components_module)
] == [
"base_components",
"streams_bootstrap",
"test_components",
"kpops.components.base_components",
"kpops.components.streams_bootstrap",
"kpops.components.test_components",
]


@pytest.mark.usefixtures("custom_components")
def test_iter_component_modules():
assert [module.__name__ for module in Registry.iter_component_modules()] == [
"kpops.components.base_components",
"kpops.components.streams_bootstrap",
"kpops.components.test_components",
]


@pytest.mark.skip()
def test_find_classes():
gen = _find_classes(PipelineComponent)
@pytest.fixture()
def module() -> ModuleType:
return importlib.import_module(MODULE)


def test_find_classes(module: ModuleType):
gen = _find_classes([module], PipelineComponent)
assert next(gen) is SubComponent
assert next(gen) is SubSubComponent
with pytest.raises(StopIteration):
next(gen)


@pytest.mark.skip()
def test_find_builtin_classes():
components = [class_.__name__ for class_ in _find_classes(PipelineComponent)]
modules = Registry.iter_component_modules()
components = [
class_.__name__ for class_ in _find_classes(modules, base=PipelineComponent)
]
assert len(components) == 10
assert components == [
"HelmApp",
Expand All @@ -73,23 +104,29 @@ def test_find_builtin_classes():
]


@pytest.mark.skip()
def test_find_class():
assert find_class(SubComponent) is SubComponent
assert find_class(PipelineComponent) is SubComponent
assert find_class(SchemaProvider) is CustomSchemaProvider
def test_find_class(module: ModuleType):
assert find_class([module], base=SubComponent) is SubComponent
assert find_class([module], base=PipelineComponent) is SubComponent
assert find_class([module], base=SchemaProvider) is CustomSchemaProvider
with pytest.raises(ClassNotFoundError):
find_class(dict)
find_class([module], base=dict)


@pytest.mark.skip()
def test_registry():
registry = Registry()
assert registry._classes == {}
registry.find_components()
assert registry._classes == {
"sub-component": SubComponent,
"sub-sub-component": SubSubComponent,
"helm-app": HelmApp,
"kafka-app": KafkaApp,
"kafka-connector": KafkaConnector,
"kafka-sink-connector": KafkaSinkConnector,
"kafka-source-connector": KafkaSourceConnector,
"kubernetes-app": KubernetesApp,
"pipeline-component": PipelineComponent,
"producer-app": ProducerApp,
"streams-app": StreamsApp,
"streams-bootstrap": StreamsBootstrap,
}
assert registry["sub-component"] is SubComponent
assert registry["sub-sub-component"] is SubSubComponent
Expand Down

0 comments on commit c40ae5c

Please sign in to comment.