diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index cdd464f38a76..4d8d1c159f70 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -56,4 +56,4 @@ serde_json = { workspace = true, optional = true } [dev-dependencies] doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } -tokio = "1.18" +tokio = { version = "1.18", features = ["rt-multi-thread"] } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 563508006a1c..628ee5ad9b7a 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -530,13 +530,25 @@ pub fn parse_protobuf_file_scan_config( true => ObjectStoreUrl::local_filesystem(), }; - // extract types of partition columns + // Reacquire the partition column types from the schema before removing them below. let table_partition_cols = proto .table_partition_cols .iter() .map(|col| Ok(schema.field_with_name(col)?.clone())) .collect::>>()?; + // Remove partition columns from the schema after recreating table_partition_cols + // because the partition columns are not in the file. They are present to allow the + // the partition column types to be reconstructed after serde. + let file_schema = Arc::new(Schema::new( + schema + .fields() + .iter() + .filter(|field| !table_partition_cols.contains(field)) + .cloned() + .collect::>(), + )); + let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { let sort_expr = node_collection @@ -562,7 +574,7 @@ pub fn parse_protobuf_file_scan_config( Ok(FileScanConfig { object_store_url, - file_schema: schema, + file_schema, file_groups, statistics, projection, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 96d43e7e08ca..ce3df8183dc9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -738,6 +738,17 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { output_orderings.push(expr_node_vec) } + // Fields must be added to the schema so that they can persist in the protobuf + // and then they are to be removed from the schema in `parse_protobuf_file_scan_config` + let mut fields = conf + .file_schema + .fields() + .iter() + .cloned() + .collect::>(); + fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new)); + let schema = Arc::new(datafusion::arrow::datatypes::Schema::new(fields.clone())); + Ok(protobuf::FileScanExecConf { file_groups, statistics: Some((&conf.statistics).into()), @@ -749,7 +760,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { .iter() .map(|n| *n as u32) .collect(), - schema: Some(conf.file_schema.as_ref().try_into()?), + schema: Some(schema.as_ref().try_into()?), table_partition_cols: conf .table_partition_cols .iter() diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index f42698b69c83..e44f1863891a 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -18,6 +18,7 @@ use arrow::csv::WriterBuilder; use std::ops::Deref; use std::sync::Arc; +use std::vec; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; @@ -28,7 +29,8 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ - FileScanConfig, FileSinkConfig, ParquetExec, + wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, + FileSinkConfig, ParquetExec, }; use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::{ @@ -561,6 +563,32 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { ))) } +#[tokio::test] +async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { + let mut file_group = + PartitionedFile::new("/path/to/part=0/file.parquet".to_string(), 1024); + file_group.partition_values = + vec![wrap_partition_value_in_dict(ScalarValue::Int64(Some(0)))]; + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + + let scan_config = FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![vec![file_group]], + statistics: Statistics::new_unknown(&schema), + file_schema: schema, + projection: Some(vec![0, 1]), + limit: None, + table_partition_cols: vec![Field::new( + "part".to_string(), + wrap_partition_type_in_dict(DataType::Int16), + false, + )], + output_ordering: vec![], + }; + + roundtrip_test(Arc::new(ParquetExec::new(scan_config, None, None))) +} + #[test] fn roundtrip_builtin_scalar_function() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false);