Skip to content

Commit

Permalink
Ignore variadic argument, better signature detection (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
elfjes authored Jan 14, 2022
1 parent b42046b commit 280aa45
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/gimme/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .resolvers import AttributeResolver, TypeHintingResolver
from .types import DependencyInfo

__version__ = "0.2.2"
__version__ = "0.2.3"

__all__ = [
"add",
Expand Down
29 changes: 26 additions & 3 deletions src/gimme/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,38 @@ def get_dependencies(
signature = inspect.signature(factory)
except ValueError: # some builtin types
raise CannotResolve()

try:
type_hints = self.get_type_hints(factory, repository)
except NameError as e:
raise CannotResolve() from e

# The signature of a callable may differ from its type annotations, for example when
# __new__ has been overridden with (*args, **kwargs) `inspect.signature` can then
# not properly determine the signature. We have to make sure resolve any
# non-variadic arguments that have no default value, as well as all parameters in
# __annotations__ for arguments that do not have a default value

nonvariadic_params = {
name
for name, param in signature.parameters.items()
if param.kind not in (param.VAR_KEYWORD, param.VAR_POSITIONAL)
}
default_params = {
name
for name, param in signature.parameters.items()
if param.default is not param.empty
}
all_required = (
(nonvariadic_params | set(type_hints.keys()))
- default_params
- set(kwargs.keys())
- {"return"} # return is a special annotation indiciting the return type
)

dependencies = {}
for key, param in signature.parameters.items():
if key in kwargs or param.default is not inspect.Parameter.empty:
continue

for key in all_required:

try:
annotation = type_hints[key]
Expand Down
36 changes: 31 additions & 5 deletions tests/test_resolvers/test_type_hinting_resolver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
import typing as t
from unittest import mock
from unittest.mock import Mock, call

Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(self, a: int, b: str, c=None, d: list = None):

assert plugin.get_dependencies(ClassWithDependencies, repo).keys() == {"a", "b"}

assert repo.get.call_args_list == [call(int), call(str)]
assert {call_args[0][0] for call_args in repo.get.call_args_list} == {int, str}


def test_dont_resolve_dependency_that_is_also_in_kwargs(plugin, repo):
Expand Down Expand Up @@ -75,11 +75,11 @@ def function(a: int, b: str, c: list = None, *, d: set, e: tuple = ()):
pass

assert plugin.get_dependencies(function, repo).keys() == {"a", "b", "d"}
assert repo.get.call_args_list == [call(int), call(str), call(set)]
assert {call_args[0][0] for call_args in repo.get.call_args_list} == {int, str, set}


def test_get_dependencies_with_generic_type(plugin, repo):
def function(a: Tuple[int, ...]):
def function(a: t.Tuple[int, ...]):
pass

repo.get.return_value = [1, 2]
Expand All @@ -88,7 +88,7 @@ def function(a: Tuple[int, ...]):


def test_cannot_resolve_invalid_type_hint(plugin, repo):
def function(a: Tuple[int, int]):
def function(a: t.Tuple[int, int]):
pass

with pytest.raises(CannotResolve):
Expand All @@ -99,3 +99,29 @@ def function(a: Tuple[int, int]):
def test_raise_cannot_resolve_on_unresolvable_builtin_types(plugin, repo, tp):
with pytest.raises(CannotResolve):
plugin.get_dependencies(tp, repo)


def test_ignores_variadic_positional_arguments(plugin, repo):
def func(*args):
return 42

assert plugin.get_dependencies(func, repo) == {}


def test_ignores_variadic_keyword_arguments(plugin, repo):
def func(**kwargs):
return 42

assert plugin.get_dependencies(func, repo) == {}


def test_with_overridden_new(plugin, repo):
class MyClass:
def __new__(cls, *args, **kwargs):
return object.__new__(cls)

def __init__(self, dep: int) -> None:
self.dep = dep

plugin.get_dependencies(MyClass, repo)
assert repo.get.call_args == call(int)

0 comments on commit 280aa45

Please sign in to comment.