Skip to content

Commit

Permalink
feat(loader): Adding a loading method that loads the diff of two stru…
Browse files Browse the repository at this point in the history
…cts into one (#10)

* Adding a diff loader

* Adding tests
  • Loading branch information
Jacobbrewer1 authored Oct 7, 2024
1 parent 887a0fc commit 1a94798
Show file tree
Hide file tree
Showing 2 changed files with 428 additions and 0 deletions.
87 changes: 87 additions & 0 deletions loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package patcher

import (
"errors"
"reflect"
)

var (
// ErrInvalidType is returned when the provided type is not a pointer to a struct
ErrInvalidType = errors.New("invalid type")

// ErrInvalidFieldType is returned when the provided field type is not a struct
ErrInvalidFieldType = errors.New("invalid field type")
)

// LoadDiff inserts the fields provided in the new object into the old object and returns the result.
//
// This can be if you are inserting a patch into an existing object but require a new object to be returned with
// all fields.
func LoadDiff[T any](old T, newT T) error {
return loadDiff(old, newT)
}

func loadDiff[T any](old T, newT T) error {
orv := reflect.ValueOf(old)
if orv.Kind() != reflect.Ptr || orv.IsNil() {
return ErrInvalidType
}

nrv := reflect.ValueOf(newT)
if nrv.Kind() != reflect.Ptr || nrv.IsNil() {
return ErrInvalidType
}

oElem := orv.Elem()
nElem := nrv.Elem()

if oElem.Kind() != reflect.Struct || nElem.Kind() != reflect.Struct {
return ErrInvalidFieldType
}

for i := 0; i < oElem.NumField(); i++ {
// Include only exported fields
if !oElem.Field(i).CanSet() || !nElem.Field(i).CanSet() {
continue
}

// Handle embedded structs (Anonymous fields)
if oElem.Type().Field(i).Anonymous {
// If the embedded field is a pointer, dereference it
if oElem.Field(i).Kind() == reflect.Ptr {
if !oElem.Field(i).IsNil() && !nElem.Field(i).IsNil() {
if err := loadDiff(oElem.Field(i).Interface(), nElem.Field(i).Interface()); err != nil {
return err
}
} else if nElem.Field(i).IsValid() && !nElem.Field(i).IsNil() {
oElem.Field(i).Set(nElem.Field(i))
}

continue
}

if err := loadDiff(oElem.Field(i).Addr().Interface(), nElem.Field(i).Addr().Interface()); err != nil {
return err
}
continue
}

// If the field is a struct, we need to recursively call LoadDiff
if oElem.Field(i).Kind() == reflect.Struct {
if err := loadDiff(oElem.Field(i).Addr().Interface(), nElem.Field(i).Addr().Interface()); err != nil {
return err
}
continue
}

// Compare the old and new fields.
//
// New fields take priority over old fields if they are provided. We ignore zero values as they are not
// provided in the new object.
if !nElem.Field(i).IsZero() {
oElem.Field(i).Set(nElem.Field(i))
}
}

return nil
}
Loading

0 comments on commit 1a94798

Please sign in to comment.