Source code for xpark.dataset.expressions

from __future__ import annotations

import functools
import inspect
import logging
from collections.abc import Callable
from dataclasses import dataclass, replace
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Iterator,
    Literal,
    ParamSpec,
    Protocol,
    Self,
    TypedDict,
    TypeVar,
    Unpack,
    cast,
)

import pyarrow as pa
import ray
import ray.data.expressions
from ray.data.block import BatchColumn, DataBatch
from ray.data.datatype import DataType
from ray.data.expressions import BinaryExpr, ColumnExpr, DownloadExpr, Expr, LiteralExpr, StarExpr, UDFExpr, UnaryExpr

from xpark.dataset.constants import DEFAULT_MAP_WORKER_RAY_REMOTE_ARGS
from xpark.dataset.types import MapWorkerType
from xpark.dataset.utils import Count, copy_sig, deep_update

if TYPE_CHECKING:
    from xpark.dataset.namespace_expressions.datetime_namespace import _DatetimeNamespace
    from xpark.dataset.namespace_expressions.list_namespace import _ListNamespace
    from xpark.dataset.namespace_expressions.string_namespace import _StringNamespace
    from xpark.dataset.namespace_expressions.struct_namespace import _StructNamespace

logger = logging.getLogger("ray")

P = ParamSpec("P")
U = TypeVar("U", covariant=True)


