From 22b5c0a0fb2b0f92e463a294e1b5589c4d5a77b7 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Mon, 25 Oct 2021 11:33:30 +0300 Subject: [PATCH] internal/wire: add method providers support --- internal/wire/parse.go | 54 ++++++++++++-- .../wire/testdata/MethodProvider/foo/foo.go | 74 +++++++++++++++++++ .../wire/testdata/MethodProvider/foo/wire.go | 27 +++++++ internal/wire/testdata/MethodProvider/pkg | 1 + .../MethodProvider/want/program_out.txt | 1 + .../testdata/MethodProvider/want/wire_gen.go | 19 +++++ 6 files changed, 171 insertions(+), 5 deletions(-) create mode 100644 internal/wire/testdata/MethodProvider/foo/foo.go create mode 100644 internal/wire/testdata/MethodProvider/foo/wire.go create mode 100644 internal/wire/testdata/MethodProvider/pkg create mode 100644 internal/wire/testdata/MethodProvider/want/program_out.txt create mode 100644 internal/wire/testdata/MethodProvider/want/wire_gen.go diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 93fbda85..af12ec7b 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -425,6 +425,7 @@ type objectCache struct { type objRef struct { importPath string + recvType string name string } @@ -493,6 +494,10 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { pkgPath := obj.Pkg().Path() return oc.processExpr(oc.packages[pkgPath].TypesInfo, pkgPath, spec.Values[i], obj.Name()) case *types.Func: + sig := obj.Type().(*types.Signature) + if recv := sig.Recv(); recv != nil { + ref.recvType = recv.Type().String() + } return processFuncProvider(oc.fset, obj) default: return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} @@ -659,14 +664,42 @@ func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object { case *ast.Ident: return info.ObjectOf(expr) case *ast.SelectorExpr: - pkgName, ok := expr.X.(*ast.Ident) - if !ok { - return nil + x := astutil.Unparen(expr.X) + if star, ok := x.(*ast.StarExpr); ok { + x = star.X } - if _, ok := info.ObjectOf(pkgName).(*types.PkgName); !ok { + switch x := x.(type) { + case *ast.Ident: + switch obj := info.ObjectOf(x); obj.(type) { + case *types.PkgName: + case *types.TypeName: + named, ok := obj.Type().(*types.Named) + if !ok { + return nil + } + + t := named.Underlying() + if ptr, ok := t.(*types.Pointer); ok { + t = ptr.Elem() + } + default: + return nil + } + + return info.ObjectOf(expr.Sel) + case *ast.SelectorExpr: + pkgName, ok := x.X.(*ast.Ident) + if !ok { + return nil + } + if _, ok := info.ObjectOf(pkgName).(*types.PkgName); !ok { + return nil + } + return info.ObjectOf(expr.Sel) + default: return nil } - return info.ObjectOf(expr.Sel) + default: return nil } @@ -680,7 +713,17 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro if err != nil { return nil, []error{notePosition(fset.Position(fpos), fmt.Errorf("wrong signature for provider %s: %v", fn.Name(), err))} } + params := sig.Params() + if recv := sig.Recv(); recv != nil { + newParams := make([]*types.Var, params.Len()+1) + newParams[0] = recv + for i := 0; i < params.Len(); i++ { + newParams[i+1] = params.At(i) + } + params = types.NewTuple(newParams...) + } + provider := &Provider{ Pkg: fn.Pkg(), Name: fn.Name(), @@ -691,6 +734,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro HasCleanup: providerSig.cleanup, HasErr: providerSig.err, } + for i := 0; i < params.Len(); i++ { provider.Args[i] = ProviderInput{ Type: params.At(i).Type(), diff --git a/internal/wire/testdata/MethodProvider/foo/foo.go b/internal/wire/testdata/MethodProvider/foo/foo.go new file mode 100644 index 00000000..038aa747 --- /dev/null +++ b/internal/wire/testdata/MethodProvider/foo/foo.go @@ -0,0 +1,74 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + + "github.com/google/wire" +) + +func main() { + fmt.Println(injectFooBar()) +} + +type ( + Foo int + FooBar int + Baz int +) + +var Set = wire.NewSet( + newFooProvider, + fooProvider.provideFoo, + newFooBarProvider, + (*fooBarProvider).provideFooBar, + newBazProvider, + bazProvider.provideBaz, +) + +type fooProvider struct{} + +func newFooProvider() fooProvider { + return fooProvider{} +} + +func (f fooProvider) provideFoo() Foo { + return 40 +} + +type fooBarProvider struct{} + +func newFooBarProvider() *fooBarProvider { + return &fooBarProvider{} +} + +func (*fooBarProvider) provideFooBar(foo Foo) FooBar { + return FooBar(foo) + 1 +} + +type bazProvider interface { + provideBaz(f FooBar) Baz +} + +func newBazProvider() bazProvider { + return bazProviderImpl{} +} + +type bazProviderImpl struct{} + +func (bazProviderImpl) provideBaz(fooBar FooBar) Baz { + return Baz(fooBar) + 1 +} diff --git a/internal/wire/testdata/MethodProvider/foo/wire.go b/internal/wire/testdata/MethodProvider/foo/wire.go new file mode 100644 index 00000000..21af4bb3 --- /dev/null +++ b/internal/wire/testdata/MethodProvider/foo/wire.go @@ -0,0 +1,27 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build wireinject +// +build wireinject + +package main + +import ( + "github.com/google/wire" +) + +func injectFooBar() Baz { + wire.Build(Set) + return 0 +} diff --git a/internal/wire/testdata/MethodProvider/pkg b/internal/wire/testdata/MethodProvider/pkg new file mode 100644 index 00000000..f7a5c8ce --- /dev/null +++ b/internal/wire/testdata/MethodProvider/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/wire/testdata/MethodProvider/want/program_out.txt b/internal/wire/testdata/MethodProvider/want/program_out.txt new file mode 100644 index 00000000..d81cc071 --- /dev/null +++ b/internal/wire/testdata/MethodProvider/want/program_out.txt @@ -0,0 +1 @@ +42 diff --git a/internal/wire/testdata/MethodProvider/want/wire_gen.go b/internal/wire/testdata/MethodProvider/want/wire_gen.go new file mode 100644 index 00000000..20ef618a --- /dev/null +++ b/internal/wire/testdata/MethodProvider/want/wire_gen.go @@ -0,0 +1,19 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package main + +// Injectors from wire.go: + +func injectFooBar() Baz { + mainBazProvider := newBazProvider() + mainFooBarProvider := newFooBarProvider() + mainFooProvider := newFooProvider() + foo := provideFoo(mainFooProvider) + fooBar := provideFooBar(mainFooBarProvider, foo) + baz := provideBaz(mainBazProvider, fooBar) + return baz +}