From 0e7adb3feaeaf947c145b6099609d9dec952df97 Mon Sep 17 00:00:00 2001 From: Max Zhenzhera <59729293+maxzhenzhera@users.noreply.github.com> Date: Wed, 4 Oct 2023 16:50:22 +0300 Subject: [PATCH] Feature/bind by type handling generic (#110) --- .github/workflows/workflow.yaml | 2 +- di/_container.py | 4 +-- pyproject.toml | 8 ++--- tests/test_binding.py | 64 ++++++++++++++++++++++++++++++++- 4 files changed, 69 insertions(+), 9 deletions(-) diff --git a/.github/workflows/workflow.yaml b/.github/workflows/workflow.yaml index a33bc9ac..486ef145 100644 --- a/.github/workflows/workflow.yaml +++ b/.github/workflows/workflow.yaml @@ -42,7 +42,7 @@ jobs: os: ubuntu-latest - python: "3.10" os: ubuntu-latest - - python: "3.11.0-beta.1 - 3.11" + - python: "3.11" os: ubuntu-latest # test OSs - python: "3.x" diff --git a/di/_container.py b/di/_container.py index 44c4b643..e6e0b7a3 100644 --- a/di/_container.py +++ b/di/_container.py @@ -81,7 +81,7 @@ def bind_by_type( def hook( param: inspect.Parameter | None, dependent: DependentBase[Any] ) -> DependentBase[Any] | None: - if dependent.call is dependency: + if dependent.call == dependency: return provider if param is None: return None @@ -89,7 +89,7 @@ def hook( if type_annotation_option is None: return None type_annotation = type_annotation_option.value - if type_annotation is dependency: + if type_annotation == dependency: return provider if covariant: if inspect.isclass(type_annotation) and inspect.isclass(dependency): diff --git a/pyproject.toml b/pyproject.toml index e4f2e033..48922ed3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "di" -version = "0.77.0" +version = "0.78.0" description = "Dependency injection toolkit" authors = ["Adrian Garcia Badaracco "] readme = "README.md" @@ -38,6 +38,7 @@ anyio = ["anyio"] # linting black = "~23" mypy = "~1" +ruff = "^0.0.286" pre-commit = "~2" # testing pytest = "~7" @@ -48,18 +49,15 @@ coverage = { extras = ["toml"], version = "~6" } # docs mkdocs = "~1" mkdocs-material = "~8,!=8.1.3" +mkdocstrings = {version = "^0.19.0", extras = ["python"]} mike = "~1" # benchmarking pyinstrument = "~4" -mkdocstrings = {version = "^0.19.0", extras = ["python"]} -ruff = "^0.0.286" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" -[tool.isort] -profile = "black" [tool.coverage.run] branch = true diff --git a/tests/test_binding.py b/tests/test_binding.py index 887dec9d..3a2ef761 100644 --- a/tests/test_binding.py +++ b/tests/test_binding.py @@ -1,4 +1,6 @@ -from typing import List +import sys +from abc import abstractmethod +from typing import List, TypeVar import pytest @@ -7,6 +9,11 @@ from di.executors import SyncExecutor from di.typing import Annotated +if sys.version_info < (3, 8): # pragma: no cover + from typing_extensions import Protocol +else: # pragma: no cover + from typing import Protocol + class Request: def __init__(self, value: int = 0) -> None: @@ -47,6 +54,61 @@ def __init__(self, v: int = 1) -> None: assert res.v == 1 +T_co = TypeVar("T_co", covariant=True) + + +def test_bind_generic(): + container = Container() + executor = SyncExecutor() + expected = 100 + + class GetterInterface(Protocol[T_co]): + @abstractmethod + def get(self) -> T_co: + ... + + class GetterIntImpl(GetterInterface[int]): + def __init__(self, v: int) -> None: + self.v = v + + def get(self) -> int: + return self.v + + def factory() -> GetterIntImpl: + return GetterIntImpl(expected) + + hook = bind_by_type( + Dependent(factory), + GetterInterface[int], + ) + container.bind(hook) + + # =========================================== + # clean `_tp_cache` + from typing import _cleanups as cache_cleanups # type: ignore[attr-defined] + + for cache_cleanup in cache_cleanups: + cache_cleanup() + # =========================================== + + class IntService: + """Declared after binding and cache clearing.""" + + def __init__(self, getter: GetterInterface[int]) -> None: + self.getter = getter + + scopes = [None] + flat_dependent = Dependent(GetterInterface[int]) + wired_dependent = Dependent(IntService) + with container.enter_scope(None) as state: + flat_solved = container.solve(flat_dependent, scopes) + wired_solved = container.solve(wired_dependent, scopes) + flat = flat_solved.execute_sync(executor, state) + wired = wired_solved.execute_sync(executor, state) + + assert flat.get() == wired.getter.get() == expected + + def test_bind_transitive_dependency_results_skips_subdpendencies(): """If we bind a transitive dependency none of it's sub-dependencies should be executed since they are no longer required.