Skip to content

Commit

Permalink
[CHORE] Decouple Ray tensor types from main Daft logic (Eventual-Inc#…
Browse files Browse the repository at this point in the history
…2829)

Prior to this PR, Ray data extension tensor types are handled in Daft's
importing logic from Arrow. This PR moves this logic to Ray runner code
when creating a partition set from a Ray dataset.
  • Loading branch information
desmondcheongzx authored Sep 11, 2024
1 parent 7048b97 commit c2d7d08
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 73 deletions.
28 changes: 0 additions & 28 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,6 @@
if TYPE_CHECKING:
import numpy as np

_RAY_DATA_EXTENSIONS_AVAILABLE = True
_TENSOR_EXTENSION_TYPES = []
try:
import ray
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False
else:
_RAY_VERSION = tuple(int(s) for s in ray.__version__.split(".")[0:3])
try:
# Variable-shaped tensor column support was added in Ray 2.1.0.
if _RAY_VERSION >= (2, 2, 0):
from ray.data.extensions import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)

_TENSOR_EXTENSION_TYPES = [ArrowTensorType, ArrowVariableShapedTensorType]
else:
from ray.data.extensions import ArrowTensorType

_TENSOR_EXTENSION_TYPES = [ArrowTensorType]
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False


class TimeUnit:
_timeunit: PyTimeUnit
Expand Down Expand Up @@ -412,10 +388,6 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
key_type=cls.from_arrow_type(arrow_type.key_type),
value_type=cls.from_arrow_type(arrow_type.item_type),
)
elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(arrow_type, tuple(_TENSOR_EXTENSION_TYPES)):
scalar_dtype = cls.from_arrow_type(arrow_type.scalar_type)
shape = arrow_type.shape if isinstance(arrow_type, ArrowTensorType) else None
return cls.tensor(scalar_dtype, shape)
elif isinstance(arrow_type, getattr(pa, "FixedShapeTensorType", ())):
scalar_dtype = cls.from_arrow_type(arrow_type.value_type)
return cls.tensor(scalar_dtype, tuple(arrow_type.shape))
Expand Down
100 changes: 98 additions & 2 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@

import pyarrow as pa

from daft.arrow_utils import ensure_array
from daft.context import execution_config_ctx, get_context
from daft.daft import PyTable as _PyTable
from daft.expressions import ExpressionsProjection
from daft.logical.builder import LogicalPlanBuilder
from daft.plan_scheduler import PhysicalPlanScheduler
from daft.runners.progress_bar import ProgressBar
from daft.series import Series, item_to_series
from daft.table import Table

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,6 +84,42 @@
except ImportError:
_RAY_DATA_ARROW_TENSOR_TYPE_AVAILABLE = False

_RAY_DATA_EXTENSIONS_AVAILABLE = True
_TENSOR_EXTENSION_TYPES = []
try:
import ray
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False
else:
_RAY_VERSION = tuple(int(s) for s in ray.__version__.split(".")[0:3])
try:
# Variable-shaped tensor column support was added in Ray 2.1.0.
if _RAY_VERSION >= (2, 2, 0):
from ray.data.extensions import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)

_TENSOR_EXTENSION_TYPES = [ArrowTensorType, ArrowVariableShapedTensorType]
else:
from ray.data.extensions import ArrowTensorType

_TENSOR_EXTENSION_TYPES = [ArrowTensorType]
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False

_NUMPY_AVAILABLE = True
try:
import numpy as np
except ImportError:
_NUMPY_AVAILABLE = False

_PANDAS_AVAILABLE = True
try:
import pandas as pd
except ImportError:
_PANDAS_AVAILABLE = False


