Source code for xpark.dataset.utils

from __future__ import annotations

import asyncio
import contextlib
import contextvars
import enum
import functools
import inspect
import io
import itertools
import logging
import math
import os
import threading
import time
from asyncio import events
from collections import deque
from collections.abc import Callable
from functools import wraps
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Coroutine,
    Generator,
    Iterable,
    ParamSpec,
    Tuple,
    TypeVar,
    cast,
)
from urllib.parse import urlparse

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from PIL import Image

from xpark.dataset.constants import (
    AUDIO_FORMAT_MAPPING,
    RAY_DOC_REPLACE_PAIRS,
    WRAPPER_ASSIGNMENTS,
)
from xpark.dataset.context import DatasetContext
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.types import KeyType

if TYPE_CHECKING:
    import av
    import fsspec
    import openai
    import pyarrow.compute as pc
    from av.container.input import InputContainer
    from av.container.output import OutputContainer
    from openai.types.chat.chat_completion import ChatCompletion
    from openai.types.chat.chat_completion_message_param import (
        ChatCompletionMessageParam,
    )

    from xpark.dataset.filters.dedup import utils as dedup_utils
else:
    av = lazy_import("av")
    fsspec = lazy_import("fsspec")
    openai = lazy_import("openai")
    pc = lazy_import("pyarrow.compute", rename="pc")
    pa = lazy_import("pyarrow", rename="pa")
    dedup_utils = lazy_import("xpark.dataset.filters.dedup.utils", rename="dedup_utils")

logger = logging.getLogger("ray")

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


def copy_sig(fn: Callable[P, R]) -> Callable[[Callable[..., T]], Callable[P, T]]:
    """Extend the functools.wraps (The `wraps` function name results in incorrect type inference in PyCharm,
    so it has been renamed to copy_sig):
    1. Skip copying __qualname__ and __module__ attributes
    2. Update __doc__
    3. Update return annotations
    """

    def decorator(target: Callable[..., T]) -> Callable[P, T]:
        signature = inspect.signature(target)
        if signature.return_annotation is inspect._empty:
            raise TypeError(
                f"The {copy_sig.__name__} decorated function `{target.__name__}` must have a return annotation."
            )
        wrapped = functools.wraps(fn, assigned=WRAPPER_ASSIGNMENTS)(target)
        wrapped.__doc__ = wrap_ray_doc(fn.__doc__)
        wrapped.__annotations__ = wrapped.__annotations__.copy()
        wrapped.__annotations__["return"] = signature.return_annotation
        return cast(Callable[P, T], wrapped)

    return decorator


def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
    updated_mapping = mapping.copy()
    for updating_mapping in updating_mappings:
        for k, v in updating_mapping.items():
            if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
                updated_mapping[k] = deep_update(updated_mapping[k], v)
            else:
                updated_mapping[k] = v
    return updated_mapping


def multi_replace(s: str | None, replace_pairs: dict[str, str]) -> str | None:
    if s is None:
        return s
    for old, new in replace_pairs.items():
        s = s.replace(old, new)
    return s


def wrap_ray_doc(s: str | None) -> str | None:
    return multi_replace(s, RAY_DOC_REPLACE_PAIRS)


def wrap_ray_class(ray_class: type[T]) -> type[T]:
    """Create a wrapper class that inherits from a Ray class and replaces its documentation.

    This function creates a new class that:
    1. Inherits from the original Ray class
    2. Automatically replaces Ray-specific documentation with Xpark equivalents
    3. Preserves all functionality of the original class

    Similar to `copy_sig`, this function helps maintain Xpark-specific documentation
    without modifying the original Ray classes.

    Args:
        ray_class: The Ray class to wrap.

    Returns:
        A new class that inherits from ray_class with updated documentation.

    Example:
        >>> import ray.data.aggregate
        >>> Count = wrap_ray_class(ray.data.aggregate.Count)
        >>> # Count now has Xpark-specific documentation
    """

    # Special handling for Enum classes
    if isinstance(ray_class, enum.EnumMeta):
        raise Exception("Do not wrap enum class.")

    # For regular classes, use inheritance
    class WrappedClass(ray_class):  # type: ignore
        __doc__ = wrap_ray_doc(ray_class.__doc__)

    # Preserve the original class name and module
    WrappedClass.__name__ = ray_class.__name__
    WrappedClass.__qualname__ = ray_class.__qualname__

    return WrappedClass  # type: ignore


