diff --git a/flyteadmin/pkg/manager/impl/util/digests_test.go b/flyteadmin/pkg/manager/impl/util/digests_test.go index bbed2bbde8..ee3ea93d19 100644 --- a/flyteadmin/pkg/manager/impl/util/digests_test.go +++ b/flyteadmin/pkg/manager/impl/util/digests_test.go @@ -7,13 +7,13 @@ import ( "path/filepath" "testing" - "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/ptypes/duration" _struct "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/utils" ) var testLaunchPlanDigest = []byte{ @@ -92,7 +92,7 @@ func getCompiledWorkflow() (*core.CompiledWorkflowClosure, error) { if err != nil { return nil, err } - err = jsonpb.UnmarshalString(string(workflowJSON), &compiledWorkflow) + err = utils.UnmarshalBytesToPb(workflowJSON, &compiledWorkflow) if err != nil { return nil, err } diff --git a/flyteadmin/pkg/repositories/transformers/task_execution_test.go b/flyteadmin/pkg/repositories/transformers/task_execution_test.go index e1e0fd973e..5fc5430192 100644 --- a/flyteadmin/pkg/repositories/transformers/task_execution_test.go +++ b/flyteadmin/pkg/repositories/transformers/task_execution_test.go @@ -22,6 +22,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" + "github.com/flyteorg/flyte/flytestdlib/utils" ) var taskEventOccurredAt = time.Now().UTC() @@ -63,7 +64,7 @@ func transformMapToStructPB(t *testing.T, thing map[string]string) *structpb.Str } thingAsCustom := &structpb.Struct{} - if err := jsonpb.UnmarshalString(string(b), thingAsCustom); err != nil { + if err := utils.UnmarshalBytesToPb(b, thingAsCustom); err != nil { t.Fatal(t, err) } return thingAsCustom diff --git a/flytectl/cmd/get/node_execution.go b/flytectl/cmd/get/node_execution.go index 7b2e7a9e88..89c902ddbd 100644 --- a/flytectl/cmd/get/node_execution.go +++ b/flytectl/cmd/get/node_execution.go @@ -1,7 +1,6 @@ package get import ( - "bytes" "context" "fmt" "sort" @@ -13,7 +12,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/golang/protobuf/jsonpb" + "github.com/flyteorg/flyte/flytestdlib/utils" ) var nodeExecutionColumns = []printer.Column{ @@ -50,18 +49,13 @@ type TaskExecution struct { // MarshalJSON overridden method to json marshalling to use jsonpb func (in *TaskExecution) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - marshaller := jsonpb.Marshaler{} - if err := marshaller.Marshal(&buf, in.TaskExecution); err != nil { - return nil, err - } - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.TaskExecution) } // UnmarshalJSON overridden method to json unmarshalling to use jsonpb func (in *TaskExecution) UnmarshalJSON(b []byte) error { in.TaskExecution = &admin.TaskExecution{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.TaskExecution) + return utils.UnmarshalBytesToPb(b, in.TaskExecution) } type NodeExecution struct { @@ -70,18 +64,13 @@ type NodeExecution struct { // MarshalJSON overridden method to json marshalling to use jsonpb func (in *NodeExecution) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - marshaller := jsonpb.Marshaler{} - if err := marshaller.Marshal(&buf, in.NodeExecution); err != nil { - return nil, err - } - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.NodeExecution) } // UnmarshalJSON overridden method to json unmarshalling to use jsonpb func (in *NodeExecution) UnmarshalJSON(b []byte) error { *in = NodeExecution{} - return jsonpb.Unmarshal(bytes.NewReader(b), in) + return utils.UnmarshalBytesToPb(b, in.NodeExecution) } // NodeExecutionClosure forms a wrapper around admin.NodeExecution and also fetches the childnodes , task execs diff --git a/flytectl/pkg/visualize/graphviz_test.go b/flytectl/pkg/visualize/graphviz_test.go index 72dcd3a69b..9480ea3053 100644 --- a/flytectl/pkg/visualize/graphviz_test.go +++ b/flytectl/pkg/visualize/graphviz_test.go @@ -1,7 +1,6 @@ package visualize import ( - "bytes" "fmt" "io/ioutil" "testing" @@ -10,7 +9,6 @@ import ( "github.com/flyteorg/flyte/flytectl/pkg/visualize/mocks" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytestdlib/utils" - "github.com/golang/protobuf/jsonpb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -24,10 +22,8 @@ func TestRenderWorkflowBranch(t *testing.T) { r, err := ioutil.ReadFile(fmt.Sprintf("testdata/%s.json", s)) assert.NoError(t, err) - i := bytes.NewReader(r) - c := &core.CompiledWorkflowClosure{} - err = jsonpb.Unmarshal(i, c) + err = utils.UnmarshalBytesToPb(r, c) assert.NoError(t, err) b, err := RenderWorkflow(c) fmt.Println(b) diff --git a/flyteidl/clients/go/coreutils/literals.go b/flyteidl/clients/go/coreutils/literals.go index f3277d0886..3527ac246b 100644 --- a/flyteidl/clients/go/coreutils/literals.go +++ b/flyteidl/clients/go/coreutils/literals.go @@ -377,7 +377,8 @@ func MakeLiteralForSimpleType(t core.SimpleType, s string) (*core.Literal, error switch t { case core.SimpleType_STRUCT: st := &structpb.Struct{} - err := jsonpb.UnmarshalString(s, st) + unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true} + err := unmarshaler.Unmarshal(strings.NewReader(s), st) if err != nil { return nil, errors.Wrapf(err, "failed to load generic type as json.") } diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go index 51e36fe395..d437b47a0e 100755 --- a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go @@ -15,6 +15,7 @@ var jsonPbUnmarshaler = &jsonpb.Unmarshaler{ AllowUnknownFields: true, } +// Deprecated: Use flytestdlib/utils.UnmarshalStructToPb instead. func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error { if structObj == nil { return fmt.Errorf("nil Struct Object passed") @@ -32,6 +33,7 @@ func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error { return nil } +// Deprecated: Use flytestdlib/utils.MarshalPbToStruct instead. func MarshalStruct(in proto.Message, out *structpb.Struct) error { if out == nil { return fmt.Errorf("nil Struct Object passed") @@ -49,11 +51,12 @@ func MarshalStruct(in proto.Message, out *structpb.Struct) error { return nil } +// Deprecated: Use flytestdlib/utils.MarshalToString instead. func MarshalToString(msg proto.Message) (string, error) { return jsonPbMarshaler.MarshalToString(msg) } -// TODO: Use the stdlib version in the future, or move there if not there. +// Deprecated: Use flytestdlib/utils.MarshalObjToStruct instead. // Don't use this if input is a proto Message. func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { b, err := json.Marshal(input) @@ -69,6 +72,7 @@ func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { return structObj, nil } +// Deprecated: Use flytestdlib/utils.UnmarshalStructToObj instead. // Don't use this if the unmarshalled obj is a proto message. func UnmarshalStructToObj(structObj *structpb.Struct, obj interface{}) error { if structObj == nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index fdb3e74182..eba53067ef 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -7,7 +7,6 @@ import ( "time" daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" - "github.com/golang/protobuf/jsonpb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/protobuf/types/known/structpb" @@ -25,6 +24,7 @@ import ( pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils" ) const ( @@ -122,7 +122,7 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTem } structObj := structpb.Struct{} - err = jsonpb.UnmarshalString(daskJobJSON, &structObj) + err = stdlibUtils.UnmarshalStringToPb(daskJobJSON, &structObj) if err != nil { panic(err) } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 7db8269eaf..6c0080d45a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" mpiOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -29,6 +28,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" + stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils" ) const testImage = "image://" @@ -99,7 +99,7 @@ func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate { structObj := structpb.Struct{} - err = jsonpb.UnmarshalString(mpiObjJSON, &structObj) + err = stdlibUtils.UnmarshalStringToPb(mpiObjJSON, &structObj) if err != nil { panic(err) } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 6284b4d8f3..546b42d7df 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" commonOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -29,6 +28,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" + stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils" ) const testImage = "image://" @@ -105,7 +105,7 @@ func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate structObj := structpb.Struct{} - err = jsonpb.UnmarshalString(ptObjJSON, &structObj) + err = stdlibUtils.UnmarshalStringToPb(ptObjJSON, &structObj) if err != nil { panic(err) } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 8206bda130..d4d6e6da17 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" commonOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -29,6 +28,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" + stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils" ) const testImage = "image://" @@ -100,7 +100,7 @@ func dummyTensorFlowTaskTemplate(id string, args ...interface{}) *core.TaskTempl structObj := structpb.Struct{} - err = jsonpb.UnmarshalString(tfObjJSON, &structObj) + err = stdlibUtils.UnmarshalStringToPb(tfObjJSON, &structObj) if err != nil { panic(err) } diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index a560544228..7ea6c42be2 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -9,7 +9,6 @@ import ( sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -26,6 +25,7 @@ import ( pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils" ) const sparkMainClass = "MainClass" @@ -318,7 +318,7 @@ func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *co structObj := structpb.Struct{} - err = jsonpb.UnmarshalString(sparkJobJSON, &structObj) + err = stdlibUtils.UnmarshalStringToPb(sparkJobJSON, &structObj) if err != nil { panic(err) } @@ -346,7 +346,7 @@ func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec * structObj := structpb.Struct{} - err = jsonpb.UnmarshalString(sparkJobJSON, &structObj) + err = stdlibUtils.UnmarshalStringToPb(sparkJobJSON, &structObj) if err != nil { panic(err) } diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/create.go b/flytepropeller/cmd/kubectl-flyte/cmd/create.go index 24688ef7e2..2feeb8ec8e 100644 --- a/flytepropeller/cmd/kubectl-flyte/cmd/create.go +++ b/flytepropeller/cmd/kubectl-flyte/cmd/create.go @@ -1,14 +1,12 @@ package cmd import ( - "bytes" "context" "encoding/json" "fmt" "io/ioutil" "github.com/ghodss/yaml" - "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" "github.com/pkg/errors" "github.com/spf13/cobra" @@ -20,6 +18,7 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/common" compilerErrors "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/transformers/k8s" + "github.com/flyteorg/flyte/flytestdlib/utils" ) const ( @@ -89,7 +88,7 @@ func unmarshal(in []byte, format format, message proto.Message) (err error) { case formatProto: err = proto.Unmarshal(in, message) case formatJSON: - err = jsonpb.Unmarshal(bytes.NewReader(in), message) + err = utils.UnmarshalBytesToPb(in, message) if err != nil { err = errors.Wrapf(err, "Failed to unmarshal converted Json. [%v]", string(in)) } @@ -105,19 +104,12 @@ func unmarshal(in []byte, format format, message proto.Message) (err error) { return } -var jsonPbMarshaler = jsonpb.Marshaler{} - func marshal(message proto.Message, format format) (raw []byte, err error) { switch format { case formatProto: return proto.Marshal(message) case formatJSON: - b := &bytes.Buffer{} - err := jsonPbMarshaler.Marshal(b, message) - if err != nil { - return nil, errors.Wrapf(err, "Failed to marshal Json.") - } - return b.Bytes(), nil + return utils.MarshalPbToBytes(message) case formatYaml: b, err := marshal(message, formatJSON) if err != nil { diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch.go index 37a54dfffe..eb1fd2a6c0 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch.go @@ -1,11 +1,8 @@ package v1alpha1 import ( - "bytes" - - "github.com/golang/protobuf/jsonpb" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/utils" ) type BooleanExpression struct { @@ -13,20 +10,12 @@ type BooleanExpression struct { } func (in BooleanExpression) MarshalJSON() ([]byte, error) { - if in.BooleanExpression == nil { - return nilJSON, nil - } - - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.BooleanExpression); err != nil { - return nil, err - } - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.BooleanExpression) } func (in *BooleanExpression) UnmarshalJSON(b []byte) error { in.BooleanExpression = &core.BooleanExpression{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.BooleanExpression) + return utils.UnmarshalBytesToPb(b, in.BooleanExpression) } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error.go index 39ec19c165..6c72699980 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error.go @@ -1,11 +1,8 @@ package v1alpha1 import ( - "bytes" - - "github.com/golang/protobuf/jsonpb" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/utils" ) // Wrapper around core.Execution error. Execution Error has a protobuf enum and hence needs to be wrapped by custom marshaller @@ -13,20 +10,13 @@ type ExecutionError struct { *core.ExecutionError } -func (in *ExecutionError) UnmarshalJSON(b []byte) error { - in.ExecutionError = &core.ExecutionError{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.ExecutionError) +func (in *ExecutionError) MarshalJSON() ([]byte, error) { + return utils.MarshalPbToBytes(in.ExecutionError) } -func (in *ExecutionError) MarshalJSON() ([]byte, error) { - if in == nil { - return nilJSON, nil - } - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.ExecutionError); err != nil { - return nil, err - } - return buf.Bytes(), nil +func (in *ExecutionError) UnmarshalJSON(b []byte) error { + in.ExecutionError = &core.ExecutionError{} + return utils.UnmarshalBytesToPb(b, in.ExecutionError) } func (in *ExecutionError) DeepCopyInto(out *ExecutionError) { diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate.go index a7ffa799fa..670a18ddbd 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate.go @@ -1,11 +1,8 @@ package v1alpha1 import ( - "bytes" - - "github.com/golang/protobuf/jsonpb" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/utils" ) type ConditionKind string @@ -25,20 +22,12 @@ type ApproveCondition struct { } func (in ApproveCondition) MarshalJSON() ([]byte, error) { - if in.ApproveCondition == nil { - return nilJSON, nil - } - - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.ApproveCondition); err != nil { - return nil, err - } - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.ApproveCondition) } func (in *ApproveCondition) UnmarshalJSON(b []byte) error { in.ApproveCondition = &core.ApproveCondition{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.ApproveCondition) + return utils.UnmarshalBytesToPb(b, in.ApproveCondition) } type SignalCondition struct { @@ -46,20 +35,12 @@ type SignalCondition struct { } func (in SignalCondition) MarshalJSON() ([]byte, error) { - if in.SignalCondition == nil { - return nilJSON, nil - } - - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.SignalCondition); err != nil { - return nil, err - } - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.SignalCondition) } func (in *SignalCondition) UnmarshalJSON(b []byte) error { in.SignalCondition = &core.SignalCondition{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.SignalCondition) + return utils.UnmarshalBytesToPb(b, in.SignalCondition) } type SleepCondition struct { @@ -67,20 +48,12 @@ type SleepCondition struct { } func (in SleepCondition) MarshalJSON() ([]byte, error) { - if in.SleepCondition == nil { - return nilJSON, nil - } - - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.SleepCondition); err != nil { - return nil, err - } - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.SleepCondition) } func (in *SleepCondition) UnmarshalJSON(b []byte) error { in.SleepCondition = &core.SleepCondition{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.SleepCondition) + return utils.UnmarshalBytesToPb(b, in.SleepCondition) } type GateNodeSpec struct { diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier.go index 7d6a3622c8..e77838fdd4 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier.go @@ -1,28 +1,21 @@ package v1alpha1 import ( - "bytes" - - "github.com/golang/protobuf/jsonpb" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/utils" ) type Identifier struct { *core.Identifier } -func (in *Identifier) UnmarshalJSON(b []byte) error { - in.Identifier = &core.Identifier{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.Identifier) +func (in *Identifier) MarshalJSON() ([]byte, error) { + return utils.MarshalPbToBytes(in.Identifier) } -func (in *Identifier) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.Identifier); err != nil { - return nil, err - } - return buf.Bytes(), nil +func (in *Identifier) UnmarshalJSON(b []byte) error { + in.Identifier = &core.Identifier{} + return utils.UnmarshalBytesToPb(b, in.Identifier) } func (in *Identifier) DeepCopyInto(out *Identifier) { diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index ef402d724d..bcd1064e67 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -21,8 +21,6 @@ import ( //go:generate mockery -all -var nilJSON, _ = json.Marshal(nil) - type CustomState map[string]interface{} type WorkflowID = string type TaskID = string diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go index 6554357031..bb1db2453d 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go @@ -1,34 +1,26 @@ package v1alpha1 import ( - "bytes" "time" - "github.com/golang/protobuf/jsonpb" typesv1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/utils" ) -var marshaler = jsonpb.Marshaler{} - type OutputVarMap struct { *core.VariableMap } func (in *OutputVarMap) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.VariableMap); err != nil { - return nil, err - } - - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.VariableMap) } func (in *OutputVarMap) UnmarshalJSON(b []byte) error { in.VariableMap = &core.VariableMap{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.VariableMap) + return utils.UnmarshalBytesToPb(b, in.VariableMap) } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. @@ -42,18 +34,13 @@ type Binding struct { *core.Binding } -func (in *Binding) UnmarshalJSON(b []byte) error { - in.Binding = &core.Binding{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.Binding) -} - func (in *Binding) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.Binding); err != nil { - return nil, err - } + return utils.MarshalPbToBytes(in.Binding) +} - return buf.Bytes(), nil +func (in *Binding) UnmarshalJSON(b []byte) error { + in.Binding = &core.Binding{} + return utils.UnmarshalBytesToPb(b, in.Binding) } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. @@ -98,17 +85,12 @@ type ExtendedResources struct { } func (in *ExtendedResources) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.ExtendedResources); err != nil { - return nil, err - } - - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.ExtendedResources) } func (in *ExtendedResources) UnmarshalJSON(b []byte) error { in.ExtendedResources = &core.ExtendedResources{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.ExtendedResources) + return utils.UnmarshalBytesToPb(b, in.ExtendedResources) } func (in *ExtendedResources) DeepCopyInto(out *ExtendedResources) { diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks.go index 23b6b2cb47..17b747f0c6 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks.go @@ -1,11 +1,8 @@ package v1alpha1 import ( - "bytes" - - "github.com/golang/protobuf/jsonpb" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/utils" ) type TaskSpec struct { @@ -27,14 +24,10 @@ func (in *TaskSpec) DeepCopyInto(out *TaskSpec) { } func (in *TaskSpec) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.TaskTemplate); err != nil { - return nil, err - } - return buf.Bytes(), nil + return utils.MarshalPbToBytes(in.TaskTemplate) } func (in *TaskSpec) UnmarshalJSON(b []byte) error { in.TaskTemplate = &core.TaskTemplate{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.TaskTemplate) + return utils.UnmarshalBytesToPb(b, in.TaskTemplate) } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go index 1d45dc6578..22ed947f11 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go @@ -1,17 +1,16 @@ package v1alpha1 import ( - "bytes" "context" "encoding/json" - "github.com/golang/protobuf/jsonpb" "github.com/pkg/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytestdlib/storage" + "github.com/flyteorg/flyte/flytestdlib/utils" ) // Defines a non-configurable keyspace size for shard keys. This needs to be a small value because we use label @@ -194,21 +193,13 @@ type Inputs struct { *core.LiteralMap } -func (in *Inputs) UnmarshalJSON(b []byte) error { - in.LiteralMap = &core.LiteralMap{} - return jsonpb.Unmarshal(bytes.NewReader(b), in.LiteralMap) -} - func (in *Inputs) MarshalJSON() ([]byte, error) { - if in == nil || in.LiteralMap == nil { - return nilJSON, nil - } + return utils.MarshalPbToBytes(in.LiteralMap) +} - var buf bytes.Buffer - if err := marshaler.Marshal(&buf, in.LiteralMap); err != nil { - return nil, err - } - return buf.Bytes(), nil +func (in *Inputs) UnmarshalJSON(b []byte) error { + in.LiteralMap = &core.LiteralMap{} + return utils.UnmarshalBytesToPb(b, in.LiteralMap) } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. diff --git a/flytepropeller/pkg/compiler/test/compiler_test.go b/flytepropeller/pkg/compiler/test/compiler_test.go index ae0322b66b..355fc4a15b 100644 --- a/flytepropeller/pkg/compiler/test/compiler_test.go +++ b/flytepropeller/pkg/compiler/test/compiler_test.go @@ -1,7 +1,6 @@ package test import ( - "bytes" "encoding/json" "flag" "io/ioutil" @@ -27,6 +26,7 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/transformers/k8s" "github.com/flyteorg/flyte/flytepropeller/pkg/visualize" + "github.com/flyteorg/flyte/flytestdlib/utils" ) var update = flag.Bool("update", false, "Update .golden files") @@ -117,7 +117,7 @@ func TestDynamic(t *testing.T) { raw, err := ioutil.ReadFile(path) assert.NoError(t, err) wf := &core.DynamicJobSpec{} - err = jsonpb.UnmarshalString(string(raw), wf) + err = utils.UnmarshalBytesToPb(raw, wf) if !assert.NoError(t, err) { t.FailNow() } @@ -362,8 +362,7 @@ func runCompileTest(t *testing.T, dirName string) { taskBytes, err := os.ReadFile(taskFile) assert.NoError(t, err) compiledTaskFromFile := &core.CompiledTask{} - reader := bytes.NewReader(taskBytes) - err = jsonpb.Unmarshal(reader, compiledTaskFromFile) + err = utils.UnmarshalBytesToPb(taskBytes, compiledTaskFromFile) assert.NoError(t, err) assert.True(t, proto.Equal(task, compiledTaskFromFile)) }) @@ -440,7 +439,7 @@ func runCompileTest(t *testing.T, dirName string) { } compiledWfc := &core.CompiledWorkflowClosure{} - if !assert.NoError(t, jsonpb.UnmarshalString(string(raw), compiledWfc)) { + if !assert.NoError(t, utils.UnmarshalBytesToPb(raw, compiledWfc)) { t.FailNow() } diff --git a/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go b/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go index ae3b82b8b9..dbb51e25eb 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go @@ -1,11 +1,9 @@ package k8s import ( - "bytes" "io/ioutil" "testing" - "github.com/golang/protobuf/jsonpb" "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/util/sets" @@ -13,6 +11,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/errors" + "github.com/flyteorg/flyte/flytestdlib/utils" ) func createSampleMockWorkflow() *mockWorkflow { @@ -329,10 +328,8 @@ func TestBuildFlyteWorkflow_withBranch(t *testing.T) { c, err := ioutil.ReadFile("testdata/compiled_closure_branch_nested.json") assert.NoError(t, err) - r := bytes.NewReader(c) - w := &core.CompiledWorkflowClosure{} - assert.NoError(t, jsonpb.Unmarshal(r, w)) + assert.NoError(t, utils.UnmarshalBytesToPb(c, w)) assert.Len(t, w.Primary.Connections.Downstream, 2) ids := w.Primary.Connections.Downstream["start-node"] diff --git a/flytestdlib/utils/marshal_utils.go b/flytestdlib/utils/marshal_utils.go index 3f68a1667d..555daacb75 100644 --- a/flytestdlib/utils/marshal_utils.go +++ b/flytestdlib/utils/marshal_utils.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "strings" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" @@ -12,6 +13,9 @@ import ( ) var jsonPbMarshaler = jsonpb.Marshaler{} +var jsonPbUnmarshaler = jsonpb.Unmarshaler{ + AllowUnknownFields: true, +} // UnmarshalStructToPb unmarshals a proto struct into a proto message using jsonPb marshaler. func UnmarshalStructToPb(structObj *structpb.Struct, msg proto.Message) error { @@ -28,7 +32,7 @@ func UnmarshalStructToPb(structObj *structpb.Struct, msg proto.Message) error { return errors.WithMessage(err, "Failed to marshal strcutObj input") } - if err = jsonpb.UnmarshalString(jsonObj, msg); err != nil { + if err = UnmarshalStringToPb(jsonObj, msg); err != nil { return errors.WithMessage(err, "Failed to unmarshal json obj into proto") } @@ -47,7 +51,7 @@ func MarshalPbToStruct(in proto.Message) (out *structpb.Struct, err error) { } out = &structpb.Struct{} - if err = jsonpb.Unmarshal(bytes.NewReader(buf.Bytes()), out); err != nil { + if err = UnmarshalBytesToPb(buf.Bytes(), out); err != nil { return nil, errors.WithMessage(err, "Failed to unmarshal json object into struct") } @@ -59,6 +63,26 @@ func MarshalPbToString(msg proto.Message) (string, error) { return jsonPbMarshaler.MarshalToString(msg) } +// UnmarshalStringToPb unmarshals a string to a proto message +func UnmarshalStringToPb(s string, msg proto.Message) error { + return jsonPbUnmarshaler.Unmarshal(strings.NewReader(s), msg) +} + +// MarshalPbToBytes marshals a proto message to a byte slice +func MarshalPbToBytes(msg proto.Message) ([]byte, error) { + var buf bytes.Buffer + err := jsonPbMarshaler.Marshal(&buf, msg) + if err != nil { + return nil, err + } + return buf.Bytes(), err +} + +// UnmarshalBytesToPb unmarshals a byte slice to a proto message +func UnmarshalBytesToPb(b []byte, msg proto.Message) error { + return jsonPbUnmarshaler.Unmarshal(bytes.NewReader(b), msg) +} + // MarshalObjToStruct marshals obj into a struct. Will use jsonPb if input is a proto message, otherwise, it'll use json // marshaler. func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { @@ -73,7 +97,7 @@ func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { // Turn JSON into a protobuf struct structObj := &structpb.Struct{} - if err := jsonpb.Unmarshal(bytes.NewReader(b), structObj); err != nil { + if err := UnmarshalBytesToPb(b, structObj); err != nil { return nil, errors.WithMessage(err, "Failed to unmarshal json object into struct") } diff --git a/flytestdlib/utils/marshal_utils_test.go b/flytestdlib/utils/marshal_utils_test.go index 482e0c0a72..f9b51e420f 100644 --- a/flytestdlib/utils/marshal_utils_test.go +++ b/flytestdlib/utils/marshal_utils_test.go @@ -184,3 +184,56 @@ func TestUnmarshalStructToObj(t *testing.T) { } }) } + +func TestMarshalPbToBytes(t *testing.T) { + type args struct { + msg proto.Message + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + {"empty", args{msg: &prototest.TestProto{}}, []byte("{}"), false}, + {"has value", args{msg: &prototest.TestProto{StringValue: "hello"}}, []byte(`{"stringValue":"hello"}`), false}, + {"nil input", args{msg: nil}, []byte(nil), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MarshalPbToBytes(tt.args.msg) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalPbToBytes() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got, "MarshalPbToBytes() = %v, want %v", got, tt.want) + }) + } +} + +func TestUnmarshalBytesToPb(t *testing.T) { + type args struct { + b []byte + } + tests := []struct { + name string + args args + want proto.Message + wantErr bool + }{ + {"empty", args{b: []byte("{}")}, &prototest.TestProto{}, false}, + {"has value", args{b: []byte(`{"stringValue":"hello"}`)}, &prototest.TestProto{StringValue: "hello"}, false}, + {"nil input", args{b: []byte(nil)}, &prototest.TestProto{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &prototest.TestProto{} + err := UnmarshalBytesToPb(tt.args.b, m) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalBytesToPb() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.True(t, proto.Equal(tt.want, m), "UnmarshalBytesToPb() = %v, want %v", m, tt.want) + }) + } +}