diff --git a/inject.go b/inject.go index 300b9a3..9fdd15d 100644 --- a/inject.go +++ b/inject.go @@ -254,11 +254,6 @@ StructLoop: ) } - // Don't overwrite existing values. - if !isNilOrZero(field, fieldType) { - continue - } - // Named injects must have been explicitly provided. if tag.Name != "" { existing := g.named[tag.Name] @@ -326,7 +321,19 @@ StructLoop: } // Interface injection is handled in a second pass. + if fieldType.Kind() == reflect.Interface && isNilOrZero(field, fieldType) { + continue + } + if fieldType.Kind() == reflect.Interface { + err := g.Provide(&Object{ + Value: field.Elem().Interface(), + private: true, + embedded: o.reflectType.Elem().Field(i).Anonymous, + }) + if err != nil { + return err + } continue } @@ -383,6 +390,10 @@ StructLoop: } } + if !tag.Override && !isNilOrZero(field, fieldType) { + continue + } + newValue := reflect.New(fieldType.Elem()) newObject := &Object{ Value: newValue.Interface(), @@ -453,11 +464,6 @@ func (g *Graph) populateUnnamedInterface(o *Object) error { ) } - // Don't overwrite existing values. - if !isNilOrZero(field, fieldType) { - continue - } - // Named injects must have already been handled in populateExplicit. if tag.Name != "" { panic(fmt.Sprintf("unhandled named instance with name %s", tag.Name)) @@ -482,6 +488,9 @@ func (g *Graph) populateUnnamedInterface(o *Object) error { existing.reflectValue, ) } + if !tag.Override && !isNilOrZero(field, fieldType) { + continue + } found = existing field.Set(reflect.ValueOf(existing.Value)) if g.Logger != nil { @@ -497,13 +506,14 @@ func (g *Graph) populateUnnamedInterface(o *Object) error { } // If we didn't find an assignable value, we're missing something. - if found == nil { + if found == nil && isNilOrZero(field, fieldType) { return fmt.Errorf( "found no assignable value for field %s in type %s", o.reflectType.Elem().Field(i).Name, o.reflectType, ) } + } return nil } @@ -531,15 +541,17 @@ func (g *Graph) Objects() []*Object { } var ( - injectOnly = &tag{} - injectPrivate = &tag{Private: true} - injectInline = &tag{Inline: true} + injectOnly = &tag{} + injectPrivate = &tag{Private: true} + injectInline = &tag{Inline: true} + injectOverride = &tag{Override: true} ) type tag struct { - Name string - Inline bool - Private bool + Name string + Inline bool + Private bool + Override bool } func parseTag(t string) (*tag, error) { @@ -550,14 +562,15 @@ func parseTag(t string) (*tag, error) { if !found { return nil, nil } - if value == "" { + switch value { + case "": return injectOnly, nil - } - if value == "inline" { + case "inline": return injectInline, nil - } - if value == "private" { + case "private": return injectPrivate, nil + case "override": + return injectOverride, nil } return &tag{Name: value}, nil } diff --git a/inject_test.go b/inject_test.go index 6433eed..148b131 100644 --- a/inject_test.go +++ b/inject_test.go @@ -3,6 +3,7 @@ package inject_test import ( "fmt" "math/rand" + "reflect" "strings" "testing" "time" @@ -114,6 +115,44 @@ func TestInjectSimple(t *testing.T) { } } +func TestInjectOverride(t *testing.T) { + var v struct { + A *TypeAnswerStruct `inject:""` + B *TypeAnswerStruct `inject:"override"` + } + olda, oldb := &TypeAnswerStruct{}, &TypeAnswerStruct{} + v.A, v.B = olda, oldb + if err := inject.Populate(&v); err != nil { + t.Fatal(err) + } + if v.A != olda { + t.Fatal("original A was lost") + } + if v.B == oldb { + t.Fatal("original B was not overridden") + } +} + +type TypeNestedInterfaceStruct struct { + A Answerable `inject:""` +} + +func TestNonEmptyInterfaceTraversal(t *testing.T) { + olda := &TypeNestedStruct{} + v := TypeNestedInterfaceStruct{ + A: olda, + } + if err := inject.Populate(&v); err != nil { + t.Fatal(err) + } + if v.A != olda { + t.Fatal("original A was lost") + } + if olda.A == nil { + t.Fatal("v.A.A is nil") + } +} + func TestDoesNotOverwrite(t *testing.T) { a := &TypeAnswerStruct{} var v struct { @@ -238,7 +277,7 @@ func TestProvideTwoOfTheSame(t *testing.T) { t.Fatal("expected error") } - const msg = "provided two unnamed instances of type *github.com/facebookgo/inject_test.TypeAnswerStruct" + msg := fmt.Sprintf("provided two unnamed instances of type *%s.TypeAnswerStruct", reflect.TypeOf(a).PkgPath()) if err.Error() != msg { t.Fatalf("expected:\n%s\nactual:\n%s", msg, err.Error()) } @@ -251,7 +290,7 @@ func TestProvideTwoOfTheSameWithPopulate(t *testing.T) { t.Fatal("expected error") } - const msg = "provided two unnamed instances of type *github.com/facebookgo/inject_test.TypeAnswerStruct" + msg := fmt.Sprintf("provided two unnamed instances of type *%s.TypeAnswerStruct", reflect.TypeOf(a).PkgPath()) if err.Error() != msg { t.Fatalf("expected:\n%s\nactual:\n%s", msg, err.Error()) }