Skip to content

Commit

Permalink
standardize partitioned writes and add partitioning tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Sep 16, 2024
1 parent 5f339f8 commit eaff9cc
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 280 deletions.
31 changes: 22 additions & 9 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,19 +683,28 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->
else:
deleted_files = []

schema = table.schema()
partitioning: Dict[str, list] = {schema.find_field(field.source_id).name: [] for field in table.spec().fields}

for data_file in data_files:
operations.append("ADD")
path.append(data_file.file_path)
rows.append(data_file.record_count)
size.append(data_file.file_size_in_bytes)

for field in partitioning.keys():
partitioning[field].append(getattr(data_file.partition, field, None))

for pf in deleted_files:
data_file = pf.file
operations.append("DELETE")
path.append(data_file.file_path)
rows.append(data_file.record_count)
size.append(data_file.file_size_in_bytes)

for field in partitioning.keys():
partitioning[field].append(getattr(data_file.partition, field, None))

if parse(pyiceberg.__version__) >= parse("0.7.0"):
from pyiceberg.table import ALWAYS_TRUE, PropertyUtil, TableProperties

Expand Down Expand Up @@ -735,19 +744,23 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->

merge.commit()

with_operations = {
"operation": pa.array(operations, type=pa.string()),
"rows": pa.array(rows, type=pa.int64()),
"file_size": pa.array(size, type=pa.int64()),
"file_name": pa.array([fp for fp in path], type=pa.string()),
}

if partitioning:
with_operations["partitioning"] = pa.StructArray.from_arrays(
partitioning.values(), names=partitioning.keys()
)

from daft import from_pydict

with_operations = from_pydict(
{
"operation": pa.array(operations, type=pa.string()),
"rows": pa.array(rows, type=pa.int64()),
"file_size": pa.array(size, type=pa.int64()),
"file_name": pa.array([os.path.basename(fp) for fp in path], type=pa.string()),
}
)
# NOTE: We are losing the history of the plan here.
# This is due to the fact that the logical plan of the write_iceberg returns datafiles but we want to return the above data
return with_operations
return from_pydict(with_operations)

@DataframePublicAPI
def write_deltalake(
Expand Down
214 changes: 78 additions & 136 deletions daft/iceberg/iceberg_write.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
import datetime
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any

from daft import Series
from daft import Expression, Series, col
from daft.table import MicroPartition

if TYPE_CHECKING:
import pyarrow as pa
from pyiceberg.io.pyarrow import _TablePartition
from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec
from pyiceberg.partitioning import PartitionField as IcebergPartitionField
from pyiceberg.schema import Schema as IcebergSchema

Check warning on line 12 in daft/iceberg/iceberg_write.py

View check run for this annotation

Codecov / codecov/patch

daft/iceberg/iceberg_write.py#L10-L12

Added lines #L10 - L12 were not covered by tests
from pyiceberg.typedef import Record


def _coerce_pyarrow_table_to_schema(pa_table: "pa.Table", schema: "IcebergSchema") -> "pa.Table":
def add_missing_columns(table: MicroPartition, schema: "pa.Schema") -> MicroPartition:
"""Add null values for columns in the schema that are missing from the table."""

import pyarrow as pa

existing_columns = set(table.column_names())

columns = {}
for name in schema.names:
if name in existing_columns:
columns[name] = table.get_column(name)
else:
columns[name] = Series.from_arrow(pa.nulls(len(table), type=schema.field(name).type), name=name)

return MicroPartition.from_pydict(columns)


def coerce_pyarrow_table_to_schema(pa_table: "pa.Table", schema: "pa.Schema") -> "pa.Table":
"""Coerces a PyArrow table to the supplied schema
1. For each field in `pa_table`, cast it to the field in `input_schema` if one with a matching name
Expand All @@ -27,154 +44,79 @@ def _coerce_pyarrow_table_to_schema(pa_table: "pa.Table", schema: "IcebergSchema
Args:
pa_table (pa.Table): Table to coerce
schema (IcebergSchema): PyIceberg schema to coerce to
schema (pa.Schema): Iceberg schema to coerce to
Returns:
pa.Table: Table with schema == `input_schema`
pa.Table: Table with schema == `schema`
"""
import pyarrow as pa
from pyiceberg.io.pyarrow import schema_to_pyarrow

input_schema = schema_to_pyarrow(schema)

input_schema_names = set(input_schema.names)
input_schema_names = set(schema.names)

# Perform casting of types to provided schema's types
cast_to_schema = [
(input_schema.field(inferred_field.name) if inferred_field.name in input_schema_names else inferred_field)
(schema.field(inferred_field.name) if inferred_field.name in input_schema_names else inferred_field)
for inferred_field in pa_table.schema
]
casted_table = pa_table.cast(pa.schema(cast_to_schema))

# Reorder and pad columns with a null column where necessary
pa_table_column_names = set(casted_table.column_names)
columns = []
for name in input_schema.names:
for name in schema.names:
if name in pa_table_column_names:
columns.append(casted_table[name])
else:
columns.append(pa.nulls(len(casted_table), type=input_schema.field(name).type))
return pa.table(columns, schema=input_schema)


def _determine_partitions(
spec: "IcebergPartitionSpec", schema: "IcebergSchema", arrow_table: "pa.Table"
) -> List["_TablePartition"]:
"""Based on https://github.com/apache/iceberg-python/blob/d8d509ff1bc33040b9f6c90c28ee47ac7437945d/pyiceberg/io/pyarrow.py#L2669"""

import pyarrow as pa
from pyiceberg.io.pyarrow import _get_table_partitions
from pyiceberg.partitioning import PartitionField
from pyiceberg.transforms import Transform
from pyiceberg.types import NestedField

partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]

