Skip to content

Commit

Permalink
Improve interface assertions to allow val/ptr receivers without data …
Browse files Browse the repository at this point in the history
…loss (#112)
  • Loading branch information
vearutop authored Mar 6, 2024
1 parent fbc1e0d commit 9575eb9
Show file tree
Hide file tree
Showing 3 changed files with 599 additions and 36 deletions.
140 changes: 106 additions & 34 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,13 @@ func checkSchemaSetup(params InterceptSchemaParams) (bool, error) {
v := params.Value
s := params.Schema

vi := v.Interface()
if v.Kind() == reflect.Ptr && v.IsNil() {
vi = reflect.New(v.Type().Elem()).Interface()
}

vpi := reflect.New(v.Type()).Interface()

reflectEnum(s, "", vi)
reflectEnum(s, "", v.Interface())

var e Exposer

if exposer, ok := vi.(Exposer); ok {
e = exposer
}

if exposer, ok := vi.(Exposer); ok {
if exposer, ok := safeInterface(v).(Exposer); ok {
e = exposer
} else if exposer, ok := vpi.(Exposer); ok {
} else if exposer, ok := ptrTo(v).(Exposer); ok {
e = exposer
}

Expand All @@ -164,9 +153,9 @@ func checkSchemaSetup(params InterceptSchemaParams) (bool, error) {
var re RawExposer

// Checking if RawExposer is defined on a current value.
if exposer, ok := vi.(RawExposer); ok {
if exposer, ok := safeInterface(v).(RawExposer); ok {
re = exposer
} else if exposer, ok := vpi.(RawExposer); ok { // Checking if RawExposer is defined on a pointer to current value.
} else if exposer, ok := ptrTo(v).(RawExposer); ok { // Checking if RawExposer is defined on a pointer to current value.
re = exposer
}

Expand All @@ -176,11 +165,15 @@ func checkSchemaSetup(params InterceptSchemaParams) (bool, error) {
return true, err
}

err = json.Unmarshal(schemaBytes, s)
var rs Schema

err = json.Unmarshal(schemaBytes, &rs)
if err != nil {
return true, err
}

*s = rs

return true, nil
}

Expand Down Expand Up @@ -389,6 +382,8 @@ func (r *Reflector) reflectDefer(defName string, typeString refl.TypeString, rc
func (r *Reflector) checkTitle(v reflect.Value, s *Struct, schema *Schema) {
if vd, ok := safeInterface(v).(Described); ok {
schema.WithDescription(vd.Description())
} else if vd, ok := ptrTo(v).(Described); ok {
schema.WithDescription(vd.Description())
}

if s != nil && s.Description != nil {
Expand All @@ -397,6 +392,8 @@ func (r *Reflector) checkTitle(v reflect.Value, s *Struct, schema *Schema) {

if vt, ok := safeInterface(v).(Titled); ok {
schema.WithTitle(vt.Title())
} else if vt, ok := ptrTo(v).(Titled); ok {
schema.WithTitle(vt.Title())
}

if s != nil && s.Title != nil {
Expand Down Expand Up @@ -535,6 +532,10 @@ func (r *Reflector) reflect(i interface{}, rc *ReflectContext, keepType bool, pa
if preparer, ok := safeInterface(v).(Preparer); ok {
err := preparer.PrepareJSONSchema(sp)

return schema, err
} else if preparer, ok := ptrTo(v).(Preparer); ok {
err := preparer.PrepareJSONSchema(sp)

return schema, err
}

Expand All @@ -556,25 +557,48 @@ func checkTextMarshaler(t reflect.Type, schema *Schema) bool {
}

func safeInterface(v reflect.Value) interface{} {
if !v.IsValid() {
return nil
}

if v.Kind() == reflect.Ptr && !v.Elem().IsValid() {
v = reflect.New(v.Type())
v = reflect.New(v.Type().Elem())
}

return v.Interface()
}

func ptrTo(v reflect.Value) interface{} {
if !v.IsValid() {
return nil
}

rd := reflect.New(v.Type())
rd.Elem().Set(v)

return rd.Interface()
}

func (r *Reflector) applySubSchemas(v reflect.Value, rc *ReflectContext, schema *Schema) error {
vi := safeInterface(v)
vp := ptrTo(v)

var oe OneOfExposer
if e, ok := vi.(OneOfExposer); ok {
oe = e
} else if e, ok := vp.(OneOfExposer); ok {
oe = e
}

if oe != nil {
var schemas []SchemaOrBool

for _, item := range e.JSONSchemaOneOf() {
for _, item := range oe.JSONSchemaOneOf() {
rc.Path = append(rc.Path, "oneOf")

s, err := r.reflect(item, rc, false, schema)
if err != nil {
return fmt.Errorf("failed to reflect 'oneOf' values of %T: %w", vi, err)
return fmt.Errorf("failed to reflect 'oneOf' values of %T: %w", oe, err)
}

schemas = append(schemas, s.ToSchemaOrBool())
Expand All @@ -583,15 +607,22 @@ func (r *Reflector) applySubSchemas(v reflect.Value, rc *ReflectContext, schema
schema.OneOf = schemas
}

var ane AnyOfExposer
if e, ok := vi.(AnyOfExposer); ok {
ane = e
} else if e, ok := vp.(AnyOfExposer); ok {
ane = e
}

if ane != nil {
var schemas []SchemaOrBool

for _, item := range e.JSONSchemaAnyOf() {
for _, item := range ane.JSONSchemaAnyOf() {
rc.Path = append(rc.Path, "anyOf")

s, err := r.reflect(item, rc, false, schema)
if err != nil {
return fmt.Errorf("failed to reflect 'anyOf' values of %T: %w", vi, err)
return fmt.Errorf("failed to reflect 'anyOf' values of %T: %w", ane, err)
}

schemas = append(schemas, s.ToSchemaOrBool())
Expand All @@ -600,15 +631,22 @@ func (r *Reflector) applySubSchemas(v reflect.Value, rc *ReflectContext, schema
schema.AnyOf = schemas
}

var ale AllOfExposer
if e, ok := vi.(AllOfExposer); ok {
ale = e
} else if e, ok := vp.(AllOfExposer); ok {
ale = e
}

if ale != nil {
var schemas []SchemaOrBool

for _, item := range e.JSONSchemaAllOf() {
for _, item := range ale.JSONSchemaAllOf() {
rc.Path = append(rc.Path, "allOf")

s, err := r.reflect(item, rc, false, schema)
if err != nil {
return fmt.Errorf("failed to reflect 'allOf' values of %T: %w", vi, err)
return fmt.Errorf("failed to reflect 'allOf' values of %T: %w", ale, err)
}

schemas = append(schemas, s.ToSchemaOrBool())
Expand All @@ -617,45 +655,73 @@ func (r *Reflector) applySubSchemas(v reflect.Value, rc *ReflectContext, schema
schema.AllOf = schemas
}

var ne NotExposer
if e, ok := vi.(NotExposer); ok {
ne = e
} else if e, ok := vp.(NotExposer); ok {
ne = e
}

if ne != nil {
rc.Path = append(rc.Path, "not")

s, err := r.reflect(e.JSONSchemaNot(), rc, false, schema)
s, err := r.reflect(ne.JSONSchemaNot(), rc, false, schema)
if err != nil {
return fmt.Errorf("failed to reflect 'not' value of %T: %w", vi, err)
return fmt.Errorf("failed to reflect 'not' value of %T: %w", ne, err)
}

schema.WithNot(s.ToSchemaOrBool())
}

var ie IfExposer
if e, ok := vi.(IfExposer); ok {
ie = e
} else if e, ok := vp.(IfExposer); ok {
ie = e
}

if ie != nil {
rc.Path = append(rc.Path, "if")

s, err := r.reflect(e.JSONSchemaIf(), rc, false, schema)
s, err := r.reflect(ie.JSONSchemaIf(), rc, false, schema)
if err != nil {
return fmt.Errorf("failed to reflect 'if' value of %T: %w", vi, err)
return fmt.Errorf("failed to reflect 'if' value of %T: %w", ie, err)
}

schema.WithIf(s.ToSchemaOrBool())
}

var te ThenExposer
if e, ok := vi.(ThenExposer); ok {
te = e
} else if e, ok := vp.(ThenExposer); ok {
te = e
}

if te != nil {
rc.Path = append(rc.Path, "if")

s, err := r.reflect(e.JSONSchemaThen(), rc, false, schema)
s, err := r.reflect(te.JSONSchemaThen(), rc, false, schema)
if err != nil {
return fmt.Errorf("failed to reflect 'then' value of %T: %w", vi, err)
return fmt.Errorf("failed to reflect 'then' value of %T: %w", te, err)
}

schema.WithThen(s.ToSchemaOrBool())
}

var ee ElseExposer
if e, ok := vi.(ElseExposer); ok {
ee = e
} else if e, ok := vp.(ElseExposer); ok {
ee = e
}

if ee != nil {
rc.Path = append(rc.Path, "if")

s, err := r.reflect(e.JSONSchemaElse(), rc, false, schema)
s, err := r.reflect(ee.JSONSchemaElse(), rc, false, schema)
if err != nil {
return fmt.Errorf("failed to reflect 'else' value of %T: %w", vi, err)
return fmt.Errorf("failed to reflect 'else' value of %T: %w", ee, err)
}

schema.WithElse(s.ToSchemaOrBool())
Expand Down Expand Up @@ -1359,11 +1425,17 @@ type enum struct {

// loadFromField loads enum from field tag: json array or comma-separated string.
func (enum *enum) loadFromField(fieldTag reflect.StructTag, fieldVal interface{}) {
if e, isEnumer := fieldVal.(NamedEnum); isEnumer {
fv := reflect.ValueOf(fieldVal)

if e, isEnumer := safeInterface(fv).(NamedEnum); isEnumer {
enum.items, enum.names = e.NamedEnum()
} else if e, isEnumer := ptrTo(fv).(NamedEnum); isEnumer {
enum.items, enum.names = e.NamedEnum()
}

if e, isEnumer := fieldVal.(Enum); isEnumer {
if e, isEnumer := safeInterface(fv).(Enum); isEnumer {
enum.items = e.Enum()
} else if e, isEnumer := ptrTo(fv).(Enum); isEnumer {
enum.items = e.Enum()
}

Expand Down
Loading

0 comments on commit 9575eb9

Please sign in to comment.