def iter_batch(array: pa.ChunkedArray, batch_size: int) -> Generator[pa.ChunkedArray, None, None]:
    x = len(array)
    for i in range(0, x, batch_size):
        a = i
        b = min(i + batch_size, x)
        yield array[a:b]


def qps_limiter(max_qps: int | None = None, max_concurrency: int | None = None):
    """Decorator to limit QPS (queries per second) for async functions."""
    if max_qps and max_qps < 1:
        raise ValueError("max_qps must be >= 1")
    if max_concurrency and max_concurrency < 1:
        raise ValueError("max_concurrency must be >= 1")

    def _decorator(func):
        if max_qps is None and max_concurrency is None:
            return func

        # Store timestamps of function calls
        timestamps = deque()
        lock = asyncio.Lock()
        semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None

        @copy_sig(func)
        async def wrapper(*args, **kwargs) -> Any:
            if semaphore:
                await semaphore.acquire()

            try:
                if max_qps:
                    # Acquire lock to ensure that rate-limiting is concurrent-safe
                    async with lock:
                        # Current time
                        now = time.time()

                        # Remove timestamps that are older than 1 second from the current time
                        while timestamps and timestamps[0] < now - 1:
                            timestamps.popleft()

                        # Check if we exceeded the max QPS
                        if len(timestamps) >= max_qps:
                            # Calculate how much time we need to wait
                            sleep_time = 1 - (now - timestamps[0])
                            await asyncio.sleep(sleep_time)

                        timestamps.append(time.time())
                        return await func(*args, **kwargs)

            finally:
                if semaphore:
                    semaphore.release()

        return wrapper

    return _decorator