def partition_transform(array: pa.Array, transform: Transform) -> Optional[pa.Array]:
from pyiceberg.transforms import (
BucketTransform,
DayTransform,
HourTransform,
IdentityTransform,
MonthTransform,
TruncateTransform,
YearTransform,
)

series = Series.from_arrow(array)

transformed = None
if isinstance(transform, IdentityTransform):
transformed = series
elif isinstance(transform, YearTransform):
transformed = series.partitioning.years()
elif isinstance(transform, MonthTransform):
transformed = series.partitioning.months()
elif isinstance(transform, DayTransform):
transformed = series.partitioning.days()
elif isinstance(transform, HourTransform):
transformed = series.partitioning.hours()
elif isinstance(transform, BucketTransform):
n = transform.num_buckets
transformed = series.partitioning.iceberg_bucket(n)
elif isinstance(transform, TruncateTransform):
w = transform.width
transformed = series.partitioning.iceberg_truncate(w)
else:
warnings.warn(f"{transform} not implemented, Please make an issue!")

return transformed.to_arrow() if transformed is not None else None

partition_values_table = pa.table(
{
str(partition.field_id): partition_transform(arrow_table[field.name], partition.transform)
for partition, field in partition_columns
}
columns.append(pa.nulls(len(casted_table), type=schema.field(name).type))

Check warning on line 70 in daft/iceberg/iceberg_write.py

View check run for this annotation

Codecov / codecov/patch

daft/iceberg/iceberg_write.py#L70

Added line #L70 was not covered by tests
return pa.table(columns, schema=schema)


def partition_field_to_expr(field: "IcebergPartitionField", schema: "IcebergSchema") -> Expression:
from pyiceberg.transforms import (
BucketTransform,
DayTransform,
HourTransform,
IdentityTransform,
MonthTransform,
TruncateTransform,
YearTransform,
)

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions


def micropartition_to_arrow_tables(
table: MicroPartition, path: str, schema: "IcebergSchema", partition_spec: "IcebergPartitionSpec"
) -> List[Tuple["pa.Table", str, "Record"]]:
"""
Converts a MicroPartition to a list of Arrow tables with paths, partitioning the data if necessary.
Args:
table (MicroPartition): Table to convert
path (str): Base path to write the table to
schema (IcebergSchema): Schema of the Iceberg table
partition_spec (IcebergPartitionSpec): Iceberg partitioning spec
Returns:
List[Tuple[pa.Table, str, Record]]: List of Arrow tables with their paths and partition records
"""
from pyiceberg.typedef import Record

arrow_table = table.to_arrow()
arrow_table = _coerce_pyarrow_table_to_schema(arrow_table, schema)

if partition_spec.is_unpartitioned():
return [(arrow_table, path, Record())]
partition_col = col(schema.find_field(field.source_id).name)

if isinstance(field.transform, IdentityTransform):
return partition_col
elif isinstance(field.transform, YearTransform):
return partition_col.partitioning.years()
elif isinstance(field.transform, MonthTransform):
return partition_col.partitioning.months()
elif isinstance(field.transform, DayTransform):
return partition_col.partitioning.days()
elif isinstance(field.transform, HourTransform):
return partition_col.partitioning.hours()
elif isinstance(field.transform, BucketTransform):
return partition_col.partitioning.iceberg_bucket(field.transform.num_buckets)
elif isinstance(field.transform, TruncateTransform):
return partition_col.partitioning.iceberg_truncate(field.transform.width)
else:
warnings.warn(f"{field.transform} not implemented, Please make an issue!")
return partition_col

Check warning on line 103 in daft/iceberg/iceberg_write.py

View check run for this annotation

Codecov / codecov/patch

daft/iceberg/iceberg_write.py#L102-L103

Added lines #L102 - L103 were not covered by tests


def to_partition_representation(value: Any):
if value is None:
return None

if isinstance(value, datetime.datetime):
# Convert to microseconds since epoch
return (value - datetime.datetime(1970, 1, 1)) // datetime.timedelta(microseconds=1)
elif isinstance(value, datetime.date):
# Convert to days since epoch
return (value - datetime.date(1970, 1, 1)) // datetime.timedelta(days=1)
elif isinstance(value, datetime.time):
# Convert to microseconds since midnight
return (value.hour * 60 * 60 + value.minute * 60 + value.second) * 1_000_000 + value.microsecond

Check warning on line 118 in daft/iceberg/iceberg_write.py

View check run for this annotation

Codecov / codecov/patch

daft/iceberg/iceberg_write.py#L118

Added line #L118 was not covered by tests
elif isinstance(value, uuid.UUID):
return str(value)

Check warning on line 120 in daft/iceberg/iceberg_write.py

View check run for this annotation

Codecov / codecov/patch

daft/iceberg/iceberg_write.py#L120

Added line #L120 was not covered by tests
else:
partitions = _determine_partitions(partition_spec, schema, arrow_table)

return [
(
partition.arrow_table_partition,
f"{path}/{partition.partition_key.to_path()}",
partition.partition_key.partition,
)
for partition in partitions
]
return value
Loading

0 comments on commit eaff9cc

Please sign in to comment.