[docs] class ExprUDFOptions(TypedDict): batch_size: Literal["default"] | int | None num_workers: dict[MapWorkerType, tuple[int, int] | int] | None worker_ray_remote_args: dict[MapWorkerType, dict] | None
@dataclass class ExprReturn: index: int @dataclass(frozen=True, eq=False) class ExtendedUDFExpr(UDFExpr): column_name: ExprReturn fn: Callable[..., BatchColumn] | BatchColumnClassProtocol[..., BatchColumn] init_args: tuple[Any, ...] init_kwargs: dict[str, Any] batch_size: Literal["default"] | int | None = None num_workers: dict[MapWorkerType, tuple[int, int] | int] | None = None worker_ray_remote_args: dict[MapWorkerType, dict] | None = None class BatchColumnClassProtocol(Protocol[P, U]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> U | Iterator[U] | Awaitable[U]: ... class ExprUDFProtocol(Protocol[P]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Expr: ... def options(self, **kwargs: Unpack[ExprUDFOptions]) -> Self: ... def with_column(self, *args, **kwargs) -> Expr: ... class _ExprUDFMetadata: def __init__( self, wrapped: Callable[P, BatchColumn] | type[BatchColumnClassProtocol[P, BatchColumn]], return_dtype: DataType, options: ExprUDFOptions, ): self.wrapped = wrapped self.return_dtype = return_dtype self.options = options def _args_kwargs_to_expr(args, kwargs) -> tuple[list[Expr], dict[str, Expr]]: # Convert arguments to expressions if they aren't already expr_args = [] for arg in args: if isinstance(arg, Expr): expr_args.append(arg) else: expr_args.append(LiteralExpr(arg)) expr_kwargs = {} for k, v in kwargs.items(): if isinstance(v, Expr): expr_kwargs[k] = v else: expr_kwargs[k] = LiteralExpr(v) return expr_args, expr_kwargs def _validate_worker_options( name: str, supported_worker_types: set[MapWorkerType], num_workers: dict[MapWorkerType, tuple[int, int] | int] | None = None, worker_ray_remote_args: dict[MapWorkerType, dict] | None = None, ): if not supported_worker_types: raise RuntimeError(f"The {name} does not support any worker types.") for key, value in dict(num_workers=num_workers, worker_ray_remote_args=worker_ray_remote_args).items(): if value is not None and not isinstance(value, dict): raise TypeError(f"The {key} must be a dict, got a {type(value)} instead.") if value and not supported_worker_types.issuperset(value.keys()): raise ValueError( f"The {name} only supports the following worker types: {supported_worker_types}, " f"but the specified {key} is: {set(value.keys())}" ) def _wrap_class(metadata: _ExprUDFMetadata) -> type[ExprUDFProtocol[P]]: cls: type = cast(type, metadata.wrapped) if not issubclass(cls, Callable): # type: ignore[arg-type] raise TypeError(f"The class `{cls}` should be a callable class.") signature = inspect.signature(cls.__init__) class _ExprActor(ExprUDFProtocol[P]): __metadata__ = metadata @copy_sig(cls.__init__) # type: ignore[misc] def __init__(self, *args, **kwargs) -> None: # Check args and kwargs before execution. signature.bind(self, *args, **kwargs) if any(isinstance(arg, Expr) for arg in args): raise TypeError("Can't pass Expr to __init__ arguments.") if any(isinstance(arg, Expr) for arg in kwargs.values()): raise TypeError("Can't pass Expr to __init__ keyword arguments.") self._args = args self._kwargs = kwargs self._options: ExprUDFOptions = self.__metadata__.options.copy() def options(self, **kwargs: Unpack[ExprUDFOptions]) -> Self: for k, v in kwargs.items(): if k not in self._options: raise ValueError( f"Unknown option for actor: {k}, available options: {ExprUDFOptions.__annotations__.keys()}" ) self._options.update(kwargs) return self @copy_sig(cls.__call__) def __call__(self, *args, **kwargs) -> Expr: # Validate worker options. _validate_worker_options( cls.__name__, supported_worker_types={"CPU", "GPU", "IO"}, num_workers=self._options.get("num_workers"), worker_ray_remote_args=self._options.get("worker_ray_remote_args"), ) expr_args, expr_kwargs = _args_kwargs_to_expr(args, kwargs) # The column_name has a dynamic value, so we set it to -1 here, and it will be updated in ExprVisitor column_name = ExprReturn(-1) _BatchProcessor: type # Makes mypy happy if inspect.iscoroutinefunction(cls.__call__): # Async actor class _AsyncBatchProcessor(cls): async def __call__(self, batch: pa.Table) -> pa.Table: from ray.data._internal.planner.plan_expression.expression_evaluator import eval_expr eval_args = [eval_expr(arg, batch) for arg in expr_args] eval_kwargs = {k: eval_expr(v, batch) for k, v in expr_kwargs.items()} new_column = await super().__call__(*eval_args, **eval_kwargs) return batch.append_column(str(column_name), new_column) _BatchProcessor = _AsyncBatchProcessor else: # Sync actor class _SyncBatchProcessor(cls): def __call__(self, batch: pa.Table) -> pa.Table: from ray.data._internal.planner.plan_expression.expression_evaluator import eval_expr eval_args = [eval_expr(arg, batch) for arg in expr_args] eval_kwargs = {k: eval_expr(v, batch) for k, v in expr_kwargs.items()} new_column = super().__call__(*eval_args, **eval_kwargs) return batch.append_column(str(column_name), new_column) _BatchProcessor = _SyncBatchProcessor _BatchProcessor.__module__ = cls.__module__ _BatchProcessor.__name__ = cls.__name__ _BatchProcessor.__qualname__ = cls.__qualname__ _BatchProcessor.__doc__ = cls.__doc__ return ExtendedUDFExpr( column_name=column_name, data_type=self.__metadata__.return_dtype, fn=_BatchProcessor, args=expr_args, kwargs=expr_kwargs, init_args=self._args, init_kwargs=self._kwargs, **self._options, ) with_column = __call__ _ExprActor.__module__ = cls.__module__ _ExprActor.__name__ = cls.__name__ _ExprActor.__qualname__ = cls.__qualname__ _ExprActor.__doc__ = cls.__doc__ return _ExprActor class ExprTask(ExprUDFProtocol[P]): def __init__(self, metadata: _ExprUDFMetadata) -> None: self.metadata = metadata self._signature = inspect.signature(metadata.wrapped) self._options: ExprUDFOptions = self.metadata.options.copy() def options(self, **kwargs: Unpack[ExprUDFOptions]) -> Self: for k, v in kwargs.items(): if k not in self._options: raise ValueError( f"Unknown option for task: {k}, available options: {ExprUDFOptions.__annotations__.keys()}" ) self._options.update(kwargs) return self def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Expr: # Check args and kwargs before execution. self._signature.bind(*args, **kwargs) # Validate worker options. _validate_worker_options( self.metadata.wrapped.__name__, supported_worker_types={"CPU", "GPU", "IO"}, num_workers=self._options.get("num_workers"), worker_ray_remote_args=self._options.get("worker_ray_remote_args"), ) # Convert arguments to expressions if they aren't already expr_args, expr_kwargs = _args_kwargs_to_expr(args, kwargs) return ExtendedUDFExpr( column_name=ExprReturn(-1), data_type=self.metadata.return_dtype, fn=self.metadata.wrapped, args=expr_args, kwargs=expr_kwargs, init_args=tuple(), init_kwargs={}, **self._options, ) with_column = __call__ def __repr__(self): return f"{self.__class__.__name__}({self.metadata.wrapped.__name__})" def _wrap_function(fn: Callable[P, BatchColumn], task: ExprTask[P]) -> ExprUDFProtocol[P]: """Wrapping a function into a function, rather than using a callable, is better for documentation.""" @copy_sig(fn) def _wrapped(*args: P.args, **kwargs: P.kwargs) -> Expr: return task(*args, **kwargs) _wrapped.__metadata__ = task.metadata _wrapped.options = task.options _wrapped.with_column = task.__call__ return cast(ExprUDFProtocol[P], _wrapped) def _batch_column_task(fn: Callable[..., BatchColumn], return_column: ExprReturn) -> Callable[..., DataBatch]: @functools.wraps(fn) def _wrapper(batch: pa.Table, *args, **kwargs) -> pa.Table: from ray.data._internal.planner.plan_expression.expression_evaluator import eval_expr eval_args = [eval_expr(arg, batch) for arg in args] eval_kwargs = {k: eval_expr(v, batch) for k, v in kwargs.items()} new_column = fn(*eval_args, **eval_kwargs) return batch.append_column(str(return_column), new_column) return _wrapper def _hybrid_compute( num_workers: dict[MapWorkerType, tuple[int, int] | int], worker_ray_remote_args: dict[MapWorkerType, dict] | None = None, ray_remote_args_fn: Callable[[], dict[str, Any]] | None = None, ) -> tuple[int | tuple[int, int] | tuple[int, int, int], Callable[[], dict[str, Any]] | None, dict[str, Any]]: """Support hybrid computation, e.g. CPU + GPU However, Ray only provide ``ray_remote_args_fn``, if the actor pool is downscaled and upscaled again, then we don't know which actor resources should be used. # TODO(baobliu): Support range for num_workers. # TODO(baobliu): Handle actor pool autoscale. """ logger.warning("Hybrid computation is experimental.") num_cpu_workers = num_workers.get("CPU", 0) num_gpu_workers = num_workers.get("GPU", 0) num_io_workers = num_workers.get("IO", 0) if not isinstance(num_cpu_workers, int): raise TypeError("Hybrid computation num_workers CPU must be an int") if not isinstance(num_gpu_workers, int): raise TypeError("Hybrid computation num_workers GPU must be an int") if not isinstance(num_io_workers, int): raise TypeError("Hybrid computation num_workers IO must be an int") worker_ray_remote_args = worker_ray_remote_args or {} cpu_worker_ray_remote_args = deep_update( DEFAULT_MAP_WORKER_RAY_REMOTE_ARGS["CPU"], worker_ray_remote_args.get("CPU", {}) ) gpu_worker_ray_remote_args = deep_update( DEFAULT_MAP_WORKER_RAY_REMOTE_ARGS["GPU"], worker_ray_remote_args.get("GPU", {}) ) io_worker_ray_remote_args = deep_update( DEFAULT_MAP_WORKER_RAY_REMOTE_ARGS["IO"], worker_ray_remote_args.get("IO", {}) ) _overwrite = ray_remote_args_fn or (lambda: {}) def _resource_gen(): while True: # Walkaround for Ray 2.49 for _ in range(num_gpu_workers): yield deep_update(gpu_worker_ray_remote_args, _overwrite()) for _ in range(num_cpu_workers): yield deep_update(cpu_worker_ray_remote_args, _overwrite()) for _ in range(num_io_workers): yield deep_update(io_worker_ray_remote_args, _overwrite()) resource_gen = _resource_gen() def _ray_remote_args_fn(): return next(resource_gen) return num_cpu_workers + num_gpu_workers + num_io_workers, _ray_remote_args_fn, {} def _simple_compute( num_workers: dict[MapWorkerType, tuple[int, int] | int], worker_ray_remote_args: dict[MapWorkerType, dict] | None = None, ray_remote_args_fn: Callable[[], dict[str, Any]] | None = None, ) -> tuple[int | tuple[int, int] | tuple[int, int, int], Callable[[], dict[str, Any]] | None, dict[str, Any]]: worker_type, concurrency = next(iter(num_workers.items())) worker_ray_remote_args = worker_ray_remote_args or {} return ( concurrency, ray_remote_args_fn, deep_update( DEFAULT_MAP_WORKER_RAY_REMOTE_ARGS[worker_type], worker_ray_remote_args.get(worker_type, {}), ), ) class GenericExprVisitor(object): def visit(self, expr: Expr) -> Expr: """Visit a node.""" if isinstance(expr, UDFExpr): # We have to update inplace. expr.args[:] = [self.visit(arg) for arg in expr.args] expr.kwargs.update({k: self.visit(v) for k, v in expr.kwargs.items()}) if isinstance(expr, BinaryExpr): return replace(expr, left=self.visit(expr.left), right=self.visit(expr.right)) if isinstance(expr, UnaryExpr): return replace(expr, operand=self.visit(expr.operand)) return expr class ExprVisitor(GenericExprVisitor): def __init__(self, dataset: ray.data.Dataset): self._dataset = dataset self._return_index = Count() self._dropped_indexes: set[int] = set() def visit(self, expr: Expr) -> Expr: current_index = self._return_index.last_value expr = super().visit(expr) # Then process current node method = "visit_" + expr.__class__.__name__ visitor: Callable[[Expr, set[int]], Expr] = getattr(self, method, self.generic_visit) return visitor(expr, set(range(current_index + 1, self._return_index.last_value + 1)) - self._dropped_indexes) def generic_visit(self, expr: Expr, _ref_indexes: set[int]) -> Expr: return expr def visit_ExtendedUDFExpr(self, expr: ExtendedUDFExpr, ref_indexes: set[int]) -> Expr: # If any distributed resources are specified. if any([expr.batch_size, expr.num_workers, expr.worker_ray_remote_args]): if isinstance(expr.fn, type): if expr.num_workers is None: raise ValueError("Actor UDF requires num_workers specified.") # Dynamic return column expr.column_name.index = self._return_index.next() compute_func = _hybrid_compute if len(expr.num_workers) > 1 else _simple_compute concurrency, ray_remote_args_fn, ray_remote_args = compute_func( num_workers=expr.num_workers, worker_ray_remote_args=expr.worker_ray_remote_args ) self._dataset = self._dataset.map_batches( expr.fn, fn_constructor_args=expr.init_args, fn_constructor_kwargs=expr.init_kwargs, concurrency=concurrency, batch_size=expr.batch_size, batch_format="pyarrow", # We use pyarrow by default zero_copy_batch=True, # We use zero copy by default ray_remote_args_fn=ray_remote_args_fn, **ray_remote_args, ) else: num_workers = expr.num_workers if num_workers is None: num_workers = {"CPU": None} # type: ignore[dict-item] # Dynamic return column expr.column_name.index = self._return_index.next() compute_func = _hybrid_compute if len(num_workers) > 1 else _simple_compute concurrency, ray_remote_args_fn, ray_remote_args = compute_func( num_workers=num_workers, worker_ray_remote_args=expr.worker_ray_remote_args ) self._dataset = self._dataset.map_batches( _batch_column_task(expr.fn, expr.column_name), fn_args=expr.args, fn_kwargs=expr.kwargs, concurrency=concurrency, batch_size=expr.batch_size, batch_format="pyarrow", # We use pyarrow by default zero_copy_batch=True, # We use zero copy by default ray_remote_args_fn=ray_remote_args_fn, **ray_remote_args, ) if ref_indexes: drop_columns = [str(ExprReturn(index)) for index in ref_indexes] self._dataset = self._dataset.drop_columns(drop_columns) self._dropped_indexes.update(ref_indexes) return ColumnExpr(str(expr.column_name)) # If no special resources provided, treat the task as Ray data's implementation: # The UDF function will be executed together with other expressions. return UDFExpr( fn=expr.fn, args=expr.args, kwargs=expr.kwargs, data_type=expr.data_type, ) def transform(self, column_name: str, expr: Expr, **ray_remote_args: dict[str, Any]) -> ray.data.Dataset: expr = self.visit(expr) final_column_name = str(ExprReturn(self._return_index.last_value)) if expr.structurally_equals(ColumnExpr(final_column_name)): dataset = self._dataset.rename_columns({final_column_name: column_name}) else: dataset = self._dataset.with_column( column_name, expr, **ray_remote_args, ) drop_columns = [str(ExprReturn(index)) for index in set(self._return_index.range()) - self._dropped_indexes] if drop_columns: dataset = dataset.drop_columns(drop_columns) return dataset
[docs] def udf( *, return_dtype: DataType, batch_size: Literal["default"] | int | None = None, num_workers: dict[MapWorkerType, tuple[int, int] | int] | None = None, worker_ray_remote_args: dict[MapWorkerType, dict] | None = None, ): """ Decorator to convert a UDF into an expression-compatible function. This decorator allows UDFs to be used seamlessly within the expression system, enabling schema inference and integration with other expressions. IMPORTANT: UDFs operate on batches of data, not individual rows. When your UDF is called, each column argument will be passed as a PyArrow Array containing multiple values from that column across the batch. Under the hood, when working with multiple columns, they get translated to PyArrow arrays (one array per column). Args: return_dtype: The data type of the return value of the UDF batch_size: The desired number of rows in each batch, or ``None`` to use entire blocks as batches (blocks may contain different numbers of rows). The actual size of the batch provided to ``processor`` may be smaller than ``batch_size`` if ``batch_size`` doesn't evenly divide the block(s) sent to a given map task. Default ``batch_size`` is ``None``. num_workers: The number of worker processes to use for batch inference, the available worker types are ``CPU``, ``GPU`` and ``IO``. Actual number of workers will be equal or less than ``num_workers``. worker_ray_remote_args: Additional resource requirements for each type of map worker. See :func:`ray.remote` for details. Returns: A callable that creates UDFExpr instances when called with expressions Example: >>> from xpark.dataset import from_items >>> from xpark.dataset.expressions import col, udf >>> import pyarrow as pa >>> import pyarrow.compute as pc >>> import ray >>> >>> # UDF that operates on a batch of values (PyArrow Array) >>> @udf(return_dtype=DataType.int32()) ... def add_one(x: pa.Array) -> pa.Array: ... return pc.add(x, 1) # Vectorized operation on the entire Array >>> >>> # UDF that combines multiple columns (each as a PyArrow Array) >>> @udf(return_dtype=DataType.string()) ... def format_name(first: pa.Array, last: pa.Array) -> pa.Array: ... return pc.binary_join_element_wise(first, last, " ") # Vectorized string concatenation >>> >>> # Use in dataset operations >>> ds = from_items([ ... {"value": 5, "first": "John", "last": "Doe"}, ... {"value": 10, "first": "Jane", "last": "Smith"} ... ]) >>> >>> # Single column transformation (operates on batches) >>> ds_incremented = ds.with_column("value_plus_one", add_one(col("value"))) >>> >>> # Multi-column transformation (each column becomes a PyArrow Array) >>> ds_formatted = ds.with_column("full_name", format_name(col("first"), col("last"))) >>> >>> # Can also be used in complex expressions >>> ds_complex = ds.with_column("doubled_plus_one", add_one(col("value")) * 2) >>> >>> # UDF can be an actor >>> @udf(return_dtype=DataType.int32(), num_workers={"CPU": 1}) ... class Add: ... def __init__(self, value): ... self.value = value ... ... def __call__(self, array: pa.Array) -> pa.Array: ... return pc.add(array, self.value) >>> >>> add_two = Add(2) >>> ds_actor_udf = ds.with_column("value_plus_two", add_two(col("value"))) """ def _udf_wrapper( fn: Callable[P, BatchColumn] | type[BatchColumnClassProtocol[P, BatchColumn]], ) -> ExprUDFProtocol[P] | type[ExprUDFProtocol[P]]: metadata = _ExprUDFMetadata( wrapped=fn, return_dtype=return_dtype, options={ "batch_size": batch_size, "num_workers": num_workers, "worker_ray_remote_args": worker_ray_remote_args, }, ) if isinstance(fn, type): return _wrap_class(metadata) else: assert callable(fn) task: ExprTask = ExprTask(metadata) return _wrap_function(fn, task) return _udf_wrapper
[docs] @copy_sig(ray.data.expressions.star) def star(*args, **kwargs) -> StarExpr: return ray.data.expressions.star(*args, **kwargs)
[docs] @copy_sig(ray.data.expressions.col) def col(*args, **kwargs) -> ColumnExpr: return ray.data.expressions.col(*args, **kwargs)
[docs] @copy_sig(ray.data.expressions.lit) def lit(*args, **kwargs) -> LiteralExpr: return ray.data.expressions.lit(*args, **kwargs)
[docs] @copy_sig(ray.data.expressions.download) def download(*args, **kwargs) -> DownloadExpr: return ray.data.expressions.download(*args, **kwargs)
class _PatchExpr: """This is for routing the expr namespace to xpark implementation, we can extend the namespace by ourselves.""" @property def list(self) -> "_ListNamespace": """Access list operations for this expression. Returns: A _ListNamespace that provides list-specific operations. Example: >>> from xpark.dataset.expressions import col >>> from xpark.dataset import from_items >>> ds = from_items([ ... {"items": [1, 2, 3]}, ... {"items": [4, 5]} ... ]) >>> ds = ds.with_column("num_items", col("items").list.len()) >>> ds = ds.with_column("first_item", col("items").list[0]) >>> ds = ds.with_column("slice", col("items").list[1:3]) """ from xpark.dataset.namespace_expressions.list_namespace import _ListNamespace return _ListNamespace(self) @property def str(self) -> "_StringNamespace": """Access string operations for this expression. Returns: A _StringNamespace that provides string-specific operations. Example: >>> from xpark.dataset.expressions import col >>> from xpark.dataset import from_items >>> ds = from_items([ ... {"name": "Alice"}, ... {"name": "Bob"} ... ]) >>> ds = ds.with_column("upper_name", col("name").str.upper()) >>> ds = ds.with_column("name_len", col("name").str.len()) >>> ds = ds.with_column("starts_a", col("name").str.starts_with("A")) """ from xpark.dataset.namespace_expressions.string_namespace import _StringNamespace return _StringNamespace(self) @property def struct(self) -> "_StructNamespace": """Access struct operations for this expression. Returns: A _StructNamespace that provides struct-specific operations. Example: >>> from xpark.dataset.expressions import col >>> from xpark.dataset import from_arrow >>> import pyarrow as pa >>> ds = from_arrow(pa.table({ ... "user": pa.array([ ... {"name": "Alice", "age": 30} ... ], type=pa.struct([ ... pa.field("name", pa.string()), ... pa.field("age", pa.int32()) ... ])) ... })) >>> ds = ds.with_column("age", col("user").struct["age"]) # doctest: +SKIP """ from xpark.dataset.namespace_expressions.struct_namespace import _StructNamespace return _StructNamespace(self) @property def dt(self) -> "_DatetimeNamespace": """Access datetime operations for this expression. Returns: A _DatetimeNamespace that provides datetime-specific operations. Example: >>> from xpark.dataset.expressions import col >>> from xpark.dataset import from_items >>> from datetime import datetime >>> ds = from_items([{"date": datetime.now()}]) >>> ds = ds.with_column("year", col("date").dt.year()) >>> ds = ds.with_column("month", col("date").dt.month()) """ from xpark.dataset.namespace_expressions.datetime_namespace import _DatetimeNamespace return _DatetimeNamespace(self) for name, value in _PatchExpr.__dict__.items(): if isinstance(value, property): m = Expr.__dict__.get(name) if m is None: raise RuntimeError(f"Unexpected patch member: {name}") if value is not m: setattr(Expr, name, value)