[docs] class Count: def __init__(self, firstval=0, step=1): self._firstval = firstval self._step = step self._gen = itertools.count(firstval, step) self._val = firstval - step @property def last_value(self) -> int: return self._val def range(self) -> Iterable[int]: return range(self._firstval, self._val + self._step, self._step) def next(self) -> int: self._val = val = next(self._gen) return val def reset(self) -> None: self._gen = itertools.count(self._firstval, self._step) self._val = self._firstval - self._step def is_reset(self): return self._val == self._firstval - self._step
def audio_codec_to_format_ext(codec_name: str) -> tuple[str, str]: """ Convert audio codec name to file format and extension Args: codec_name: Audio codec name. Returns: File format and extension. """ codec_name = codec_name.lower() if codec_name.startswith("pcm_"): return "wav", "wav" if codec_name not in AUDIO_FORMAT_MAPPING: raise ValueError(f"Audio codec name not found in AUDIO_FORMAT_MAPPING: {codec_name}") return AUDIO_FORMAT_MAPPING[codec_name] @contextlib.asynccontextmanager async def open_video( video: pa.BinaryScalar | pa.StringScalar | str | bytes, **storage_options: dict[str, Any] ) -> AsyncGenerator[InputContainer, None]: """Open a video file asynchronously. Supports multiple input types: - pa.StringScalar: PyArrow string scalar (video path) - pa.BinaryScalar: PyArrow binary scalar (video data) - str: Raw string path - bytes: Raw binary data Args: video: Video path or binary data. **storage_options: Storage options for remote paths. Yields: PyAV InputContainer for the video. """ video_binary, video_url = b"", "" # Handle different input types if isinstance(video, str): video_url = video elif isinstance(video, bytes): video_binary = video elif isinstance(video, pa.StringScalar): video_url = video.as_py() elif isinstance(video, pa.BinaryScalar) and not isinstance(video, pa.StringScalar): video_binary = video.as_py() if video_url != "" and len(video_binary) == 0: if video_url.startswith("http://") or video_url.startswith("https://"): with await asyncio.to_thread(av.open, video_url) as input_container: yield input_container else: with await asyncio.to_thread(auto_open, video_url, "rb", **storage_options) as file_obj: with await asyncio.to_thread(av.open, file_obj) as input_container: yield input_container else: with io.BytesIO(video_binary) as bio: with await asyncio.to_thread(av.open, bio, "r") as input_container: yield input_container def assert_type(x: Any, expect_type: type[T]) -> T: """The helper function to narrow the types.""" assert isinstance(x, expect_type) return x async def to_thread(func, /, *args, _executor=None, **kwargs): """Asynchronously run function *func* in a separate thread. Any *args and **kwargs supplied for this function are directly passed to *func*. Also, the current :class:`contextvars.Context` is propagated, allowing context variables from the main thread to be accessed in the separate thread. Return a coroutine that can be awaited to get the eventual result of *func*. """ loop = events.get_running_loop() ctx = contextvars.copy_context() func_call = functools.partial(ctx.run, func, *args, **kwargs) return await loop.run_in_executor(_executor, func_call) def split_audio( audio_data: np.ndarray, sample_rate: int, max_audio_clip_s: int = 30, overlap_chunk_second: int = 1, min_energy_split_window_size: int = 1600, ) -> list[np.ndarray]: """Copy from: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/speech_to_text.py""" chunk_size = sample_rate * max_audio_clip_s overlap_size = max(0, sample_rate * overlap_chunk_second) chunks = [] i = 0 total_samples = audio_data.shape[-1] while i < total_samples: if i + chunk_size >= total_samples: # handle last chunk chunks.append(audio_data[..., i:]) break search_end = min(i + chunk_size, total_samples) search_window = max(1, min_energy_split_window_size, overlap_size) search_range = min(chunk_size, search_window) search_start = max(i, search_end - search_range) split_point = _find_split_point( audio_data, search_start, search_end, min_energy_split_window_size, ) if split_point <= i or split_point > search_end: split_point = search_end # Extract chunk up to the split point chunks.append(audio_data[..., i:split_point]) i = split_point return chunks def _find_split_point(wav: np.ndarray, start_idx: int, end_idx: int, min_energy_split_window_size: int) -> int: """Find the best point to split audio by looking for silence or low amplitude. Args: wav: Audio tensor [1, T] start_idx: Start index of search region end_idx: End index of search region Returns: Index of best splitting point """ segment = wav[start_idx:end_idx] # Calculate RMS energy in small windows min_energy = math.inf quietest_idx = 0 min_energy_window = min_energy_split_window_size assert min_energy_window is not None for i in range(0, len(segment) - min_energy_window, min_energy_window): window = segment[i : i + min_energy_window] energy = (window**2).mean() ** 0.5 if energy < min_energy: quietest_idx = i + start_idx min_energy = energy return quietest_idx async def auto_batch( items: list[Any], *, map_fn: Callable[[int, Any], Coroutine[None, None, tuple[int, Any]] | tuple[int, Any]], filter_fn: Callable[[Any], bool] | None = None, ) -> AsyncGenerator[tuple[list[int], list[Any]], None]: """Auto batch process the items by the map function. Args: items: A list of items map_fn: The sync or async map function to process one item. It should accept (index, item) and return (index, value). filter_fn: A sync filter function to determine which items to process. All the items will be processed if the filter_fn is None. Returns: An async generator that yields a tuple of (batch index, batch items). """ batch_index = [] batch_items = [] async_tasks: list[asyncio.Task] = [] if not inspect.iscoroutinefunction(map_fn): map_fn = cast( Callable[[int, Any], Coroutine[None, None, tuple[int, Any]]], functools.partial(asyncio.to_thread, map_fn), # type: ignore[call-arg] ) if filter_fn is not None: for index, item in enumerate(items): if filter_fn(item): async_tasks.append(asyncio.create_task(map_fn(index, item))) else: batch_index.append(index) batch_items.append(item) else: async_tasks = [asyncio.create_task(map_fn(index, item)) for index, item in enumerate(items)] while batch_index or async_tasks: if async_tasks: ready, remaining = await asyncio.wait(async_tasks, timeout=0) for task in ready: idx, data = task.result() batch_index.append(idx) batch_items.append(data) async_tasks[:] = remaining if not batch_items: # We got nothing, so wait until got one. ready, remaining = await asyncio.wait(async_tasks, return_when=asyncio.FIRST_COMPLETED) for task in ready: idx, data = task.result() batch_index.append(idx) batch_items.append(data) async_tasks[:] = remaining assert len(batch_index) == len(batch_items) yield batch_index, batch_items batch_index.clear() batch_items.clear() @contextlib.asynccontextmanager async def open_image( image: pa.BinaryScalar | pa.StringScalar | pa.ExtensionScalar, source_mode: str | None = None, **storage_options: dict[str, Any], ) -> AsyncGenerator[Image.Image, None]: if isinstance(image, pa.StringScalar): image_url = image.as_py() with await asyncio.to_thread(auto_open, image_url, **storage_options) as file_obj: img = Image.open(file_obj) # convert PIL.ImageFile.ImageFile to Image.Image yield img.convert(source_mode if source_mode is not None else img.mode) elif isinstance(image, pa.ExtensionScalar): image_ndarray = image.as_py() yield Image.fromarray(image_ndarray, source_mode) # Note: pa.StringScalar is a subclass of pa.BinaryScalar, so we need to exclude it explicitly elif isinstance(image, pa.BinaryScalar) and not isinstance(image, pa.StringScalar): image_binary = image.as_py() with io.BytesIO(image_binary) as bio: with Image.open(bio) as img: # convert PIL.ImageFile.ImageFile to Image.Image yield img.convert(source_mode if source_mode is not None else img.mode) async def read_image( image: pa.BinaryScalar | pa.StringScalar | pa.ExtensionScalar, source_mode: str | None = None, **storage_options: dict[str, Any], ) -> Image.Image: async with open_image(image, source_mode, **storage_options) as img: return img def _register_cos_protocol(): # Import s3fs takes more than 200ms, lazy import it here. cos_protocol = "cos" if cos_protocol not in fsspec.registry: from s3fs import S3FileSystem class COSFileSystem(S3FileSystem): protocol = cos_protocol try: fsspec.register_implementation(cos_protocol, COSFileSystem) except ValueError: pass if cos_protocol not in fsspec.registry: raise Exception("Register COSFileSystem failed!") def _register_hf_protocol(): hf_protocol = "hf" if hf_protocol not in fsspec.registry: try: from huggingface_hub import HfFileSystem fsspec.register_implementation(hf_protocol, HfFileSystem) except (ImportError, ValueError): pass def get_filesystem(path: str) -> Tuple["fsspec.AbstractFileSystem", str]: """Get filesystem and normalized path from a path string.""" parsed = urlparse(path) protocol = parsed.scheme or "file" if protocol == "cos": _register_cos_protocol() elif protocol == "hf": _register_hf_protocol() ctx = DatasetContext.get_current() storage_options = ctx.storage_options if protocol == "hf": hf_token = os.environ.get("HF_TOKEN") or storage_options.get("hf", {}).get("token") proto_options = {"token": hf_token} if hf_token else {} else: proto_options = storage_options.get(protocol, {}) fs = fsspec.filesystem(protocol, **proto_options) if protocol != "file": normalized_path = path.split("://", 1)[1] if "://" in path else path else: normalized_path = parsed.path or path return fs, normalized_path def auto_open( urlpath, mode="rb", compression=None, encoding="utf8", errors=None, protocol=None, newline=None, expand=None, **storage_options, ) -> fsspec.core.OpenFile: """Auto select storage options based on protocol.""" if isinstance(urlpath, (list, tuple, set)): if not urlpath: raise ValueError("empty urlpath sequence") urlpath0 = fsspec.core.stringify_path(next(iter(urlpath))) else: urlpath0 = fsspec.core.stringify_path(urlpath) # Ensure that the cos protocol is registered before open. _register_cos_protocol() selected_storage_options = {} for bit, protocol, kw in fsspec.core._un_chain(urlpath0, {}): so = storage_options.get(protocol) if so: selected_storage_options[protocol] = so return fsspec.open( urlpath, mode=mode, compression=compression, encoding=encoding, errors=errors, protocol=protocol, newline=newline, expand=expand, **selected_storage_options, ) @contextlib.contextmanager def isolated_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: """A simpler version of `loop_in_thread`, don't need the ThreadPoolExecutor. https://docs.python.org/3/library/asyncio-task.html#asyncio.run_coroutine_threadsafe """ def _run_loop(_loop: asyncio.AbstractEventLoop): asyncio.set_event_loop(_loop) _loop.run_forever() try: to_cancel = asyncio.tasks.all_tasks(_loop) if not to_cancel: return for task in to_cancel: task.cancel() _loop.run_until_complete(asyncio.tasks.gather(*to_cancel, return_exceptions=True)) for task in to_cancel: if task.cancelled(): continue if task.exception() is not None: _loop.call_exception_handler( { "message": "unhandled exception during isolated loop shutdown", "exception": task.exception(), "task": task, } ) _loop.run_until_complete(_loop.shutdown_asyncgens()) _loop.run_until_complete(_loop.shutdown_default_executor()) finally: asyncio.set_event_loop(None) _loop.close() new_loop = asyncio.new_event_loop() loop_thread = threading.Thread(target=_run_loop, args=(new_loop,)) loop_thread.start() yield new_loop new_loop.call_soon_threadsafe(new_loop.stop) loop_thread.join() def safe_run(main: Coroutine) -> Any: try: asyncio.get_event_loop() except RuntimeError: return asyncio.run(main) with isolated_loop() as loop: return asyncio.run_coroutine_threadsafe(main, loop).result() def copy_video_stream( input_container: InputContainer, output_container: OutputContainer, *, start_second: float = 0, end_second: float | None = None, ): """Copies the video stream without re-encoding. This method relies on locating the I-frame in the video, which may result in slight inaccuracies. For example, if you extract a 10-second segment, the final video might be 10.3 seconds long. This behavior is equivalent to using the '-ss' option before '-i' and '-c:v copy' in ffmpeg. """ # Import av takes more than 200ms, lazy import it here. from av.audio.stream import AudioStream from av.packet import Packet from av.stream import Stream from av.video.stream import VideoStream streams_map: dict[Stream, VideoStream | AudioStream] = {} video_stream = input_container.streams.video[0] if video_stream.time_base is None: raise ValueError("video stream time base is zero, file may be broken.") for stream in input_container.streams: if stream.type == "video": streams_map[stream] = output_container.add_stream_from_template(template=cast(VideoStream, stream)) elif stream.type == "audio": streams_map[stream] = output_container.add_stream_from_template(template=cast(AudioStream, stream)) start_pts_map = {} running_stream = set(streams_map.keys()) input_container.seek(int(start_second / video_stream.time_base), stream=video_stream) # TODO(anthonycai) add precise segmentation logic, use decode method for packet in input_container.demux(*streams_map.keys()): if not isinstance(packet, Packet) or packet.pts is None: continue current_time_sec = packet.pts * packet.time_base if end_second is not None and current_time_sec > end_second: if packet.stream in running_stream: running_stream.remove(packet.stream) if len(running_stream) == 0: break continue input_stream = packet.stream if input_stream not in start_pts_map: start_pts_map[input_stream] = packet.pts offset = start_pts_map[input_stream] packet.pts -= offset if packet.dts is not None: packet.dts -= offset packet.stream = streams_map[input_stream] output_container.mux(packet) # flush encoder for stream in streams_map.values(): if stream.type == "video": output_container.mux(stream.encode(None)) def desire_video_format(list_format: list[str]) -> str: """ Select a single, explicitly supported output container format (FFmpeg muxer name) from the list of possible container formats reported by FFmpeg. Why is this function needed? 1. `input_container.format.name` returns a string like "mov,mp4,m4a,3gp,3g2,mj2", which lists **all possible container formats** that FFmpeg believes the file could be compatible with (comma-separated). Passing this full string directly to `av.open(..., format=...)` will fail because FFmpeg cannot recognize a combined name like "mov,mp4,m4a,..." — it expects **one valid muxer name**. 2. The `format` parameter in `av.open(..., mode="w", format=...)` **must be a single valid muxer name** (e.g., "mp4", "mov", "matroska"). Otherwise, it raises: `ValueError: unknown format`. 3. Priority strategy: - Prefer **mp4**: most widely supported, excellent web and mobile compatibility. - Then **mov**: common in Apple ecosystem, good compatibility with QuickTime. - Then **mkv** (as "matroska"): powerful container, supports many codecs. - Fallback: use the first format in the list (guaranteed to be recognized by FFmpeg). This ensures we always pass a **valid and highly compatible** output format to `av.open`. """ if "mp4" in list_format: output_format = "mp4" elif "mov" in list_format: output_format = "mov" elif "mkv" in list_format: output_format = "matroska" else: output_format = list_format[0] return output_format def cut_video_by_seconds( input_container: InputContainer, start_second: float, end_second: float | None = None, ) -> bytes: input_video_stream = input_container.streams.video[0] if input_video_stream.time_base is None or input_video_stream.time_base == 0: raise ValueError("video stream time base is zero, file may be broken.") with io.BytesIO() as output_buffer: with av.open( output_buffer, mode="w", format=desire_video_format(input_container.format.name.split(",")) ) as output_container: copy_video_stream(input_container, output_container, start_second=start_second, end_second=end_second) return output_buffer.getvalue() def deep_getattr(obj: object, attr: str) -> Any: try: for part in attr.split("."): obj = getattr(obj, part) return obj except AttributeError: raise AttributeError(f"'{obj.__class__.__name__}' object has no attribute '{attr}'") class LLMChatCompletions: def __init__( self, base_url: str, model: str, api_key: str, max_qps: int | None = None, max_retries: int = 0, fallback_response: str | float | None = None, response_format: str = "text", **kwargs: dict[str, Any], ): logger.info("Using Remote Http LLMChatCompletions") self.model = model self.kwargs = kwargs self.fallback_response = fallback_response supported_formats = {"json_object", "text"} if response_format not in supported_formats: raise ValueError(f"Unsupported response_format: {response_format}") self.response_format = {"type": response_format} client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries) self.create_completions = qps_limiter(max_qps)(client.chat.completions.create) async def __call__(self, messages: Iterable[ChatCompletionMessageParam]) -> ChatCompletion | None: try: return await self.create_completions( messages=messages, model=self.model, response_format=self.response_format, **self.kwargs ) except Exception as e: logger.error(f"create_completions failed: {e}") return None async def batch_generate( self, texts: pa.ChunkedArray, build_prompt: Callable[[str], Iterable[ChatCompletionMessageParam]], post_process: Callable[[str], str | float] | None = None, datatype: pa.DataType = pa.string(), ) -> pa.Array: requests = [] for text in texts: requests.append(self(messages=build_prompt(text))) responses = await asyncio.gather(*requests) results = [] for response in responses: if response is not None and len(response.choices) > 0 and response.choices[0].message.content is not None: results.append( response.choices[0].message.content.strip() if post_process is None else post_process(response.choices[0].message.content.strip()) ) else: logger.error(f"request failed, response is: {response}") if self.fallback_response is not None: results.append(self.fallback_response) else: raise ValueError("LLM call failed or returned empty content.") return pa.array(results, type=datatype) def skip_empty_texts(func: Callable | None = None, empty_response: str = ""): """Decorator to skip empty or null texts in text processing.""" def decorator(f: Callable) -> Callable: @wraps(f) async def wrapper(self, texts: pa.ChunkedArray) -> pa.Array: combined = texts.combine_chunks() is_not_empty = pc.fill_null(pc.not_equal(pc.utf8_length(combined), 0), False) valid_indices = pc.indices_nonzero(is_not_empty) if len(valid_indices) == 0: return pa.array([empty_response] * len(combined), type=pa.string()) valid_texts = pc.filter(combined, is_not_empty) processed_results = await f(self, valid_texts) results = [empty_response] * len(combined) for idx, valid_idx in enumerate(valid_indices): results[valid_idx.as_py()] = processed_results[idx].as_py() return pa.array(results, type=pa.string()) return wrapper # Support both @skip_empty_texts and @skip_empty_texts(...) if func is not None: return decorator(func) return decorator def normalize_labels(labels: list[str], labels_name: str = "labels") -> list[str]: labels = list(dict.fromkeys(label.strip() for label in labels if label.strip())) if len(labels) == 0: raise ValueError(f"{labels_name} must not be empty") return labels def text_tokenize(arr: pa.Array, regex_patten: str | None = None, cjk=False) -> pa.Array: if regex_patten is None: if cjk: tokens = dedup_utils.utf8_split_mixed(arr) else: tokens = pc.utf8_split_whitespace(arr) else: tokens = pc.split_pattern_regex(arr, pattern=regex_patten) return tokens