From 3cfb99d9874736c1240e60b698857aa8d218c3e3 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Sat, 27 Jul 2024 04:24:02 -0700 Subject: [PATCH] Implement physical plan serialization for json Copy plans (#11645) --- .../core/src/datasource/file_format/json.rs | 3 +- .../proto/datafusion_common.proto | 5 + .../proto-common/src/generated/pbjson.rs | 91 +++++++++++++++++++ .../proto-common/src/generated/prost.rs | 6 ++ datafusion/proto/proto/datafusion.proto | 1 + .../src/generated/datafusion_proto_common.rs | 6 ++ datafusion/proto/src/generated/pbjson.rs | 13 +++ datafusion/proto/src/generated/prost.rs | 4 +- datafusion/proto/src/lib.rs | 2 +- .../proto/src/logical_plan/file_formats.rs | 69 ++++++++++++-- datafusion/proto/src/logical_plan/mod.rs | 25 ++++- .../tests/cases/roundtrip_logical_plan.rs | 84 ++++++++++++++++- 12 files changed, 296 insertions(+), 13 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 9de9c3d7d871..7c579e890c8c 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -57,7 +57,8 @@ use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; #[derive(Default)] /// Factory struct used to create [JsonFormat] pub struct JsonFormatFactory { - options: Option, + /// the options carried by format factory + pub options: Option, } impl JsonFormatFactory { diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 8e8fd2352c6c..85983dddf6ae 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -51,6 +51,11 @@ message ParquetFormat { message AvroFormat {} +message NdJsonFormat { + JsonOptions options = 1; +} + + message PrimaryKeyConstraint{ repeated uint64 indices = 1; } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 511072f3cb55..4ac6517ed739 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -4642,6 +4642,97 @@ impl<'de> serde::Deserialize<'de> for Map { deserializer.deserialize_struct("datafusion_common.Map", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for NdJsonFormat { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.NdJsonFormat", len)?; + if let Some(v) = self.options.as_ref() { + struct_ser.serialize_field("options", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NdJsonFormat { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "options", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Options, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "options" => Ok(GeneratedField::Options), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NdJsonFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.NdJsonFormat") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = map_.next_value()?; + } + } + } + Ok(NdJsonFormat { + options: options__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.NdJsonFormat", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ParquetFormat { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 62919e218b13..bf198a24c811 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -49,6 +49,12 @@ pub struct ParquetFormat { pub struct AvroFormat {} #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct NdJsonFormat { + #[prost(message, optional, tag = "1")] + pub options: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PrimaryKeyConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e133abd46f43..4c90297263c4 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -90,6 +90,7 @@ message ListingTableScanNode { datafusion_common.CsvFormat csv = 10; datafusion_common.ParquetFormat parquet = 11; datafusion_common.AvroFormat avro = 12; + datafusion_common.NdJsonFormat json = 15; } repeated LogicalExprNodeCollection file_sort_order = 13; } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 62919e218b13..bf198a24c811 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -49,6 +49,12 @@ pub struct ParquetFormat { pub struct AvroFormat {} #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct NdJsonFormat { + #[prost(message, optional, tag = "1")] + pub options: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PrimaryKeyConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c5ec67d72875..163a4c044aeb 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9031,6 +9031,9 @@ impl serde::Serialize for ListingTableScanNode { listing_table_scan_node::FileFormatType::Avro(v) => { struct_ser.serialize_field("avro", v)?; } + listing_table_scan_node::FileFormatType::Json(v) => { + struct_ser.serialize_field("json", v)?; + } } } struct_ser.end() @@ -9062,6 +9065,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { "csv", "parquet", "avro", + "json", ]; #[allow(clippy::enum_variant_names)] @@ -9079,6 +9083,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { Csv, Parquet, Avro, + Json, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9113,6 +9118,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { "csv" => Ok(GeneratedField::Csv), "parquet" => Ok(GeneratedField::Parquet), "avro" => Ok(GeneratedField::Avro), + "json" => Ok(GeneratedField::Json), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9226,6 +9232,13 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { return Err(serde::de::Error::duplicate_field("avro")); } file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Avro) +; + } + GeneratedField::Json => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("json")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Json) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 98b70dc25351..606fe3c1699f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -118,7 +118,7 @@ pub struct ListingTableScanNode { pub target_partitions: u32, #[prost(message, repeated, tag = "13")] pub file_sort_order: ::prost::alloc::vec::Vec, - #[prost(oneof = "listing_table_scan_node::FileFormatType", tags = "10, 11, 12")] + #[prost(oneof = "listing_table_scan_node::FileFormatType", tags = "10, 11, 12, 15")] pub file_format_type: ::core::option::Option< listing_table_scan_node::FileFormatType, >, @@ -134,6 +134,8 @@ pub mod listing_table_scan_node { Parquet(super::super::datafusion_common::ParquetFormat), #[prost(message, tag = "12")] Avro(super::super::datafusion_common::AvroFormat), + #[prost(message, tag = "15")] + Json(super::super::datafusion_common::NdJsonFormat), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index bac31850c875..e7019553f53d 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -124,7 +124,7 @@ pub mod protobuf { pub use datafusion_proto_common::common::proto_error; pub use datafusion_proto_common::protobuf_common::{ ArrowOptions, ArrowType, AvroFormat, AvroOptions, CsvFormat, DfSchema, - EmptyMessage, Field, JoinSide, ParquetFormat, ScalarValue, Schema, + EmptyMessage, Field, JoinSide, NdJsonFormat, ParquetFormat, ScalarValue, Schema, }; pub use datafusion_proto_common::{FromProtoError, ToProtoError}; } diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 2c4085b88869..ce9d24d94d99 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use datafusion::{ - config::CsvOptions, + config::{CsvOptions, JsonOptions}, datasource::file_format::{ arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory, parquet::ParquetFormatFactory, FileFormatFactory, @@ -31,7 +31,7 @@ use datafusion_common::{ }; use prost::Message; -use crate::protobuf::CsvOptions as CsvOptionsProto; +use crate::protobuf::{CsvOptions as CsvOptionsProto, JsonOptions as JsonOptionsProto}; use super::LogicalExtensionCodec; @@ -222,6 +222,34 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { } } +impl JsonOptionsProto { + fn from_factory(factory: &JsonFormatFactory) -> Self { + if let Some(options) = &factory.options { + JsonOptionsProto { + compression: options.compression as i32, + schema_infer_max_rec: options.schema_infer_max_rec as u64, + } + } else { + JsonOptionsProto::default() + } + } +} + +impl From<&JsonOptionsProto> for JsonOptions { + fn from(proto: &JsonOptionsProto) -> Self { + JsonOptions { + compression: match proto.compression { + 0 => CompressionTypeVariant::GZIP, + 1 => CompressionTypeVariant::BZIP2, + 2 => CompressionTypeVariant::XZ, + 3 => CompressionTypeVariant::ZSTD, + _ => CompressionTypeVariant::UNCOMPRESSED, + }, + schema_infer_max_rec: proto.schema_infer_max_rec as usize, + } + } +} + #[derive(Debug)] pub struct JsonLogicalExtensionCodec; @@ -267,17 +295,44 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { fn try_decode_file_format( &self, - __buf: &[u8], - __ctx: &SessionContext, + buf: &[u8], + _ctx: &SessionContext, ) -> datafusion_common::Result> { - Ok(Arc::new(JsonFormatFactory::new())) + let proto = JsonOptionsProto::decode(buf).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode JsonOptionsProto: {:?}", + e + )) + })?; + let options: JsonOptions = (&proto).into(); + Ok(Arc::new(JsonFormatFactory { + options: Some(options), + })) } fn try_encode_file_format( &self, - __buf: &mut Vec, - __node: Arc, + buf: &mut Vec, + node: Arc, ) -> datafusion_common::Result<()> { + let options = if let Some(json_factory) = + node.as_any().downcast_ref::() + { + json_factory.options.clone().unwrap_or_default() + } else { + return Err(DataFusionError::Execution( + "Unsupported FileFormatFactory type".to_string(), + )); + }; + + let proto = JsonOptionsProto::from_factory(&JsonFormatFactory { + options: Some(options), + }); + + proto.encode(buf).map_err(|e| { + DataFusionError::Execution(format!("Failed to encode JsonOptions: {:?}", e)) + })?; + Ok(()) } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 5427f34e8e07..0a91babdfb60 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -38,7 +38,10 @@ use datafusion::datasource::file_format::{ }; use datafusion::{ datasource::{ - file_format::{avro::AvroFormat, csv::CsvFormat, FileFormat}, + file_format::{ + avro::AvroFormat, csv::CsvFormat, json::JsonFormat as OtherNdJsonFormat, + FileFormat, + }, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, view::ViewTable, TableProvider, @@ -395,7 +398,17 @@ impl AsLogicalPlan for LogicalPlanNode { if let Some(options) = options { csv = csv.with_options(options.try_into()?) } - Arc::new(csv)}, + Arc::new(csv) + }, + FileFormatType::Json(protobuf::NdJsonFormat { + options + }) => { + let mut json = OtherNdJsonFormat::default(); + if let Some(options) = options { + json = json.with_options(options.try_into()?) + } + Arc::new(json) + } FileFormatType::Avro(..) => Arc::new(AvroFormat), }; @@ -996,6 +1009,14 @@ impl AsLogicalPlan for LogicalPlanNode { })); } + if let Some(json) = any.downcast_ref::() { + let options = json.options(); + maybe_some_type = + Some(FileFormatType::Json(protobuf::NdJsonFormat { + options: Some(options.try_into()?), + })) + } + if any.is::() { maybe_some_type = Some(FileFormatType::Avro(protobuf::AvroFormat {})) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 1bd6e9ad34b4..daa92475068f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -23,6 +23,8 @@ use arrow::datatypes::{ IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, DECIMAL256_MAX_PRECISION, }; +use datafusion::datasource::file_format::json::JsonFormatFactory; +use datafusion_common::parsers::CompressionTypeVariant; use prost::Message; use std::any::Any; use std::collections::HashMap; @@ -74,7 +76,8 @@ use datafusion_proto::bytes::{ logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; use datafusion_proto::logical_plan::file_formats::{ - ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, ParquetLogicalExtensionCodec, + ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, JsonLogicalExtensionCodec, + ParquetLogicalExtensionCodec, }; use datafusion_proto::logical_plan::to_proto::serialize_expr; use datafusion_proto::logical_plan::{ @@ -507,6 +510,73 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_json() -> Result<()> { + let ctx = SessionContext::new(); + + // Assume create_json_scan creates a logical plan for scanning a JSON file + let input = create_json_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut json_format = table_options.json; + + // Set specific JSON format options + json_format.compression = CompressionTypeVariant::GZIP; + json_format.schema_infer_max_rec = 1000; + + let file_type = format_as_file_type(Arc::new(JsonFormatFactory::new_with_options( + json_format.clone(), + ))); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.json".to_string(), + partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], + file_type, + options: Default::default(), + }); + + // Assume JsonLogicalExtensionCodec is implemented similarly to CsvLogicalExtensionCodec + let codec = JsonLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.json", copy_to.output_url); + assert_eq!("json".to_string(), copy_to.file_type.get_ext()); + assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); + + let file_type = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + let format_factory = file_type.as_format_factory(); + let json_factory = format_factory + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let json_config = json_factory.options.as_ref().unwrap(); + assert_eq!(json_format.compression, json_config.compression); + assert_eq!( + json_format.schema_infer_max_rec, + json_config.schema_infer_max_rec + ); + } + _ => panic!(), + } + + Ok(()) +} + async fn create_csv_scan(ctx: &SessionContext) -> Result { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?; @@ -515,6 +585,18 @@ async fn create_csv_scan(ctx: &SessionContext) -> Result Result { + ctx.register_json( + "t1", + "../core/tests/data/1.json", + NdJsonReadOptions::default(), + ) + .await?; + + let input = ctx.table("t1").await?.into_optimized_plan()?; + Ok(input) +} + #[tokio::test] async fn roundtrip_logical_plan_distinct_on() -> Result<()> { let ctx = SessionContext::new();