from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Any, Iterable, List
import ray.data
import ray.data.grouped_data
from ray.data._internal.compute import ComputeStrategy
from ray.data.block import DataBatch
from ray.data.expressions import Expr
from ray.util.annotations import AnnotationType
from xpark.dataset.expressions import ExprVisitor
from xpark.dataset.filters.dedup.dedup import DedupOp
from xpark.dataset.utils import copy_sig, wrap_ray_doc
if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa
from ray import ObjectRef
class Schema(ray.data.Schema):
__doc__ = wrap_ray_doc(ray.data.Schema.__doc__)
[docs]
class GroupedData:
"""Represents a grouped dataset created by calling ``Dataset.groupby()``.
The actual groupby is deferred until an aggregation is applied.
"""
def __init__(self, ray_grouped_data: ray.data.grouped_data.GroupedData):
"""Construct a dataset grouped by key (internal API).
The constructor is not part of the GroupedData API.
Use the ``Dataset.groupby()`` method to construct one.
"""
self._ray_grouped_data = ray_grouped_data
def _is_public_member(o):
return inspect.isfunction(o) and getattr(o, "_annotated_type", None) == AnnotationType.PUBLIC_API
def _wrap_public_member(fn):
@copy_sig(fn)
def _wrapper(self, *args, **kwargs) -> Dataset:
return Dataset(fn(self._ray_grouped_data, *args, **kwargs))
return _wrapper
for name, method in inspect.getmembers(ray.data.grouped_data.GroupedData, predicate=_is_public_member):
assert method.__annotations__["return"] in [ray.data.Dataset, "Dataset"]
setattr(GroupedData, name, _wrap_public_member(method))
[docs]
class Dataset:
"""Construct a :class:`Dataset` from a `ray.data.Dataset <https://docs.ray.io/en/\
latest/data/api/dataset.html>`_. In general, you do not need to
manually call this constructor. Please use the :ref:`read-api` to construct it.
Args:
ray_dataset: An instance of ``ray.data.Dataset``.
"""
def __init__(self, ray_dataset: ray.data.Dataset):
self._dataset = ray_dataset
[docs]
@copy_sig(ray.data.Dataset.show)
def show(self, *args, **kwargs) -> None:
return self._dataset.show(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.take)
def take(self, *args, **kwargs) -> list[dict[str, Any]]:
return self._dataset.take(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.take_all)
def take_all(self, *args, **kwargs) -> list[dict[str, Any]]:
return self._dataset.take_all(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.take_batch)
def take_batch(self, *args, **kwargs) -> DataBatch:
return self._dataset.take_batch(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.iter_batches)
def iter_batches(self, *args, **kwargs) -> Iterable[DataBatch]:
return self._dataset.iter_batches(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.count)
def count(self) -> int:
return self._dataset.count()
[docs]
def copy(self, deep_copy: bool = False) -> Dataset:
"""Copy the Dataset."""
return Dataset(ray.data.Dataset.copy(self._dataset, _deep_copy=deep_copy))
[docs]
@copy_sig(ray.data.Dataset.materialize)
def materialize(self) -> Dataset:
return Dataset(self._dataset.materialize())
[docs]
@copy_sig(ray.data.Dataset.unique)
def unique(self, *args, **kwargs) -> list[Any]:
return self._dataset.unique(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.aggregate)
def aggregate(self, *args, **kwargs) -> Any | dict[str, Any]:
return self._dataset.aggregate(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.sum)
def sum(self, *args, **kwargs) -> Any | dict[str, Any]:
return self._dataset.sum(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.min)
def min(self, *args, **kwargs) -> Any | dict[str, Any]:
return self._dataset.min(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.max)
def max(self, *args, **kwargs) -> Any | dict[str, Any]:
return self._dataset.max(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.mean)
def mean(self, *args, **kwargs) -> Any | dict[str, Any]:
return self._dataset.mean(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.std)
def std(self, *args, **kwargs) -> Any | dict[str, Any]:
return self._dataset.std(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.sort)
def sort(self, *args, **kwargs) -> Dataset:
return Dataset(self._dataset.sort(*args, **kwargs))
[docs]
@copy_sig(ray.data.Dataset.map)
def map(self, *args, **kwargs) -> Dataset:
return Dataset(self._dataset.map(*args, **kwargs))
[docs]
@copy_sig(ray.data.Dataset.map_batches)
def map_batches(self, *args, **kwargs) -> Dataset:
return Dataset(self._dataset.map_batches(*args, **kwargs))
[docs]
@copy_sig(ray.data.Dataset.repartition)
def repartition(self, *args, **kwargs) -> Dataset:
return Dataset(self._dataset.repartition(*args, **kwargs))
[docs]
@copy_sig(ray.data.Dataset.limit)
def limit(self, *args, **kwargs) -> Dataset:
return Dataset(self._dataset.limit(*args, **kwargs))
[docs]
@copy_sig(ray.data.Dataset.groupby)
def groupby(self, *args, **kwargs) -> GroupedData:
return GroupedData(self._dataset.groupby(*args, **kwargs))
[docs]
@copy_sig(ray.data.Dataset.select_columns)
def select_columns(self, *args, **kwargs) -> Dataset:
return Dataset(self._dataset.select_columns(*args, **kwargs))
[docs]
@copy_sig(ray.data.Dataset.drop_columns)
def drop_columns(self, *args, **kwargs) -> Dataset:
return Dataset(self._dataset.drop_columns(*args, **kwargs))
[docs]
def with_column(
self,
column_name: str,
expr: Expr,
**ray_remote_args,
) -> Dataset:
"""
Add a new column to the dataset via an expression.
This method allows you to add a new column to a dataset by applying an
expression. The expression can be composed of existing columns, literals,
and user-defined functions (UDFs).
Examples:
>>> from xpark.dataset import from_range
>>> from xpark.dataset.expressions import col
>>> ds = from_range(100)
>>> # Add a new column 'id_2' by multiplying 'id' by 2.
>>> ds.with_column("id_2", col("id") * 2).show(2)
{'id': 0, 'id_2': 0}
{'id': 1, 'id_2': 2}
>>> # Using a UDF with with_column
>>> from xpark.dataset.datatype import DataType
>>> from xpark.dataset.expressions import udf
>>> import pyarrow.compute as pc
>>>
>>> @udf(return_dtype=DataType.int32())
... def add_one(column):
... return pc.add(column, 1)
>>>
>>> ds.with_column("id_plus_one", add_one(col("id"))).show(2)
{'id': 0, 'id_plus_one': 1}
{'id': 1, 'id_plus_one': 2}
>>> # Using an actor UDF with with_column
>>> @udf(return_dtype=DataType.int32(), num_workers={"CPU": 1})
... class Add:
... def __init__(self, value):
... self.value = value
...
... def __call__(self, input: pyarrow.Array) -> pyarrow.Array:
... return pc.add(input, self.value)
>>>
>>> ds_actor_udf = ds.with_column("id_plus_two", Add(2).with_column(col("id")))
{'id': 0, 'id_plus_one': 1, 'id_plus_two': 2}
{'id': 1, 'id_plus_one': 2, 'id_plus_two': 3}
Args:
column_name: The name of the new column.
expr: An expression that defines the new column values.
**ray_remote_args: Additional resource requirements to request from
Ray for the map tasks (e.g., `num_gpus=1`).
Returns:
A new dataset with the added column evaluated via the expression.
"""
return Dataset(ExprVisitor(self._dataset).transform(column_name=column_name, expr=expr, **ray_remote_args))
[docs]
def filter(
self,
expr: Expr | DedupOp,
compute: ComputeStrategy | None = None,
**ray_remote_args,
) -> Dataset:
"""Filter out rows that don't satisfy the given predicate.
You can use either a function or a callable class or an expression to
perform the transformation.
For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses
stateful Ray actors. For more information, see
:ref:`Stateful Transforms <stateful_transforms>`.
Examples:
>>> from xpark.dataset import from_range
>>> from xpark.dataset.expressions import col
>>> ds = from_range(100)
>>> # Using predicate expressions (preferred)
>>> ds.filter(expr=(col("id") > 10) & (col("id") < 20)).take_all()
[{'id': 11}, {'id': 12}, {'id': 13}, {'id': 14}, {'id': 15}, {'id': 16}, {'id': 17}, {'id': 18}, {'id': 19}]
Time complexity: O(dataset size / parallelism)
Args:
expr: An expression that represents a predicate (boolean condition) for filtering.
Can be either a predicate expression from `ray.data.expressions` or a xpark filter.
compute: The compute strategy to use for the map operation.
* If ``compute`` is not specified for a function, will use ``ray.data.TaskPoolStrategy()`` to launch concurrent tasks based on the available resources and number of input blocks.
* Use ``ray.data.TaskPoolStrategy(size=n)`` to launch at most ``n`` concurrent Ray tasks.
* If ``compute`` is not specified for a callable class, will use ``ray.data.ActorPoolStrategy(min_size=1, max_size=None)`` to launch an autoscaling actor pool from 1 to unlimited workers.
* Use ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed size actor pool of ``n`` workers.
* Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` to use an autoscaling actor pool from ``m`` to ``n`` workers.
* Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n, initial_size=initial)`` to use an autoscaling actor pool from ``m`` to ``n`` workers, with an initial size of ``initial``.
ray_remote_args: Additional resource requirements to request from
Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
:func:`ray.remote` for details.
"""
# TODO(baobliu): We should support enhanced expr for filter.
if isinstance(expr, DedupOp):
return Dataset(expr.plan(self._dataset, compute=compute, **ray_remote_args))
else:
return Dataset(self._dataset.filter(expr=expr, compute=compute, **ray_remote_args))
[docs]
@copy_sig(ray.data.Dataset.write_json)
def write_json(self, *args, **kwargs) -> None:
self._dataset.write_json(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.write_csv)
def write_csv(self, *args, **kwargs) -> None:
self._dataset.write_csv(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.write_numpy)
def write_numpy(self, *args, **kwargs) -> None:
self._dataset.write_numpy(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.write_parquet)
def write_parquet(self, *args, **kwargs) -> None:
self._dataset.write_parquet(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.write_iceberg)
def write_iceberg(self, *args, **kwargs) -> None:
self._dataset.write_iceberg(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.write_lance)
def write_lance(self, *args, **kwargs) -> None:
self._dataset.write_lance(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.to_pandas)
def to_pandas(self, *args, **kwargs) -> pd.DataFrame:
return self._dataset.to_pandas(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.to_arrow_refs)
def to_arrow_refs(self, *args, **kwargs) -> List[ObjectRef["pa.Table"]]:
return self._dataset.to_arrow_refs(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.schema)
def schema(self, *args, **kwargs) -> Schema | None:
return self._dataset.schema(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.columns)
def columns(self, *args, **kwargs) -> list[str] | None:
return self._dataset.columns(*args, **kwargs)
[docs]
@copy_sig(ray.data.Dataset.size_bytes)
def size_bytes(self, *args, **kwargs) -> int:
return self._dataset.size_bytes(*args, **kwargs)
def __repr__(self) -> str:
return self._dataset._plan.get_plan_as_string(self.__class__) # type: ignore[arg-type]
def __str__(self) -> str:
return repr(self)