diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 09d14e12d0..6b6b748e51 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -683,12 +683,18 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> else: deleted_files = [] + schema = table.schema() + partitioning: Dict[str, list] = {schema.find_field(field.source_id).name: [] for field in table.spec().fields} + for data_file in data_files: operations.append("ADD") path.append(data_file.file_path) rows.append(data_file.record_count) size.append(data_file.file_size_in_bytes) + for field in partitioning.keys(): + partitioning[field].append(getattr(data_file.partition, field, None)) + for pf in deleted_files: data_file = pf.file operations.append("DELETE") @@ -696,6 +702,9 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> rows.append(data_file.record_count) size.append(data_file.file_size_in_bytes) + for field in partitioning.keys(): + partitioning[field].append(getattr(data_file.partition, field, None)) + if parse(pyiceberg.__version__) >= parse("0.7.0"): from pyiceberg.table import ALWAYS_TRUE, PropertyUtil, TableProperties @@ -735,19 +744,23 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> merge.commit() + with_operations = { + "operation": pa.array(operations, type=pa.string()), + "rows": pa.array(rows, type=pa.int64()), + "file_size": pa.array(size, type=pa.int64()), + "file_name": pa.array([fp for fp in path], type=pa.string()), + } + + if partitioning: + with_operations["partitioning"] = pa.StructArray.from_arrays( + partitioning.values(), names=partitioning.keys() + ) + from daft import from_pydict - with_operations = from_pydict( - { - "operation": pa.array(operations, type=pa.string()), - "rows": pa.array(rows, type=pa.int64()), - "file_size": pa.array(size, type=pa.int64()), - "file_name": pa.array([os.path.basename(fp) for fp in path], type=pa.string()), - } - ) # NOTE: We are losing the history of the plan here. # This is due to the fact that the logical plan of the write_iceberg returns datafiles but we want to return the above data - return with_operations + return from_pydict(with_operations) @DataframePublicAPI def write_deltalake( diff --git a/daft/iceberg/iceberg_write.py b/daft/iceberg/iceberg_write.py index 0ab4165d81..385a98db7a 100644 --- a/daft/iceberg/iceberg_write.py +++ b/daft/iceberg/iceberg_write.py @@ -1,18 +1,35 @@ +import datetime +import uuid import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any -from daft import Series +from daft import Expression, Series, col from daft.table import MicroPartition if TYPE_CHECKING: import pyarrow as pa - from pyiceberg.io.pyarrow import _TablePartition - from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec + from pyiceberg.partitioning import PartitionField as IcebergPartitionField from pyiceberg.schema import Schema as IcebergSchema - from pyiceberg.typedef import Record -def _coerce_pyarrow_table_to_schema(pa_table: "pa.Table", schema: "IcebergSchema") -> "pa.Table": +def add_missing_columns(table: MicroPartition, schema: "pa.Schema") -> MicroPartition: + """Add null values for columns in the schema that are missing from the table.""" + + import pyarrow as pa + + existing_columns = set(table.column_names()) + + columns = {} + for name in schema.names: + if name in existing_columns: + columns[name] = table.get_column(name) + else: + columns[name] = Series.from_arrow(pa.nulls(len(table), type=schema.field(name).type), name=name) + + return MicroPartition.from_pydict(columns) + + +def coerce_pyarrow_table_to_schema(pa_table: "pa.Table", schema: "pa.Schema") -> "pa.Table": """Coerces a PyArrow table to the supplied schema 1. For each field in `pa_table`, cast it to the field in `input_schema` if one with a matching name @@ -27,21 +44,18 @@ def _coerce_pyarrow_table_to_schema(pa_table: "pa.Table", schema: "IcebergSchema Args: pa_table (pa.Table): Table to coerce - schema (IcebergSchema): PyIceberg schema to coerce to + schema (pa.Schema): Iceberg schema to coerce to Returns: - pa.Table: Table with schema == `input_schema` + pa.Table: Table with schema == `schema` """ import pyarrow as pa - from pyiceberg.io.pyarrow import schema_to_pyarrow - - input_schema = schema_to_pyarrow(schema) - input_schema_names = set(input_schema.names) + input_schema_names = set(schema.names) # Perform casting of types to provided schema's types cast_to_schema = [ - (input_schema.field(inferred_field.name) if inferred_field.name in input_schema_names else inferred_field) + (schema.field(inferred_field.name) if inferred_field.name in input_schema_names else inferred_field) for inferred_field in pa_table.schema ] casted_table = pa_table.cast(pa.schema(cast_to_schema)) @@ -49,132 +63,60 @@ def _coerce_pyarrow_table_to_schema(pa_table: "pa.Table", schema: "IcebergSchema # Reorder and pad columns with a null column where necessary pa_table_column_names = set(casted_table.column_names) columns = [] - for name in input_schema.names: + for name in schema.names: if name in pa_table_column_names: columns.append(casted_table[name]) else: - columns.append(pa.nulls(len(casted_table), type=input_schema.field(name).type)) - return pa.table(columns, schema=input_schema) - - -def _determine_partitions( - spec: "IcebergPartitionSpec", schema: "IcebergSchema", arrow_table: "pa.Table" -) -> List["_TablePartition"]: - """Based on https://github.com/apache/iceberg-python/blob/d8d509ff1bc33040b9f6c90c28ee47ac7437945d/pyiceberg/io/pyarrow.py#L2669""" - - import pyarrow as pa - from pyiceberg.io.pyarrow import _get_table_partitions - from pyiceberg.partitioning import PartitionField - from pyiceberg.transforms import Transform - from pyiceberg.types import NestedField - - partition_columns: List[Tuple[PartitionField, NestedField]] = [ - (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields - ] - - def partition_transform(array: pa.Array, transform: Transform) -> Optional[pa.Array]: - from pyiceberg.transforms import ( - BucketTransform, - DayTransform, - HourTransform, - IdentityTransform, - MonthTransform, - TruncateTransform, - YearTransform, - ) - - series = Series.from_arrow(array) - - transformed = None - if isinstance(transform, IdentityTransform): - transformed = series - elif isinstance(transform, YearTransform): - transformed = series.partitioning.years() - elif isinstance(transform, MonthTransform): - transformed = series.partitioning.months() - elif isinstance(transform, DayTransform): - transformed = series.partitioning.days() - elif isinstance(transform, HourTransform): - transformed = series.partitioning.hours() - elif isinstance(transform, BucketTransform): - n = transform.num_buckets - transformed = series.partitioning.iceberg_bucket(n) - elif isinstance(transform, TruncateTransform): - w = transform.width - transformed = series.partitioning.iceberg_truncate(w) - else: - warnings.warn(f"{transform} not implemented, Please make an issue!") - - return transformed.to_arrow() if transformed is not None else None - - partition_values_table = pa.table( - { - str(partition.field_id): partition_transform(arrow_table[field.name], partition.transform) - for partition, field in partition_columns - } + columns.append(pa.nulls(len(casted_table), type=schema.field(name).type)) + return pa.table(columns, schema=schema) + + +def partition_field_to_expr(field: "IcebergPartitionField", schema: "IcebergSchema") -> Expression: + from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + YearTransform, ) - # Sort by partitions - sort_indices = pa.compute.sort_indices( - partition_values_table, - sort_keys=[(col, "ascending") for col in partition_values_table.column_names], - null_placement="at_end", - ).to_pylist() - arrow_table = arrow_table.take(sort_indices) - - # Get slice_instructions to group by partitions - partition_values_table = partition_values_table.take(sort_indices) - reversed_indices = pa.compute.sort_indices( - partition_values_table, - sort_keys=[(col, "descending") for col in partition_values_table.column_names], - null_placement="at_start", - ).to_pylist() - slice_instructions: List[Dict[str, Any]] = [] - last = len(reversed_indices) - reversed_indices_size = len(reversed_indices) - ptr = 0 - while ptr < reversed_indices_size: - group_size = last - reversed_indices[ptr] - offset = reversed_indices[ptr] - slice_instructions.append({"offset": offset, "length": group_size}) - last = reversed_indices[ptr] - ptr = ptr + group_size - - table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) - - return table_partitions - - -def micropartition_to_arrow_tables( - table: MicroPartition, path: str, schema: "IcebergSchema", partition_spec: "IcebergPartitionSpec" -) -> List[Tuple["pa.Table", str, "Record"]]: - """ - Converts a MicroPartition to a list of Arrow tables with paths, partitioning the data if necessary. - - Args: - table (MicroPartition): Table to convert - path (str): Base path to write the table to - schema (IcebergSchema): Schema of the Iceberg table - partition_spec (IcebergPartitionSpec): Iceberg partitioning spec - - Returns: - List[Tuple[pa.Table, str, Record]]: List of Arrow tables with their paths and partition records - """ - from pyiceberg.typedef import Record - - arrow_table = table.to_arrow() - arrow_table = _coerce_pyarrow_table_to_schema(arrow_table, schema) - - if partition_spec.is_unpartitioned(): - return [(arrow_table, path, Record())] + partition_col = col(schema.find_field(field.source_id).name) + + if isinstance(field.transform, IdentityTransform): + return partition_col + elif isinstance(field.transform, YearTransform): + return partition_col.partitioning.years() + elif isinstance(field.transform, MonthTransform): + return partition_col.partitioning.months() + elif isinstance(field.transform, DayTransform): + return partition_col.partitioning.days() + elif isinstance(field.transform, HourTransform): + return partition_col.partitioning.hours() + elif isinstance(field.transform, BucketTransform): + return partition_col.partitioning.iceberg_bucket(field.transform.num_buckets) + elif isinstance(field.transform, TruncateTransform): + return partition_col.partitioning.iceberg_truncate(field.transform.width) + else: + warnings.warn(f"{field.transform} not implemented, Please make an issue!") + return partition_col + + +def to_partition_representation(value: Any): + if value is None: + return None + + if isinstance(value, datetime.datetime): + # Convert to microseconds since epoch + return (value - datetime.datetime(1970, 1, 1)) // datetime.timedelta(microseconds=1) + elif isinstance(value, datetime.date): + # Convert to days since epoch + return (value - datetime.date(1970, 1, 1)) // datetime.timedelta(days=1) + elif isinstance(value, datetime.time): + # Convert to microseconds since midnight + return (value.hour * 60 * 60 + value.minute * 60 + value.second) * 1_000_000 + value.microsecond + elif isinstance(value, uuid.UUID): + return str(value) else: - partitions = _determine_partitions(partition_spec, schema, arrow_table) - - return [ - ( - partition.arrow_table_partition, - f"{path}/{partition.partition_key.to_path()}", - partition.partition_key.partition, - ) - for partition in partitions - ] + return value diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 7e8b1b247c..72b3dce508 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -6,8 +6,9 @@ import random import time from collections.abc import Callable, Generator +from dataclasses import dataclass from functools import partial -from typing import IO, TYPE_CHECKING, Any, Union +from typing import IO, TYPE_CHECKING, Any, Iterator, Union from uuid import uuid4 import pyarrow as pa @@ -30,7 +31,6 @@ PythonStorageConfig, StorageConfig, ) -from daft.datatype import DataType from daft.expressions import ExpressionsProjection from daft.expressions.expressions import Expression from daft.filesystem import ( @@ -399,6 +399,42 @@ def read_csv( return _cast_table_to_schema(daft_table, read_options=read_options, schema=schema) +@dataclass +class _TableWriteData: + table: MicroPartition + path: str + partition_values: dict[str, Any] + + +def _table_to_partitions( + table: MicroPartition, + path: str, + partition_keys: ExpressionsProjection | None, + partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__", +) -> Iterator[_TableWriteData]: + if partition_keys is None or len(partition_keys) == 0: + yield _TableWriteData(table, path, {}) + else: + default_part = Series.from_pylist([partition_null_fallback]) + split_tables, partition_values = table.partition_by_value(partition_keys=partition_keys) + assert len(split_tables) == len(partition_values) + pkey_names = partition_values.column_names() + + values_string_values = [] + + for c in pkey_names: + column = partition_values.get_column(c) + string_names = column._to_str_values() + null_filled = column.is_null().if_else(default_part, string_names) + values_string_values.append(null_filled.to_pylist()) + + partition_values_list = partition_values.to_pylist() + for i, (tab, values) in enumerate(zip(split_tables, partition_values_list)): + postfix = "/".join(f"{pkey}={values[i]}" for pkey, values in zip(pkey_names, values_string_values)) + partition_path = f"{path}/{postfix}" + yield _TableWriteData(tab, partition_path, values) + + def write_tabular( table: MicroPartition, file_format: FileFormat, @@ -407,7 +443,6 @@ def write_tabular( partition_cols: ExpressionsProjection | None = None, compression: str | None = None, io_config: IOConfig | None = None, - partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__", ) -> MicroPartition: [resolved_path], fs = _resolve_paths_and_filesystem(path, io_config=io_config) if isinstance(path, pathlib.Path): @@ -420,35 +455,6 @@ def write_tabular( is_local_fs = canonicalized_protocol == "file" - tables_to_write: list[MicroPartition] - part_keys_postfix_per_table: list[str | None] - partition_values = None - if partition_cols and len(partition_cols) > 0: - default_part = Series.from_pylist([partition_null_fallback]) - split_tables, partition_values = table.partition_by_value(partition_keys=partition_cols) - assert len(split_tables) == len(partition_values) - pkey_names = partition_values.column_names() - - values_string_values = [] - - for c in pkey_names: - column = partition_values.get_column(c) - string_names = column._to_str_values() - null_filled = column.is_null().if_else(default_part, string_names) - values_string_values.append(null_filled.to_pylist()) - - part_keys_postfix_per_table = [] - for i in range(len(partition_values)): - postfix = "/".join(f"{pkey}={values[i]}" for pkey, values in zip(pkey_names, values_string_values)) - part_keys_postfix_per_table.append(postfix) - tables_to_write = split_tables - else: - tables_to_write = [table] - part_keys_postfix_per_table = [None] - - visited_paths = [] - partition_idx = [] - execution_config = get_context().daft_execution_config TARGET_ROW_GROUP_SIZE = execution_config.parquet_target_row_group_size @@ -467,12 +473,26 @@ def write_tabular( else: raise ValueError(f"Unsupported file format {file_format}") - for i, (tab, pf) in enumerate(zip(tables_to_write, part_keys_postfix_per_table)): - full_path = resolved_path - if pf is not None and len(pf) > 0: - full_path = f"{full_path}/{pf}" + # I kept this from our original code, but idk why it's the first column name -kevin + path_key = schema.column_names()[0] + + # TODO: when we have a MicroPartition.from_pylist, use a list here instead + data_dict: dict[str, list[Any]] = {path_key: []} + + if partition_cols is not None: + data_dict.update({expr.name(): [] for expr in partition_cols}) + + @dataclass + class FileVisitor: + partition_values: dict[str, Any] - arrow_table = tab.to_arrow() + def __call__(self, written_file): + data_dict[path_key].append(written_file.path) + for c in self.partition_values: + data_dict[c].append(self.partition_values[c]) + + for write_data in _table_to_partitions(table, resolved_path, partition_cols): + arrow_table = write_data.table.to_arrow() size_bytes = arrow_table.nbytes @@ -484,38 +504,24 @@ def write_tabular( target_row_groups = max(math.ceil(size_bytes / TARGET_ROW_GROUP_SIZE / inflation_factor), 1) rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) - def file_visitor(written_file, i=i): - visited_paths.append(written_file.path) - partition_idx.append(i) - _write_tabular_arrow_table( arrow_table=arrow_table, schema=arrow_table.schema, - full_path=full_path, + full_path=write_data.path, format=format, opts=opts, fs=fs, rows_per_file=rows_per_file, rows_per_row_group=rows_per_row_group, create_dir=is_local_fs, - file_visitor=file_visitor, + file_visitor=FileVisitor(write_data.partition_values), ) - data_dict: dict[str, Any] = { - schema.column_names()[0]: Series.from_pylist(visited_paths, name=schema.column_names()[0]).cast( - DataType.string() - ) - } - - if partition_values is not None: - partition_idx_series = Series.from_pylist(partition_idx).cast(DataType.int64()) - for c_name in partition_values.column_names(): - data_dict[c_name] = partition_values.get_column(c_name).take(partition_idx_series) return MicroPartition.from_pydict(data_dict) def write_iceberg( - mp: MicroPartition, + table: MicroPartition, base_path: str, schema: IcebergSchema, properties: IcebergTableProperties, @@ -531,8 +537,14 @@ def write_iceberg( ) from pyiceberg.manifest import DataFile, DataFileContent from pyiceberg.manifest import FileFormat as IcebergFileFormat + from pyiceberg.typedef import Record as IcebergRecord - from daft.iceberg.iceberg_write import micropartition_to_arrow_tables + from daft.iceberg.iceberg_write import ( + add_missing_columns, + coerce_pyarrow_table_to_schema, + partition_field_to_expr, + to_partition_representation, + ) [resolved_path], fs = _resolve_paths_and_filesystem(base_path, io_config=io_config) if isinstance(base_path, pathlib.Path): @@ -558,11 +570,15 @@ def write_iceberg( file_schema = schema_to_pyarrow(schema) + partition_keys = ExpressionsProjection([partition_field_to_expr(field, schema) for field in partition_spec.fields]) + data_files = [] - for arrow_table, path, partition in micropartition_to_arrow_tables(mp, resolved_path, schema, partition_spec): + @dataclass + class FileVisitor: + partition_record: IcebergRecord - def file_visitor(written_file, protocol=protocol): + def __call__(self, written_file): file_path = f"{protocol}://{written_file.path}" size = written_file.size metadata = written_file.metadata @@ -571,7 +587,7 @@ def file_visitor(written_file, protocol=protocol): "content": DataFileContent.DATA, "file_path": file_path, "file_format": IcebergFileFormat.PARQUET, - "partition": partition, + "partition": self.partition_record, "file_size_in_bytes": size, # After this has been fixed: # https://github.com/apache/iceberg-python/issues/271 @@ -612,6 +628,12 @@ def file_visitor(written_file, protocol=protocol): data_files.append(data_file) + table = add_missing_columns(table, file_schema) + for write_data in _table_to_partitions(table, resolved_path, partition_keys): + arrow_table = write_data.table.to_arrow() + + arrow_table = coerce_pyarrow_table_to_schema(arrow_table, file_schema) + size_bytes = arrow_table.nbytes target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) @@ -622,17 +644,22 @@ def file_visitor(written_file, protocol=protocol): target_row_groups = max(math.ceil(size_bytes / TARGET_ROW_GROUP_SIZE / inflation_factor), 1) rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) + encoded_partition_values = { + key: to_partition_representation(value) for key, value in write_data.partition_values.items() + } + partition_record = IcebergRecord(**encoded_partition_values) + _write_tabular_arrow_table( arrow_table=arrow_table, schema=file_schema, - full_path=path, + full_path=write_data.path, format=format, opts=opts, fs=fs, rows_per_file=rows_per_file, rows_per_row_group=rows_per_row_group, create_dir=is_local_fs, - file_visitor=file_visitor, + file_visitor=FileVisitor(partition_record), ) return MicroPartition.from_pydict({"data_file": Series.from_pylist(data_files, name="data_file", pyobj="force")}) diff --git a/src/daft-core/src/series/ops/partitioning.rs b/src/daft-core/src/series/ops/partitioning.rs index ef20ae3726..18f6fed67a 100644 --- a/src/daft-core/src/series/ops/partitioning.rs +++ b/src/daft-core/src/series/ops/partitioning.rs @@ -1,6 +1,7 @@ use crate::array::ops::as_arrow::AsArrow; use crate::datatypes::logical::TimestampArray; use crate::datatypes::{Int32Array, Int64Array, TimeUnit}; +use crate::prelude::*; use crate::series::array_impl::IntoSeries; use crate::with_match_integer_daft_types; use crate::{datatypes::DataType, series::Series}; @@ -25,9 +26,7 @@ impl Series { self.data_type() ))), }?; - value - .rename(format!("{}_years", self.name())) - .cast(&DataType::Int32) + value.cast(&DataType::Int32) } pub fn partitioning_months(&self) -> DaftResult { @@ -50,32 +49,34 @@ impl Series { self.data_type() ))), }?; - value - .rename(format!("{}_months", self.name())) - .cast(&DataType::Int32) + value.cast(&DataType::Int32) } pub fn partitioning_days(&self) -> DaftResult { - let result = match self.data_type() { - DataType::Date => Ok(self.clone()), - DataType::Timestamp(_, None) => { - let ts_array = self.downcast::()?; - Ok(ts_array.date()?.into_series()) + match self.data_type() { + DataType::Date => { + let date_array = self.downcast::()?; + Ok(date_array.physical.clone().into_series()) } - - DataType::Timestamp(tu, Some(_)) => { - let array = self.cast(&DataType::Timestamp(*tu, None))?; - let ts_array = array.downcast::()?; - Ok(ts_array.date()?.into_series()) + DataType::Timestamp(unit, _) => { + let ts_array = self.downcast::()?; + let physical = &ts_array.physical.cast(&DataType::Float64)?; + let unit_to_days: f64 = match unit { + TimeUnit::Nanoseconds => 86_400_000_000_000.0, + TimeUnit::Microseconds => 86_400_000_000.0, + TimeUnit::Milliseconds => 86_400_000.0, + TimeUnit::Seconds => 86_400.0, + }; + // TODO: use floor division once it is implemented + let divider = Float64Array::from(("divider", vec![unit_to_days])).into_series(); + let days = (physical / ÷r)?.floor()?; + days.cast(&DataType::Int32) } - _ => Err(DaftError::ComputeError(format!( "Can only run partitioning_days() operation on temporal types, got {}", self.data_type() ))), - }?; - - Ok(result.rename(format!("{}_days", self.name()))) + } } pub fn partitioning_hours(&self) -> DaftResult { @@ -98,9 +99,7 @@ impl Series { self.data_type() ))), }?; - value - .rename(format!("{}_hours", self.name())) - .cast(&DataType::Int32) + value.cast(&DataType::Int32) } pub fn partitioning_iceberg_bucket(&self, n: i32) -> DaftResult { @@ -111,12 +110,12 @@ impl Series { .into_iter() .map(|v| v.map(|v| (v & i32::MAX) % n)); let array = Box::new(arrow2::array::Int32Array::from_iter(buckets)); - Ok(Int32Array::from((format!("{}_bucket", self.name()).as_str(), array)).into_series()) + Ok(Int32Array::from((self.name(), array)).into_series()) } pub fn partitioning_iceberg_truncate(&self, w: i64) -> DaftResult { assert!(w > 0, "Expected w to be positive, got {w}"); - let trunc = match self.data_type() { + match self.data_type() { i if i.is_integer() => { with_match_integer_daft_types!(i, |$T| { let downcasted = self.downcast::<<$T as DaftDataType>::ArrayType>()?; @@ -130,8 +129,6 @@ impl Series { "Can only run partitioning_iceberg_truncate() operation on integers, decimal, string, and binary, got {}", self.data_type() ))), - }?; - - Ok(trunc.rename(format!("{}_truncate", self.name()))) + } } } diff --git a/src/daft-dsl/src/functions/partitioning/evaluators.rs b/src/daft-dsl/src/functions/partitioning/evaluators.rs index 0a7acc4ad7..80883c8048 100644 --- a/src/daft-dsl/src/functions/partitioning/evaluators.rs +++ b/src/daft-dsl/src/functions/partitioning/evaluators.rs @@ -23,10 +23,9 @@ macro_rules! impl_func_evaluator_for_partitioning { ) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field::new( - format!("{}_{}", field.name, stringify!($op)), - $result_type, - )), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, $result_type)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to {} to be temporal, got {}", stringify!($op), @@ -55,10 +54,10 @@ macro_rules! impl_func_evaluator_for_partitioning { }; } use crate::functions::FunctionExpr; -use DataType::{Date, Int32}; +use DataType::Int32; impl_func_evaluator_for_partitioning!(YearsEvaluator, years, partitioning_years, Int32); impl_func_evaluator_for_partitioning!(MonthsEvaluator, months, partitioning_months, Int32); -impl_func_evaluator_for_partitioning!(DaysEvaluator, days, partitioning_days, Date); +impl_func_evaluator_for_partitioning!(DaysEvaluator, days, partitioning_days, Int32); impl_func_evaluator_for_partitioning!(HoursEvaluator, hours, partitioning_hours, Int32); pub(super) struct IcebergBucketEvaluator {} @@ -76,14 +75,8 @@ impl FunctionEvaluator for IcebergBucketEvaluator { | DataType::Date | DataType::Timestamp(..) | DataType::Utf8 - | DataType::Binary => Ok(Field::new( - format!("{}_bucket", field.name), - DataType::Int32, - )), - v if v.is_integer() => Ok(Field::new( - format!("{}_bucket", field.name), - DataType::Int32, - )), + | DataType::Binary => Ok(Field::new(field.name, DataType::Int32)), + v if v.is_integer() => Ok(Field::new(field.name, DataType::Int32)), _ => Err(DaftError::TypeError(format!( "Expected input to iceberg bucketing to be murmur3 hashable, got {}", field.dtype @@ -126,10 +119,10 @@ impl FunctionEvaluator for IcebergTruncateEvaluator { [input] => match input.to_field(schema) { Ok(field) => match &field.dtype { DataType::Decimal128(_, _) - | DataType::Utf8 => Ok(Field::new(format!("{}_truncate", field.name), field.dtype)), - v if v.is_integer() => Ok(Field::new(format!("{}_truncate", field.name), field.dtype)), + | DataType::Utf8 | DataType::Binary => Ok(Field::new(field.name, field.dtype)), + v if v.is_integer() => Ok(Field::new(field.name, field.dtype)), _ => Err(DaftError::TypeError(format!( - "Expected input to IcebergTruncate to be an Integer, Utf8 or Decimal, got {}", + "Expected input to IcebergTruncate to be an Integer, Utf8, Decimal, or Binary, got {}", field.dtype ))), }, diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index c197ed922d..174a169e43 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import date, datetime +from datetime import datetime import pyarrow as pa import pytest @@ -87,19 +87,18 @@ def test_parquet_write_with_partitioning_readback_values(tmp_path): @pytest.mark.parametrize( - "exp,key,answer", + "exp,answer", [ ( - daft.col("date").partitioning.days(), - "date_days", - [date(2024, 1, 1), date(2024, 2, 1), date(2024, 3, 1), date(2024, 4, 1), date(2024, 5, 1)], + daft.col("date").partitioning.days().alias("date_days"), + [19723, 19754, 19783, 19814, 19844], ), - (daft.col("date").partitioning.hours(), "date_hours", [473352, 474096, 474792, 475536, 476256]), - (daft.col("date").partitioning.months(), "date_months", [648, 649, 650, 651, 652]), - (daft.col("date").partitioning.years(), "date_years", [54]), + (daft.col("date").partitioning.hours().alias("date_hours"), [473352, 474096, 474792, 475536, 476256]), + (daft.col("date").partitioning.months().alias("date_months"), [648, 649, 650, 651, 652]), + (daft.col("date").partitioning.years().alias("date_years"), [54]), ], ) -def test_parquet_write_with_iceberg_date_partitioning(exp, key, answer, tmp_path): +def test_parquet_write_with_iceberg_date_partitioning(exp, answer, tmp_path): data = { "id": [1, 2, 3, 4, 5], "date": [ @@ -110,6 +109,7 @@ def test_parquet_write_with_iceberg_date_partitioning(exp, key, answer, tmp_path datetime(2024, 5, 1), ], } + key = exp.name() df = daft.from_pydict(data) date_files = df.write_parquet(tmp_path, partition_cols=[exp]).sort(by=key) output_dict = date_files.to_pydict() @@ -120,13 +120,13 @@ def test_parquet_write_with_iceberg_date_partitioning(exp, key, answer, tmp_path @pytest.mark.parametrize( - "exp,key,answer", + "exp,answer", [ - (daft.col("id").partitioning.iceberg_bucket(10), "id_bucket", [0, 3, 5, 6, 8]), - (daft.col("id").partitioning.iceberg_truncate(10), "id_truncate", [0, 10, 20, 40]), + (daft.col("id").partitioning.iceberg_bucket(10).alias("id_bucket"), [0, 3, 5, 6, 8]), + (daft.col("id").partitioning.iceberg_truncate(10).alias("id_truncate"), [0, 10, 20, 40]), ], ) -def test_parquet_write_with_iceberg_bucket_and_trunc(exp, key, answer, tmp_path): +def test_parquet_write_with_iceberg_bucket_and_trunc(exp, answer, tmp_path): data = { "id": [1, 12, 23, 24, 45], "date": [ @@ -137,6 +137,7 @@ def test_parquet_write_with_iceberg_bucket_and_trunc(exp, key, answer, tmp_path) datetime(2024, 5, 1), ], } + key = exp.name() df = daft.from_pydict(data) date_files = df.write_parquet(tmp_path, partition_cols=[exp]).sort(by=key) output_dict = date_files.to_pydict() diff --git a/tests/io/iceberg/test_iceberg_writes.py b/tests/io/iceberg/test_iceberg_writes.py index b6aa364948..6272941ed7 100644 --- a/tests/io/iceberg/test_iceberg_writes.py +++ b/tests/io/iceberg/test_iceberg_writes.py @@ -156,7 +156,6 @@ def test_missing_columns_write(simple_local_table): assert all(op == "ADD" for op in as_dict["operation"]), as_dict["operation"] assert sum(as_dict["rows"]) == 5, as_dict["rows"] read_back = daft.read_iceberg(simple_local_table) - print("as_dict", as_dict) assert read_back.to_pydict() == {"x": [None] * 5} @@ -239,15 +238,16 @@ def complex_table() -> tuple[pa.Table, Schema]: return table, schema -@pytest.fixture( - params=[ +@pytest.mark.parametrize( + "partition_spec", + [ pytest.param(UNPARTITIONED_PARTITION_SPEC, id="unpartitioned"), pytest.param( PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="a")), id="int_identity_partitioned", ), pytest.param( - PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(4), name="a")), + PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(2), name="a")), id="int_bucket_partitioned", ), pytest.param( @@ -263,7 +263,7 @@ def complex_table() -> tuple[pa.Table, Schema]: id="string_identity_partitioned", ), pytest.param( - PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=BucketTransform(4), name="c")), + PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=BucketTransform(2), name="c")), id="string_bucket_partitioned", ), pytest.param( @@ -275,7 +275,7 @@ def complex_table() -> tuple[pa.Table, Schema]: id="binary_identity_partitioned", ), pytest.param( - PartitionSpec(PartitionField(source_id=4, field_id=1000, transform=BucketTransform(4), name="d")), + PartitionSpec(PartitionField(source_id=4, field_id=1000, transform=BucketTransform(2), name="d")), id="binary_bucket_partitioned", ), pytest.param( @@ -291,7 +291,7 @@ def complex_table() -> tuple[pa.Table, Schema]: id="datetime_identity_partitioned", ), pytest.param( - PartitionSpec(PartitionField(source_id=6, field_id=1000, transform=BucketTransform(4), name="f")), + PartitionSpec(PartitionField(source_id=6, field_id=1000, transform=BucketTransform(2), name="f")), id="datetime_bucket_partitioned", ), pytest.param( @@ -315,7 +315,7 @@ def complex_table() -> tuple[pa.Table, Schema]: id="date_identity_partitioned", ), pytest.param( - PartitionSpec(PartitionField(source_id=7, field_id=1000, transform=BucketTransform(4), name="g")), + PartitionSpec(PartitionField(source_id=7, field_id=1000, transform=BucketTransform(2), name="g")), id="date_bucket_partitioned", ), pytest.param( @@ -335,25 +335,30 @@ def complex_table() -> tuple[pa.Table, Schema]: id="decimal_identity_partitioned", ), pytest.param( - PartitionSpec(PartitionField(source_id=8, field_id=1000, transform=BucketTransform(4), name="h")), + PartitionSpec(PartitionField(source_id=8, field_id=1000, transform=BucketTransform(2), name="h")), id="decimal_bucket_partitioned", ), pytest.param( PartitionSpec(PartitionField(source_id=8, field_id=1000, transform=TruncateTransform(2), name="h")), id="decimal_truncate_partitioned", ), - ] + pytest.param( + PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=BucketTransform(2), name="a"), + PartitionField(source_id=3, field_id=1000, transform=TruncateTransform(2), name="c"), + ), + id="double_partitioned", + ), + ], ) -def partition_spec(request) -> PartitionSpec: - return request.param - - def test_complex_table_write_read(local_catalog, complex_table, partition_spec): pa_table, schema = complex_table table = local_catalog.create_table("default.test", schema, partition_spec=partition_spec) df = daft.from_arrow(pa_table) result = df.write_iceberg(table) as_dict = result.to_pydict() + if "partitioning" in as_dict: + print("as_dict[partitioning]", as_dict["partitioning"]) assert all(op == "ADD" for op in as_dict["operation"]), as_dict["operation"] assert sum(as_dict["rows"]) == 3, as_dict["rows"] read_back = daft.read_iceberg(table) diff --git a/tests/series/test_partitioning.py b/tests/series/test_partitioning.py index 10e8eb700f..dea9ee90a8 100644 --- a/tests/series/test_partitioning.py +++ b/tests/series/test_partitioning.py @@ -1,9 +1,11 @@ from __future__ import annotations -from datetime import date, datetime +from datetime import date, datetime, time from decimal import Decimal from itertools import product +import pandas as pd +import pyarrow as pa import pytest from daft import DataType, TimeUnit @@ -29,8 +31,8 @@ def test_partitioning_days(input, dtype, expected): s = Series.from_pylist(input).cast(dtype) d = s.partitioning.days() - assert d.datatype() == DataType.date() - assert d.cast(DataType.int32()).to_pylist() == expected + assert d.datatype() == DataType.int32() + assert d.to_pylist() == expected @pytest.mark.parametrize( @@ -135,6 +137,35 @@ def test_iceberg_bucketing(input, n): seen[v] = b +@pytest.mark.parametrize( + "input,expected", + [ + (pa.array([34], type=pa.int32()), 2017239379), + (pa.array([34], type=pa.int64()), 2017239379), + (pa.array([Decimal("14.20")]), -500754589), + (pa.array([date.fromisoformat("2017-11-16")]), -653330422), + (pa.array([time.fromisoformat("22:31:08")]), -662762989), + (pa.array([datetime.fromisoformat("2017-11-16T22:31:08")]), -2047944441), + (pa.array([datetime.fromisoformat("2017-11-16T22:31:08.000001")]), -1207196810), + (pa.array([datetime.fromisoformat("2017-11-16T14:31:08-08:00")]), -2047944441), + (pa.array([datetime.fromisoformat("2017-11-16T14:31:08.000001-08:00")]), -1207196810), + (pa.array([datetime.fromisoformat("2017-11-16T22:31:08")], type=pa.timestamp("ns")), -2047944441), + (pa.array([pd.to_datetime("2017-11-16T22:31:08.000001001")], type=pa.timestamp("ns")), -1207196810), + (pa.array([datetime.fromisoformat("2017-11-16T14:31:08-08:00")], type=pa.timestamp("ns")), -2047944441), + (pa.array([pd.to_datetime("2017-11-16T14:31:08.000001001-08:00")], type=pa.timestamp("ns")), -1207196810), + (pa.array(["iceberg"]), 1210000089), + (pa.array([b"\x00\x01\x02\x03"]), -188683207), + ], +) +def test_iceberg_bucketing_hash(input, expected): + # https://iceberg.apache.org/spec/#appendix-b-32-bit-hash-requirements + max_buckets = 2**31 - 1 + s = Series.from_arrow(input) + buckets = s.partitioning.iceberg_bucket(max_buckets) + assert buckets.datatype() == DataType.int32() + assert buckets.to_pylist() == [(expected & max_buckets) % max_buckets] + + def test_iceberg_truncate_decimal(): data = ["12.34", "12.30", "12.29", "0.05", "-0.05"] data = [Decimal(v) for v in data] + [None]