Source code for xpark.dataset.dataset

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Iterable, List

import ray.data
import ray.data.grouped_data
from ray.data._internal.compute import ComputeStrategy
from ray.data.block import DataBatch
from ray.data.expressions import Expr
from ray.util.annotations import AnnotationType

from xpark.dataset.expressions import ExprVisitor
from xpark.dataset.filters.dedup.dedup import DedupOp
from xpark.dataset.utils import copy_sig, wrap_ray_doc

if TYPE_CHECKING:
    import pandas as pd
    import pyarrow as pa
    from ray import ObjectRef


class Schema(ray.data.Schema):
    __doc__ = wrap_ray_doc(ray.data.Schema.__doc__)


[docs] class GroupedData: """Represents a grouped dataset created by calling ``Dataset.groupby()``. The actual groupby is deferred until an aggregation is applied. """ def __init__(self, ray_grouped_data: ray.data.grouped_data.GroupedData): """Construct a dataset grouped by key (internal API). The constructor is not part of the GroupedData API. Use the ``Dataset.groupby()`` method to construct one. """ self._ray_grouped_data = ray_grouped_data
def _is_public_member(o): return inspect.isfunction(o) and getattr(o, "_annotated_type", None) == AnnotationType.PUBLIC_API def _wrap_public_member(fn): @copy_sig(fn) def _wrapper(self, *args, **kwargs) -> Dataset: return Dataset(fn(self._ray_grouped_data, *args, **kwargs)) return _wrapper for name, method in inspect.getmembers(ray.data.grouped_data.GroupedData, predicate=_is_public_member): assert method.__annotations__["return"] in [ray.data.Dataset, "Dataset"] setattr(GroupedData, name, _wrap_public_member(method))
[docs] class Dataset: """Construct a :class:`Dataset` from a `ray.data.Dataset <https://docs.ray.io/en/\ latest/data/api/dataset.html>`_. In general, you do not need to manually call this constructor. Please use the :ref:`read-api` to construct it. Args: ray_dataset: An instance of ``ray.data.Dataset``. """ def __init__(self, ray_dataset: ray.data.Dataset): self._dataset = ray_dataset
[docs] @copy_sig(ray.data.Dataset.show) def show(self, *args, **kwargs) -> None: return self._dataset.show(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.take) def take(self, *args, **kwargs) -> list[dict[str, Any]]: return self._dataset.take(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.take_all) def take_all(self, *args, **kwargs) -> list[dict[str, Any]]: return self._dataset.take_all(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.take_batch) def take_batch(self, *args, **kwargs) -> DataBatch: return self._dataset.take_batch(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.iter_batches) def iter_batches(self, *args, **kwargs) -> Iterable[DataBatch]: return self._dataset.iter_batches(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.count) def count(self) -> int: return self._dataset.count()
[docs] def copy(self, deep_copy: bool = False) -> Dataset: """Copy the Dataset.""" return Dataset(ray.data.Dataset.copy(self._dataset, _deep_copy=deep_copy))
[docs] @copy_sig(ray.data.Dataset.materialize) def materialize(self) -> Dataset: return Dataset(self._dataset.materialize())
[docs] @copy_sig(ray.data.Dataset.unique) def unique(self, *args, **kwargs) -> list[Any]: return self._dataset.unique(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.aggregate) def aggregate(self, *args, **kwargs) -> Any | dict[str, Any]: return self._dataset.aggregate(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.sum) def sum(self, *args, **kwargs) -> Any | dict[str, Any]: return self._dataset.sum(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.min) def min(self, *args, **kwargs) -> Any | dict[str, Any]: return self._dataset.min(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.max) def max(self, *args, **kwargs) -> Any | dict[str, Any]: return self._dataset.max(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.mean) def mean(self, *args, **kwargs) -> Any | dict[str, Any]: return self._dataset.mean(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.std) def std(self, *args, **kwargs) -> Any | dict[str, Any]: return self._dataset.std(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.sort) def sort(self, *args, **kwargs) -> Dataset: return Dataset(self._dataset.sort(*args, **kwargs))
[docs] @copy_sig(ray.data.Dataset.map) def map(self, *args, **kwargs) -> Dataset: return Dataset(self._dataset.map(*args, **kwargs))
[docs] @copy_sig(ray.data.Dataset.map_batches) def map_batches(self, *args, **kwargs) -> Dataset: return Dataset(self._dataset.map_batches(*args, **kwargs))
[docs] @copy_sig(ray.data.Dataset.repartition) def repartition(self, *args, **kwargs) -> Dataset: return Dataset(self._dataset.repartition(*args, **kwargs))
[docs] @copy_sig(ray.data.Dataset.limit) def limit(self, *args, **kwargs) -> Dataset: return Dataset(self._dataset.limit(*args, **kwargs))
[docs] @copy_sig(ray.data.Dataset.groupby) def groupby(self, *args, **kwargs) -> GroupedData: return GroupedData(self._dataset.groupby(*args, **kwargs))
[docs] @copy_sig(ray.data.Dataset.select_columns) def select_columns(self, *args, **kwargs) -> Dataset: return Dataset(self._dataset.select_columns(*args, **kwargs))
[docs] @copy_sig(ray.data.Dataset.drop_columns) def drop_columns(self, *args, **kwargs) -> Dataset: return Dataset(self._dataset.drop_columns(*args, **kwargs))
[docs] def with_column( self, column_name: str, expr: Expr, **ray_remote_args, ) -> Dataset: """ Add a new column to the dataset via an expression. This method allows you to add a new column to a dataset by applying an expression. The expression can be composed of existing columns, literals, and user-defined functions (UDFs). Examples: >>> from xpark.dataset import from_range >>> from xpark.dataset.expressions import col >>> ds = from_range(100) >>> # Add a new column 'id_2' by multiplying 'id' by 2. >>> ds.with_column("id_2", col("id") * 2).show(2) {'id': 0, 'id_2': 0} {'id': 1, 'id_2': 2} >>> # Using a UDF with with_column >>> from xpark.dataset.datatype import DataType >>> from xpark.dataset.expressions import udf >>> import pyarrow.compute as pc >>> >>> @udf(return_dtype=DataType.int32()) ... def add_one(column): ... return pc.add(column, 1) >>> >>> ds.with_column("id_plus_one", add_one(col("id"))).show(2) {'id': 0, 'id_plus_one': 1} {'id': 1, 'id_plus_one': 2} >>> # Using an actor UDF with with_column >>> @udf(return_dtype=DataType.int32(), num_workers={"CPU": 1}) ... class Add: ... def __init__(self, value): ... self.value = value ... ... def __call__(self, input: pyarrow.Array) -> pyarrow.Array: ... return pc.add(input, self.value) >>> >>> ds_actor_udf = ds.with_column("id_plus_two", Add(2).with_column(col("id"))) {'id': 0, 'id_plus_one': 1, 'id_plus_two': 2} {'id': 1, 'id_plus_one': 2, 'id_plus_two': 3} Args: column_name: The name of the new column. expr: An expression that defines the new column values. **ray_remote_args: Additional resource requirements to request from Ray for the map tasks (e.g., `num_gpus=1`). Returns: A new dataset with the added column evaluated via the expression. """ return Dataset(ExprVisitor(self._dataset).transform(column_name=column_name, expr=expr, **ray_remote_args))
[docs] def filter( self, expr: Expr | DedupOp, compute: ComputeStrategy | None = None, **ray_remote_args, ) -> Dataset: """Filter out rows that don't satisfy the given predicate. You can use either a function or a callable class or an expression to perform the transformation. For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses stateful Ray actors. For more information, see :ref:`Stateful Transforms <stateful_transforms>`. Examples: >>> from xpark.dataset import from_range >>> from xpark.dataset.expressions import col >>> ds = from_range(100) >>> # Using predicate expressions (preferred) >>> ds.filter(expr=(col("id") > 10) & (col("id") < 20)).take_all() [{'id': 11}, {'id': 12}, {'id': 13}, {'id': 14}, {'id': 15}, {'id': 16}, {'id': 17}, {'id': 18}, {'id': 19}] Time complexity: O(dataset size / parallelism) Args: expr: An expression that represents a predicate (boolean condition) for filtering. Can be either a predicate expression from `ray.data.expressions` or a xpark filter. compute: The compute strategy to use for the map operation. * If ``compute`` is not specified for a function, will use ``ray.data.TaskPoolStrategy()`` to launch concurrent tasks based on the available resources and number of input blocks. * Use ``ray.data.TaskPoolStrategy(size=n)`` to launch at most ``n`` concurrent Ray tasks. * If ``compute`` is not specified for a callable class, will use ``ray.data.ActorPoolStrategy(min_size=1, max_size=None)`` to launch an autoscaling actor pool from 1 to unlimited workers. * Use ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed size actor pool of ``n`` workers. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` to use an autoscaling actor pool from ``m`` to ``n`` workers. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n, initial_size=initial)`` to use an autoscaling actor pool from ``m`` to ``n`` workers, with an initial size of ``initial``. ray_remote_args: Additional resource requirements to request from Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. """ # TODO(baobliu): We should support enhanced expr for filter. if isinstance(expr, DedupOp): return Dataset(expr.plan(self._dataset, compute=compute, **ray_remote_args)) else: return Dataset(self._dataset.filter(expr=expr, compute=compute, **ray_remote_args))
[docs] @copy_sig(ray.data.Dataset.write_json) def write_json(self, *args, **kwargs) -> None: self._dataset.write_json(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.write_csv) def write_csv(self, *args, **kwargs) -> None: self._dataset.write_csv(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.write_numpy) def write_numpy(self, *args, **kwargs) -> None: self._dataset.write_numpy(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.write_parquet) def write_parquet(self, *args, **kwargs) -> None: self._dataset.write_parquet(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.write_iceberg) def write_iceberg(self, *args, **kwargs) -> None: self._dataset.write_iceberg(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.write_lance) def write_lance(self, *args, **kwargs) -> None: self._dataset.write_lance(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.to_pandas) def to_pandas(self, *args, **kwargs) -> pd.DataFrame: return self._dataset.to_pandas(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.to_arrow_refs) def to_arrow_refs(self, *args, **kwargs) -> List[ObjectRef["pa.Table"]]: return self._dataset.to_arrow_refs(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.schema) def schema(self, *args, **kwargs) -> Schema | None: return self._dataset.schema(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.columns) def columns(self, *args, **kwargs) -> list[str] | None: return self._dataset.columns(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.size_bytes) def size_bytes(self, *args, **kwargs) -> int: return self._dataset.size_bytes(*args, **kwargs)
[docs] @copy_sig(ray.data.Dataset.input_files) def input_files(self, *args, **kwargs) -> list[str]: return self._dataset.input_files(*args, **kwargs)
def __repr__(self) -> str: return self._dataset._plan.get_plan_as_string(self.__class__) # type: ignore[arg-type] def __str__(self) -> str: return repr(self)