Skip to content

Commit

Permalink
Fix include package itself for module discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
disrupted committed Jun 26, 2024
1 parent 246541f commit f76c90d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 27 deletions.
11 changes: 7 additions & 4 deletions kpops/api/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __getitem__(self, component_type: str) -> type[PipelineComponent]:
def iter_component_modules() -> Iterator[ModuleType]:
import kpops.components

for _, module_name, _ in _iter_namespace(kpops.components):
yield import_module(module_name)
yield kpops.components
yield from _iter_namespace(kpops.components)


def find_class(modules: Iterable[ModuleType], base: type[T]) -> type[T]:
Expand Down Expand Up @@ -93,5 +93,8 @@ def __filter_internal_kpops_classes(class_module: str, module_name: str) -> bool
)


def _iter_namespace(ns_pkg: ModuleType) -> Iterator[pkgutil.ModuleInfo]:
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
def _iter_namespace(ns_pkg: ModuleType) -> Iterator[ModuleType]:
for _, module_name, _ in pkgutil.iter_modules(
ns_pkg.__path__, ns_pkg.__name__ + "."
):
yield import_module(module_name)
25 changes: 2 additions & 23 deletions tests/api/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def custom_components(mocker: MockerFixture):
@pytest.mark.usefixtures("custom_components")
def test_iter_namespace():
components_module = importlib.import_module("kpops.components")
assert [
module_name for _, module_name, _ in _iter_namespace(components_module)
] == [
assert [module.__name__ for module in _iter_namespace(components_module)] == [
"kpops.components.base_components",
"kpops.components.streams_bootstrap",
"kpops.components.test_components",
Expand All @@ -65,6 +63,7 @@ def test_iter_namespace():
@pytest.mark.usefixtures("custom_components")
def test_iter_component_modules():
assert [module.__name__ for module in Registry.iter_component_modules()] == [
"kpops.components",
"kpops.components.base_components",
"kpops.components.streams_bootstrap",
"kpops.components.test_components",
Expand All @@ -84,26 +83,6 @@ def test_find_classes(module: ModuleType):
next(gen)


def test_find_builtin_classes():
modules = Registry.iter_component_modules()
components = [
class_.__name__ for class_ in _find_classes(modules, base=PipelineComponent)
]
assert len(components) == 10
assert components == [
"HelmApp",
"KafkaApp",
"KafkaConnector",
"KafkaSinkConnector",
"KafkaSourceConnector",
"KubernetesApp",
"PipelineComponent",
"ProducerApp",
"StreamsApp",
"StreamsBootstrap",
]


def test_find_class(module: ModuleType):
assert find_class([module], base=SubComponent) is SubComponent
assert find_class([module], base=PipelineComponent) is SubComponent
Expand Down

0 comments on commit f76c90d

Please sign in to comment.