@ray.remote
def _glob_path_into_file_infos(
Expand Down Expand Up @@ -135,11 +175,56 @@ def _make_ray_block_from_micropartition(partition: MicroPartition) -> RayDataset
return partition.to_pylist()


def _series_from_arrow_with_ray_data_extensions(
array: pa.Array | pa.ChunkedArray, name: str = "arrow_series"
) -> Series:
if isinstance(array, pa.Array):
# TODO(desmond): This might be dead code since `ArrayTensorType`s are `numpy.ndarray` under
# the hood and are not instances of `pyarrow.Array`. Should follow up and check if this code
# can be removed.
array = ensure_array(array)
if _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(array.type, ArrowTensorType):
storage_series = _series_from_arrow_with_ray_data_extensions(array.storage, name=name)
series = storage_series.cast(
DataType.fixed_size_list(
_from_arrow_type_with_ray_data_extensions(array.type.scalar_type),
int(np.prod(array.type.shape)),
)
)
return series.cast(DataType.from_arrow_type(array.type))
elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(array.type, ArrowVariableShapedTensorType):
return Series.from_numpy(array.to_numpy(zero_copy_only=False), name=name)
return Series.from_arrow(array, name)


def _micropartition_from_arrow_with_ray_data_extensions(arrow_table: pa.Table) -> MicroPartition:
assert isinstance(arrow_table, pa.Table)
non_native_fields = []
for arrow_field in arrow_table.schema:
dt = _from_arrow_type_with_ray_data_extensions(arrow_field.type)
if dt == DataType.python() or dt._is_tensor_type() or dt._is_fixed_shape_tensor_type():
non_native_fields.append(arrow_field.name)
if non_native_fields:
# If there are any contained Arrow types that are not natively supported, convert each
# series while checking for ray data extension types.
logger.debug("Unsupported Arrow types detected for columns: %s", non_native_fields)
series_dict = dict()
for name, column in zip(arrow_table.column_names, arrow_table.columns):
series = (
_series_from_arrow_with_ray_data_extensions(column, name)
if isinstance(column, (pa.Array, pa.ChunkedArray))
else item_to_series(name, column)
)
series_dict[name] = series._series
return MicroPartition._from_tables([Table._from_pytable(_PyTable.from_pylist_series(series_dict))])
return MicroPartition.from_arrow(arrow_table)


@ray.remote
def _make_daft_partition_from_ray_dataset_blocks(
ray_dataset_block: pa.MicroPartition, daft_schema: Schema
) -> MicroPartition:
return MicroPartition.from_arrow(ray_dataset_block)
return _micropartition_from_arrow_with_ray_data_extensions(ray_dataset_block)


@ray.remote(num_returns=2)
Expand Down Expand Up @@ -259,6 +344,14 @@ def wait(self) -> None:
ray.wait(list(deduped_object_refs))


def _from_arrow_type_with_ray_data_extensions(arrow_type: pa.lib.DataType) -> DataType:
if _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(arrow_type, tuple(_TENSOR_EXTENSION_TYPES)):
scalar_dtype = _from_arrow_type_with_ray_data_extensions(arrow_type.scalar_type)
shape = arrow_type.shape if isinstance(arrow_type, ArrowTensorType) else None
return DataType.tensor(scalar_dtype, shape)
return DataType.from_arrow_type(arrow_type)


class RayRunnerIO(runner_io.RunnerIO):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -300,7 +393,10 @@ def partition_set_from_ray_dataset(
arrow_schema = pa.schema({name: t for name, t in zip(arrow_schema.names, arrow_schema.types)})

daft_schema = Schema._from_field_name_and_types(
[(arrow_field.name, DataType.from_arrow_type(arrow_field.type)) for arrow_field in arrow_schema]
[
(arrow_field.name, _from_arrow_type_with_ray_data_extensions(arrow_field.type))
for arrow_field in arrow_schema
]
)
block_refs = ds.get_internal_block_refs()

Expand Down
22 changes: 1 addition & 21 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,6 @@
from daft.datatype import DataType
from daft.utils import pyarrow_supports_fixed_shape_tensor

_RAY_DATA_EXTENSIONS_AVAILABLE = True
try:
from ray.data.extensions import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False

_NUMPY_AVAILABLE = True
try:
import numpy as np
Expand Down Expand Up @@ -63,18 +54,7 @@ def from_arrow(array: pa.Array | pa.ChunkedArray, name: str = "arrow_series") ->
return Series.from_pylist(array.to_pylist(), name=name, pyobj="force")
elif isinstance(array, pa.Array):
array = ensure_array(array)
if _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(array.type, ArrowTensorType):
storage_series = Series.from_arrow(array.storage, name=name)
series = storage_series.cast(
DataType.fixed_size_list(
DataType.from_arrow_type(array.type.scalar_type),
int(np.prod(array.type.shape)),
)
)
return series.cast(DataType.from_arrow_type(array.type))
elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(array.type, ArrowVariableShapedTensorType):
return Series.from_numpy(array.to_numpy(zero_copy_only=False), name=name)
elif isinstance(array.type, getattr(pa, "FixedShapeTensorType", ())):
if isinstance(array.type, getattr(pa, "FixedShapeTensorType", ())):
series = Series.from_arrow(array.storage, name=name)
return series.cast(DataType.from_arrow_type(array.type))
else:
Expand Down
22 changes: 0 additions & 22 deletions tests/series/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import pyarrow as pa
import pytest
from ray.data.extensions import ArrowTensorArray

from daft import DataType, Series
from daft.context import get_context
Expand Down Expand Up @@ -138,27 +137,6 @@ def test_series_concat_struct_array(chunks) -> None:
counter += 1


@pytest.mark.parametrize("chunks", [1, 2, 3, 10])
def test_series_concat_tensor_array_ray(chunks) -> None:
element_shape = (2, 2)
num_elements_per_tensor = np.prod(element_shape)
chunk_size = 3
chunk_shape = (chunk_size,) + element_shape
chunks = [
np.arange(
i * chunk_size * num_elements_per_tensor, (i + 1) * chunk_size * num_elements_per_tensor, dtype=np.int64
).reshape(chunk_shape)
for i in range(chunks)
]
series = [Series.from_arrow(ArrowTensorArray.from_numpy(chunk)) for chunk in chunks]

concated = Series.concat(series)

assert concated.datatype() == DataType.tensor(DataType.int64(), element_shape)
expected = [chunk[i] for chunk in chunks for i in range(len(chunk))]
np.testing.assert_equal(concated.to_pylist(), expected)


@pytest.mark.skipif(
not pyarrow_supports_fixed_shape_tensor(),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
Expand Down

0 comments on commit c2d7d08

Please sign in to comment.