Source code for ray.data.namespace_expressions.struct_namespace

"""Struct namespace for expression operations on struct-typed columns."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import pyarrow
import pyarrow.compute as pc

from ray.data.datatype import DataType
from ray.data.expressions import pyarrow_udf

if TYPE_CHECKING:
    from ray.data.expressions import Expr, UDFExpr


@dataclass
class _StructNamespace:
    """Namespace for struct operations on expression columns.

    This namespace provides methods for operating on struct-typed columns using
    PyArrow compute functions.

    Example:
        >>> from ray.data.expressions import col
        >>> # Access a field using method
        >>> expr = col("user_record").struct.field("age")
        >>> # Access a field using bracket notation
        >>> expr = col("user_record").struct["age"]
        >>> # Access nested field
        >>> expr = col("user_record").struct["address"].struct["city"]
    """

    _expr: Expr

    def __getitem__(self, field_name: str) -> "UDFExpr":
        """Extract a field using bracket notation.

        Args:
            field_name: The name of the field to extract.

        Returns:
            UDFExpr that extracts the specified field from each struct.

        Example:
            >>> col("user").struct["age"]  # Get age field  # doctest: +SKIP
            >>> col("user").struct["address"].struct["city"]  # Get nested city field  # doctest: +SKIP
        """
        return self.field(field_name)

[docs] def field(self, field_name: str) -> "UDFExpr": """Extract a field from a struct. Args: field_name: The name of the field to extract. Returns: UDFExpr that extracts the specified field from each struct. """ # Infer return type from the struct field type return_dtype = DataType(object) # fallback if self._expr.data_type.is_arrow_type(): arrow_type = self._expr.data_type.to_arrow_dtype() if pyarrow.types.is_struct(arrow_type): try: field_type = arrow_type.field(field_name).type return_dtype = DataType.from_arrow(field_type) except KeyError: # Field not found in schema, fallback to object pass @pyarrow_udf(return_dtype=return_dtype) def _struct_field(arr: pyarrow.Array) -> pyarrow.Array: return pc.struct_field(arr, field_name) return _struct_field(self._expr)