Skip to content

Commit

Permalink
fix: handle nested pydantic basemodel
Browse files Browse the repository at this point in the history
Signed-off-by: mao3267 <[email protected]>
  • Loading branch information
mao3267 committed Nov 8, 2024
1 parent 818afb7 commit d6468b6
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions flytepropeller/pkg/compiler/validators/typing.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,36 @@ type trivialChecker struct {
literalType *flyte.LiteralType
}

func removeTitleFieldFromProperties(schema map[string]interface{}) {
properties, ok := schema["properties"].(*structpb.Value)
func removeTitleFieldFromProperties(schema map[string]*structpb.Value) {
properties, ok := schema["properties"]
if !ok {
return
}

Check warning on line 25 in flytepropeller/pkg/compiler/validators/typing.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/typing.go#L24-L25

Added lines #L24 - L25 were not covered by tests

for _, p := range properties.GetStructValue().Fields {
if _, ok := p.GetStructValue().Fields["properties"]; ok {
removeTitleFieldFromProperties(p.GetStructValue().Fields)
}
delete(p.GetStructValue().Fields, "title")
}
}

func resolveRef(schema, defs map[string]*structpb.Value) {
// Schema from Pydantic BaseModel includes a $def field, which is a reference to the actual schema.
// We need to resolve the reference to compare the schema with those from marshumaro.
// https://github.com/flyteorg/flytekit/blob/3475ddc41f2ba31d23dd072362be704d7c2470a0/flytekit/core/type_engine.py#L632-L641
for _, p := range schema["properties"].GetStructValue().Fields {
if _, ok := p.GetStructValue().Fields["$ref"]; ok {
propName := strings.TrimPrefix(p.GetStructValue().Fields["$ref"].GetStringValue(), "#/$defs/")
p.GetStructValue().Fields = defs[propName].GetStructValue().Fields
resolveRef(p.GetStructValue().Fields, defs)
delete(p.GetStructValue().Fields, "$ref")
}

Check warning on line 45 in flytepropeller/pkg/compiler/validators/typing.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/typing.go#L35-L45

Added lines #L35 - L45 were not covered by tests
}

delete(schema, "$defs")

Check warning on line 48 in flytepropeller/pkg/compiler/validators/typing.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/typing.go#L48

Added line #L48 was not covered by tests
}

func isSuperTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool {
// Since there are lots of field differences between draft-07 and draft 2020-12,
// we only support json schema with 2020-12 draft, which is generated here: https://github.com/flyteorg/flytekit/blob/ff2d0da686c82266db4dbf764a009896cf062349/flytekit/core/type_engine.py#L630-L639
Expand All @@ -40,8 +59,8 @@ func isSuperTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool {
return false
}

Check warning on line 60 in flytepropeller/pkg/compiler/validators/typing.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/typing.go#L59-L60

Added lines #L59 - L60 were not covered by tests

copySrcSchema := make(map[string]interface{})
copyTgtSchema := make(map[string]interface{})
copySrcSchema := make(map[string]*structpb.Value)
copyTgtSchema := make(map[string]*structpb.Value)

for k, v := range sourceMetaData.Fields {
copySrcSchema[k] = v
Expand All @@ -51,6 +70,13 @@ func isSuperTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool {
copyTgtSchema[k] = v
}

// For nested Pydantic BaseModel, we need to resolve the reference to compare the schema.
if _, ok := copySrcSchema["$defs"]; ok {
resolveRef(copySrcSchema, copySrcSchema["$defs"].GetStructValue().Fields)
}

Check warning on line 76 in flytepropeller/pkg/compiler/validators/typing.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/typing.go#L75-L76

Added lines #L75 - L76 were not covered by tests
if _, ok := copyTgtSchema["$defs"]; ok {
resolveRef(copyTgtSchema, copyTgtSchema["$defs"].GetStructValue().Fields)
}

Check warning on line 79 in flytepropeller/pkg/compiler/validators/typing.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/typing.go#L78-L79

Added lines #L78 - L79 were not covered by tests
// The JSON schema generated by Pydantic.BaseModel includes a title field in its properties, repeatedly recording the property name.
// Since this title field is absent in the JSON schema generated for dataclass, we need to remove the title field from the properties to ensure equivalence.
removeTitleFieldFromProperties(copySrcSchema)
Expand All @@ -63,7 +89,7 @@ func isSuperTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool {
for _, p := range patch {
// If additionalProperties is false, the field is not present in the schema from Pydantic.BaseModel.
// We handle this case by checking the relationships by ourselves.
if p.Type != jsondiff.OperationAdd && p.Path == "/additionalProperties" {
if p.Type != jsondiff.OperationAdd && strings.Contains(p.Path, "additionalProperties") {
if p.Type == jsondiff.OperationRemove || p.Type == jsondiff.OperationReplace {
if p.OldValue != false {
return false
Expand All @@ -89,8 +115,8 @@ func isSameTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool {
return false
}

Check warning on line 116 in flytepropeller/pkg/compiler/validators/typing.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/typing.go#L115-L116

Added lines #L115 - L116 were not covered by tests

copySrcSchema := make(map[string]interface{})
copyTgtSchema := make(map[string]interface{})
copySrcSchema := make(map[string]*structpb.Value)
copyTgtSchema := make(map[string]*structpb.Value)

for k, v := range sourceMetaData.Fields {
copySrcSchema[k] = v
Expand Down

0 comments on commit d6468b6

Please sign in to comment.