Source code for xpark.dataset.utils

from __future__ import annotations

import asyncio
import contextlib
import contextvars
import enum
import functools
import inspect
import io
import itertools
import logging
import math
import os
import re
import threading
import time
import uuid
from asyncio import events
from collections import deque
from collections.abc import Callable
from functools import wraps
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Coroutine,
    Generator,
    Iterable,
    Literal,
    ParamSpec,
    Tuple,
    TypeVar,
    cast,
)
from urllib.parse import urlparse

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from PIL import Image
from pydantic import BaseModel, model_validator

from xpark.dataset.constants import (
    AUDIO_FORMAT_MAPPING,
    RAY_DOC_REPLACE_PAIRS,
    WRAPPER_ASSIGNMENTS,
)
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.types import KeyType

if TYPE_CHECKING:
    import av
    import fsspec
    import openai
    import pyarrow.compute as pc
    from av.container.input import InputContainer
    from av.container.output import OutputContainer
    from openai.types.chat.chat_completion import ChatCompletion
    from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam

    from xpark.dataset.filters.dedup import utils as dedup_utils
else:
    av = lazy_import("av")
    fsspec = lazy_import("fsspec")
    openai = lazy_import("openai")
    pc = lazy_import("pyarrow.compute", rename="pc")
    pa = lazy_import("pyarrow", rename="pa")
    dedup_utils = lazy_import("xpark.dataset.filters.dedup.utils", rename="dedup_utils")

logger = logging.getLogger("ray")

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")

_PATH_SEGMENT_SANITIZE_PATTERN = re.compile(
    r"[^\w\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7af"
    r"\u3400-\u4dbf\U00020000-\U0002a6df\s-]"
)


def copy_sig(fn: Callable[P, R]) -> Callable[[Callable[..., T]], Callable[P, T]]:
    """Extend the functools.wraps (The `wraps` function name results in incorrect type inference in PyCharm,
    so it has been renamed to copy_sig):
    1. Skip copying __qualname__ and __module__ attributes
    2. Update __doc__
    3. Update return annotations
    """

    def decorator(target: Callable[..., T]) -> Callable[P, T]:
        signature = inspect.signature(target)
        if signature.return_annotation is inspect._empty:
            raise TypeError(
                f"The {copy_sig.__name__} decorated function `{target.__name__}` must have a return annotation."
            )
        wrapped = functools.wraps(fn, assigned=WRAPPER_ASSIGNMENTS)(target)
        wrapped.__doc__ = wrap_ray_doc(fn.__doc__)
        wrapped.__annotations__ = wrapped.__annotations__.copy()
        wrapped.__annotations__["return"] = signature.return_annotation
        return cast(Callable[P, T], wrapped)

    return decorator


def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
    updated_mapping = mapping.copy()
    for updating_mapping in updating_mappings:
        for k, v in updating_mapping.items():
            if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
                updated_mapping[k] = deep_update(updated_mapping[k], v)
            else:
                updated_mapping[k] = v
    return updated_mapping


def multi_replace(s: str | None, replace_pairs: dict[str, str]) -> str | None:
    if s is None:
        return s
    for old, new in replace_pairs.items():
        s = s.replace(old, new)
    return s


def wrap_ray_doc(s: str | None) -> str | None:
    return multi_replace(s, RAY_DOC_REPLACE_PAIRS)


def wrap_ray_class(ray_class: type[T]) -> type[T]:
    """Create a wrapper class that inherits from a Ray class and replaces its documentation.

    This function creates a new class that:
    1. Inherits from the original Ray class
    2. Automatically replaces Ray-specific documentation with Xpark equivalents
    3. Preserves all functionality of the original class

    Similar to `copy_sig`, this function helps maintain Xpark-specific documentation
    without modifying the original Ray classes.

    Args:
        ray_class: The Ray class to wrap.

    Returns:
        A new class that inherits from ray_class with updated documentation.

    Example:
        >>> import ray.data.aggregate
        >>> Count = wrap_ray_class(ray.data.aggregate.Count)
        >>> # Count now has Xpark-specific documentation
    """

    # Special handling for Enum classes
    if isinstance(ray_class, enum.EnumMeta):
        raise Exception("Do not wrap enum class.")

    # For regular classes, use inheritance
    class WrappedClass(ray_class):  # type: ignore
        __doc__ = wrap_ray_doc(ray_class.__doc__)

    # Preserve the original class name and module
    WrappedClass.__name__ = ray_class.__name__
    WrappedClass.__qualname__ = ray_class.__qualname__

    return WrappedClass  # type: ignore


def iter_batch(array: pa.ChunkedArray, batch_size: int) -> Generator[pa.ChunkedArray, None, None]:
    x = len(array)
    for i in range(0, x, batch_size):
        a = i
        b = min(i + batch_size, x)
        yield array[a:b]


def qps_limiter(max_qps: int | None = None, max_concurrency: int | None = None):
    """Decorator to limit QPS (queries per second) for async functions."""
    if max_qps and max_qps < 1:
        raise ValueError("max_qps must be >= 1")
    if max_concurrency and max_concurrency < 1:
        raise ValueError("max_concurrency must be >= 1")

    def _decorator(func):
        if max_qps is None and max_concurrency is None:
            return func

        # Store timestamps of function calls
        timestamps = deque()
        lock = asyncio.Lock()
        semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None

        @copy_sig(func)
        async def wrapper(*args, **kwargs) -> Any:
            if semaphore:
                await semaphore.acquire()

            try:
                if max_qps:
                    # Acquire lock to ensure that rate-limiting is concurrent-safe
                    async with lock:
                        # Current time
                        now = time.time()

                        # Remove timestamps that are older than 1 second from the current time
                        while timestamps and timestamps[0] < now - 1:
                            timestamps.popleft()

                        # Check if we exceeded the max QPS
                        if len(timestamps) >= max_qps:
                            # Calculate how much time we need to wait
                            sleep_time = 1 - (now - timestamps[0])
                            await asyncio.sleep(sleep_time)

                        timestamps.append(time.time())
                        return await func(*args, **kwargs)

            finally:
                if semaphore:
                    semaphore.release()

        return wrapper

    return _decorator


[docs] class Count: def __init__(self, firstval=0, step=1): self._firstval = firstval self._step = step self._gen = itertools.count(firstval, step) self._val = firstval - step @property def last_value(self) -> int: return self._val def range(self) -> Iterable[int]: return range(self._firstval, self._val + self._step, self._step) def next(self) -> int: self._val = val = next(self._gen) return val def reset(self) -> None: self._gen = itertools.count(self._firstval, self._step) self._val = self._firstval - self._step def is_reset(self): return self._val == self._firstval - self._step
def audio_codec_to_format_ext(codec_name: str) -> tuple[str, str]: """ Convert audio codec name to file format and extension Args: codec_name: Audio codec name. Returns: File format and extension. """ codec_name = codec_name.lower() if codec_name.startswith("pcm_"): return "wav", "wav" if codec_name not in AUDIO_FORMAT_MAPPING: raise ValueError(f"Audio codec name not found in AUDIO_FORMAT_MAPPING: {codec_name}") return AUDIO_FORMAT_MAPPING[codec_name] @contextlib.asynccontextmanager async def open_video( video: pa.BinaryScalar | pa.StringScalar | str | bytes, **storage_options: dict[str, Any] ) -> AsyncGenerator[InputContainer, None]: """Open a video file asynchronously. Supports multiple input types: - pa.StringScalar: PyArrow string scalar (video path) - pa.BinaryScalar: PyArrow binary scalar (video data) - str: Raw string path - bytes: Raw binary data Args: video: Video path or binary data. **storage_options: Storage options for remote paths. Yields: PyAV InputContainer for the video. """ video_binary, video_url = b"", "" # Handle different input types if isinstance(video, str): video_url = video elif isinstance(video, bytes): video_binary = video elif isinstance(video, pa.StringScalar): video_url = video.as_py() elif isinstance(video, pa.BinaryScalar) and not isinstance(video, pa.StringScalar): video_binary = video.as_py() if video_url != "" and len(video_binary) == 0: if video_url.startswith("http://") or video_url.startswith("https://"): with await asyncio.to_thread(av.open, video_url) as input_container: yield input_container else: with await asyncio.to_thread(auto_open, video_url, "rb", **storage_options) as file_obj: with await asyncio.to_thread(av.open, file_obj) as input_container: yield input_container else: with io.BytesIO(video_binary) as bio: with await asyncio.to_thread(av.open, bio, "r") as input_container: yield input_container def assert_type(x: Any, expect_type: type[T]) -> T: """The helper function to narrow the types.""" assert isinstance(x, expect_type) return x async def to_thread(func, /, *args, _executor=None, **kwargs): """Asynchronously run function *func* in a separate thread. Any *args and **kwargs supplied for this function are directly passed to *func*. Also, the current :class:`contextvars.Context` is propagated, allowing context variables from the main thread to be accessed in the separate thread. Return a coroutine that can be awaited to get the eventual result of *func*. """ loop = events.get_running_loop() ctx = contextvars.copy_context() func_call = functools.partial(ctx.run, func, *args, **kwargs) return await loop.run_in_executor(_executor, func_call) def split_audio( audio_data: np.ndarray, sample_rate: int, max_audio_clip_s: int = 30, overlap_chunk_second: int = 1, min_energy_split_window_size: int = 1600, ) -> list[np.ndarray]: """Copy from: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/speech_to_text.py""" chunk_size = sample_rate * max_audio_clip_s overlap_size = max(0, sample_rate * overlap_chunk_second) chunks = [] i = 0 total_samples = audio_data.shape[-1] while i < total_samples: if i + chunk_size >= total_samples: # handle last chunk chunks.append(audio_data[..., i:]) break search_end = min(i + chunk_size, total_samples) search_window = max(1, min_energy_split_window_size, overlap_size) search_range = min(chunk_size, search_window) search_start = max(i, search_end - search_range) split_point = _find_split_point( audio_data, search_start, search_end, min_energy_split_window_size, ) if split_point <= i or split_point > search_end: split_point = search_end # Extract chunk up to the split point chunks.append(audio_data[..., i:split_point]) i = split_point return chunks def _find_split_point(wav: np.ndarray, start_idx: int, end_idx: int, min_energy_split_window_size: int) -> int: """Find the best point to split audio by looking for silence or low amplitude. Args: wav: Audio tensor [1, T] start_idx: Start index of search region end_idx: End index of search region Returns: Index of best splitting point """ segment = wav[start_idx:end_idx] # Calculate RMS energy in small windows min_energy = math.inf quietest_idx = 0 min_energy_window = min_energy_split_window_size assert min_energy_window is not None for i in range(0, len(segment) - min_energy_window, min_energy_window): window = segment[i : i + min_energy_window] energy = (window**2).mean() ** 0.5 if energy < min_energy: quietest_idx = i + start_idx min_energy = energy return quietest_idx async def auto_batch( items: list[Any], *, map_fn: Callable[[int, Any], Coroutine[None, None, tuple[int, Any]] | tuple[int, Any]], filter_fn: Callable[[Any], bool] | None = None, ) -> AsyncGenerator[tuple[list[int], list[Any]], None]: """Auto batch process the items by the map function. Args: items: A list of items map_fn: The sync or async map function to process one item. It should accept (index, item) and return (index, value). filter_fn: A sync filter function to determine which items to process. All the items will be processed if the filter_fn is None. Returns: An async generator that yields a tuple of (batch index, batch items). """ batch_index = [] batch_items = [] async_tasks: list[asyncio.Task] = [] if not inspect.iscoroutinefunction(map_fn): map_fn = cast( Callable[[int, Any], Coroutine[None, None, tuple[int, Any]]], functools.partial(asyncio.to_thread, map_fn), # type: ignore[call-arg] ) if filter_fn is not None: for index, item in enumerate(items): if filter_fn(item): async_tasks.append(asyncio.create_task(map_fn(index, item))) else: batch_index.append(index) batch_items.append(item) else: async_tasks = [asyncio.create_task(map_fn(index, item)) for index, item in enumerate(items)] while batch_index or async_tasks: if async_tasks: ready, remaining = await asyncio.wait(async_tasks, timeout=0) for task in ready: idx, data = task.result() batch_index.append(idx) batch_items.append(data) async_tasks[:] = remaining if not batch_items: # We got nothing, so wait until got one. ready, remaining = await asyncio.wait(async_tasks, return_when=asyncio.FIRST_COMPLETED) for task in ready: idx, data = task.result() batch_index.append(idx) batch_items.append(data) async_tasks[:] = remaining assert len(batch_index) == len(batch_items) yield batch_index, batch_items batch_index.clear() batch_items.clear() @contextlib.asynccontextmanager async def open_image( image: pa.BinaryScalar | pa.StringScalar | pa.ExtensionScalar, source_mode: str | None = None, **storage_options: dict[str, Any], ) -> AsyncGenerator[Image.Image, None]: if isinstance(image, pa.StringScalar): image_url = image.as_py() with await asyncio.to_thread(auto_open, image_url, **storage_options) as file_obj: img = Image.open(file_obj) # convert PIL.ImageFile.ImageFile to Image.Image yield img.convert(source_mode if source_mode is not None else img.mode) elif isinstance(image, pa.ExtensionScalar): image_ndarray = image.as_py() yield Image.fromarray(image_ndarray, source_mode) # Note: pa.StringScalar is a subclass of pa.BinaryScalar, so we need to exclude it explicitly elif isinstance(image, pa.BinaryScalar) and not isinstance(image, pa.StringScalar): image_binary = image.as_py() with io.BytesIO(image_binary) as bio: with Image.open(bio) as img: # convert PIL.ImageFile.ImageFile to Image.Image yield img.convert(source_mode if source_mode is not None else img.mode) async def read_image( image: pa.BinaryScalar | pa.StringScalar | pa.ExtensionScalar, source_mode: str | None = None, **storage_options: dict[str, Any], ) -> Image.Image: async with open_image(image, source_mode, **storage_options) as img: return img def _register_cos_protocol(): # Import s3fs takes more than 200ms, lazy import it here. cos_protocol = "cos" if cos_protocol not in fsspec.registry: from s3fs import S3FileSystem class COSFileSystem(S3FileSystem): protocol = cos_protocol try: fsspec.register_implementation(cos_protocol, COSFileSystem) except ValueError: pass if cos_protocol not in fsspec.registry: raise Exception("Register COSFileSystem failed!") def _register_hf_protocol(): hf_protocol = "hf" if hf_protocol not in fsspec.registry: try: from huggingface_hub import HfFileSystem fsspec.register_implementation(hf_protocol, HfFileSystem) except (ImportError, ValueError): pass def get_filesystem(path: str, **storage_options) -> Tuple["fsspec.AbstractFileSystem", str]: """Get filesystem and normalized path from a path string.""" parsed = urlparse(path) protocol = parsed.scheme or "file" if protocol == "cos": _register_cos_protocol() elif protocol == "hf": _register_hf_protocol() if protocol == "hf": hf_token = os.environ.get("HF_TOKEN") or storage_options.get("hf", {}).get("token") proto_options = {"token": hf_token} if hf_token else {} else: proto_options = storage_options.get(protocol, {}) fs = fsspec.filesystem(protocol, **proto_options) if protocol != "file": normalized_path = path.split("://", 1)[1] if "://" in path else path else: normalized_path = parsed.path or path return fs, normalized_path def auto_open( urlpath, mode="rb", compression=None, encoding="utf8", errors=None, protocol=None, newline=None, expand=None, **storage_options, ) -> fsspec.core.OpenFile: """Auto select storage options based on protocol.""" if isinstance(urlpath, (list, tuple, set)): if not urlpath: raise ValueError("empty urlpath sequence") urlpath0 = fsspec.core.stringify_path(next(iter(urlpath))) else: urlpath0 = fsspec.core.stringify_path(urlpath) # Ensure that the cos protocol is registered before open. _register_cos_protocol() selected_storage_options = {} for bit, protocol, kw in fsspec.core._un_chain(urlpath0, {}): so = storage_options.get(protocol) if so: selected_storage_options[protocol] = so return fsspec.open( urlpath, mode=mode, compression=compression, encoding=encoding, errors=errors, protocol=protocol, newline=newline, expand=expand, **selected_storage_options, ) @contextlib.contextmanager def isolated_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: """A simpler version of `loop_in_thread`, don't need the ThreadPoolExecutor. https://docs.python.org/3/library/asyncio-task.html#asyncio.run_coroutine_threadsafe """ def _run_loop(_loop: asyncio.AbstractEventLoop): asyncio.set_event_loop(_loop) _loop.run_forever() try: to_cancel = asyncio.tasks.all_tasks(_loop) if not to_cancel: return for task in to_cancel: task.cancel() _loop.run_until_complete(asyncio.tasks.gather(*to_cancel, return_exceptions=True)) for task in to_cancel: if task.cancelled(): continue if task.exception() is not None: _loop.call_exception_handler( { "message": "unhandled exception during isolated loop shutdown", "exception": task.exception(), "task": task, } ) _loop.run_until_complete(_loop.shutdown_asyncgens()) _loop.run_until_complete(_loop.shutdown_default_executor()) finally: asyncio.set_event_loop(None) _loop.close() new_loop = asyncio.new_event_loop() loop_thread = threading.Thread(target=_run_loop, args=(new_loop,)) loop_thread.start() yield new_loop new_loop.call_soon_threadsafe(new_loop.stop) loop_thread.join() def safe_run(main: Coroutine) -> Any: try: asyncio.get_event_loop() except RuntimeError: return asyncio.run(main) with isolated_loop() as loop: return asyncio.run_coroutine_threadsafe(main, loop).result() def copy_video_stream( input_container: InputContainer, output_container: OutputContainer, *, start_second: float = 0, end_second: float | None = None, ): """Copies the video stream without re-encoding. This method relies on locating the I-frame in the video, which may result in slight inaccuracies. For example, if you extract a 10-second segment, the final video might be 10.3 seconds long. This behavior is equivalent to using the '-ss' option before '-i' and '-c:v copy' in ffmpeg. """ # Import av takes more than 200ms, lazy import it here. from av.audio.stream import AudioStream from av.packet import Packet from av.stream import Stream from av.video.stream import VideoStream streams_map: dict[Stream, VideoStream | AudioStream] = {} video_stream = input_container.streams.video[0] if video_stream.time_base is None: raise ValueError("video stream time base is zero, file may be broken.") for stream in input_container.streams: if stream.type == "video": streams_map[stream] = output_container.add_stream_from_template(template=cast(VideoStream, stream)) elif stream.type == "audio": streams_map[stream] = output_container.add_stream_from_template(template=cast(AudioStream, stream)) start_pts_map = {} running_stream = set(streams_map.keys()) input_container.seek(int(start_second / video_stream.time_base), stream=video_stream) # TODO(anthonycai) add precise segmentation logic, use decode method for packet in input_container.demux(*streams_map.keys()): if not isinstance(packet, Packet) or packet.pts is None: continue current_time_sec = packet.pts * packet.time_base if end_second is not None and current_time_sec > end_second: if packet.stream in running_stream: running_stream.remove(packet.stream) if len(running_stream) == 0: break continue input_stream = packet.stream if input_stream not in start_pts_map: start_pts_map[input_stream] = packet.pts offset = start_pts_map[input_stream] packet.pts -= offset if packet.dts is not None: packet.dts -= offset packet.stream = streams_map[input_stream] output_container.mux(packet) # flush encoder for stream in streams_map.values(): if stream.type == "video": output_container.mux(stream.encode(None)) def desire_video_format(list_format: list[str]) -> str: """ Select a single, explicitly supported output container format (FFmpeg muxer name) from the list of possible container formats reported by FFmpeg. Why is this function needed? 1. `input_container.format.name` returns a string like "mov,mp4,m4a,3gp,3g2,mj2", which lists **all possible container formats** that FFmpeg believes the file could be compatible with (comma-separated). Passing this full string directly to `av.open(..., format=...)` will fail because FFmpeg cannot recognize a combined name like "mov,mp4,m4a,..." — it expects **one valid muxer name**. 2. The `format` parameter in `av.open(..., mode="w", format=...)` **must be a single valid muxer name** (e.g., "mp4", "mov", "matroska"). Otherwise, it raises: `ValueError: unknown format`. 3. Priority strategy: - Prefer **mp4**: most widely supported, excellent web and mobile compatibility. - Then **mov**: common in Apple ecosystem, good compatibility with QuickTime. - Then **mkv** (as "matroska"): powerful container, supports many codecs. - Fallback: use the first format in the list (guaranteed to be recognized by FFmpeg). This ensures we always pass a **valid and highly compatible** output format to `av.open`. """ if "mp4" in list_format: output_format = "mp4" elif "mov" in list_format: output_format = "mov" elif "mkv" in list_format: output_format = "matroska" else: output_format = list_format[0] return output_format def cut_video_by_seconds( input_container: InputContainer, start_second: float, end_second: float | None = None, ) -> bytes: input_video_stream = input_container.streams.video[0] if input_video_stream.time_base is None or input_video_stream.time_base == 0: raise ValueError("video stream time base is zero, file may be broken.") with io.BytesIO() as output_buffer: with av.open( output_buffer, mode="w", format=desire_video_format(input_container.format.name.split(",")) ) as output_container: copy_video_stream(input_container, output_container, start_second=start_second, end_second=end_second) return output_buffer.getvalue() def deep_getattr(obj: object, attr: str) -> Any: try: for part in attr.split("."): obj = getattr(obj, part) return obj except AttributeError: raise AttributeError(f"'{obj.__class__.__name__}' object has no attribute '{attr}'") class LabelSpec(BaseModel): """A label with an optional description for richer LLM prompt hints.""" label: str description: str | None = None @model_validator(mode="after") def strip_and_validate(self) -> LabelSpec: self.label = self.label.strip() if not self.label: raise ValueError("label must not be empty") if self.description is not None: self.description = self.description.strip() or None return self def _format_labels(specs: list[LabelSpec]) -> str: """Format label specs into a prompt-friendly string.""" parts = [] for spec in specs: parts.append(f"label: {spec.label} description: {spec.description}" if spec.description else spec.label) return "\n".join(parts) class LLMChatCompletions: def __init__( self, base_url: str, model: str, api_key: str, max_qps: int | None = None, max_retries: int = 0, fallback_response: str | float | list[str] | None = None, response_format: str = "text", **kwargs: dict[str, Any], ): logger.info("Using Remote Http LLMChatCompletions") self.model = model self.kwargs = kwargs self.fallback_response = fallback_response supported_formats = {"json_object", "text"} if response_format not in supported_formats: raise ValueError(f"Unsupported response_format: {response_format}") self.response_format = {"type": response_format} client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries) self.create_completions = qps_limiter(max_qps)(client.chat.completions.create) async def __call__(self, messages: Iterable[ChatCompletionMessageParam]) -> ChatCompletion | None: try: return await self.create_completions( messages=messages, model=self.model, response_format=self.response_format, **self.kwargs ) except Exception as e: logger.error(f"create_completions failed: {e}") return None async def call_with_fallback( self, messages: Iterable[ChatCompletionMessageParam], post_process: Callable[[str], str | float | list[str]] | None = None, ) -> str | float | list[str] | None: """Call the LLM and extract the response content, with fallback handling on failure. Unlike ``__call__``, which returns the raw ``ChatCompletion`` object (or ``None`` on network/API errors), this method goes one step further: it validates the response, extracts the text content, applies an optional post-processing function, and handles the failure case via ``fallback_response``. Args: messages: The message list to send to the LLM. post_process: An optional function applied to the stripped response text. If ``None``, the raw text is returned as-is. Returns: The (optionally post-processed) response string or numeric value. Raises: ValueError: If the LLM call fails or returns empty content and no ``fallback_response`` is configured. """ response = await self(messages=messages) if response is not None and len(response.choices) > 0 and response.choices[0].message.content is not None: content = response.choices[0].message.content.strip() return content if post_process is None else post_process(content) else: logger.error(f"request failed, response is: {response}") if self.fallback_response is not None: return self.fallback_response else: raise ValueError("LLM call failed or returned empty content.") async def batch_generate( self, texts: pa.ChunkedArray, build_prompt: Callable[[str], Iterable[ChatCompletionMessageParam]], post_process: Callable[[str], str | float | list[str]] | None = None, datatype: pa.DataType = pa.string(), ) -> pa.Array: requests = [self.call_with_fallback(messages=build_prompt(text), post_process=post_process) for text in texts] results = await asyncio.gather(*requests) return pa.array(results, type=datatype) def _split_text_with_regex(text: str, separator: str, *, keep_separator: bool | Literal["start", "end"]) -> list[str]: """Split a text into chunks using re.split with separator.""" if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. splits_ = re.split(f"({separator})", text) splits = ( ([splits_[i] + splits_[i + 1] for i in range(0, len(splits_) - 1, 2)]) if keep_separator == "end" else ([splits_[i] + splits_[i + 1] for i in range(1, len(splits_), 2)]) ) if len(splits_) % 2 == 0: splits += splits_[-1:] splits = ([*splits, splits_[-1]]) if keep_separator == "end" else ([splits_[0], *splits]) else: splits = re.split(separator, text) else: splits = list(text) return [s for s in splits if s] class RecursiveCharacterTextSplitter: """Splitting text by recursively look at characters. Reference: https://github.com/langchain-ai/langchain/blob/v0.1.16/libs/text-splitters/langchain_text_splitters/character.py Recursively tries to split by different characters to find one that works. The splitter attempts each separator in order; if a separator is found in the text, it is used to split the text and the remaining separators are passed recursively for further splitting of oversized chunks. The resulting splits are then merged back into chunks that respect ``chunk_size`` and ``chunk_overlap`` constraints. Args: separators: An ordered list of separator strings to try when splitting. The splitter iterates through this list and uses the first separator that is found in the text. If none match, the last separator is used as a fallback. Defaults to ``["\\n\\n", "\\n", " ", ""]``, which progressively splits on paragraph breaks, line breaks, spaces, and finally individual characters. keep_separator: Controls whether the matched separator is retained in the output chunks and where it is placed. - ``True`` or ``"start"``: the separator is prepended to the following chunk. - ``"end"``: the separator is appended to the preceding chunk. - ``False``: the separator is discarded entirely. Defaults to ``True``. chunk_size: The maximum number of characters allowed in a single output chunk. Defaults to ``4000``. chunk_overlap: The number of characters that consecutive chunks are allowed to share at their boundaries. A non-zero overlap helps preserve context across chunk boundaries. Must be less than ``chunk_size``. Defaults to ``200``. is_separator_regex: If ``True``, the strings in ``separators`` are treated as regular expression patterns rather than plain strings. Defaults to ``False``. strip_whitespace: If ``True``, leading and trailing whitespace is stripped from each output chunk before it is emitted. Chunks that become empty after stripping are discarded. Defaults to ``True``. Raises: ValueError: If `chunk_size` is less than or equal to 0 ValueError: If `chunk_overlap` is less than 0 ValueError: If `chunk_overlap` is greater than `chunk_size` """ def __init__( self, separators: list[str] | None = None, keep_separator: bool | Literal["start", "end"] = True, chunk_size: int = 4000, chunk_overlap: int = 200, is_separator_regex: bool = False, strip_whitespace: bool = True, ) -> None: if chunk_size <= 0: raise ValueError("chunk_size must be > 0") if chunk_overlap < 0: raise ValueError("chunk_overlap must be >= 0") if chunk_overlap > chunk_size: raise ValueError("chunk_overlap must be less than chunk_size") self._separators = separators or ["\n\n", "\n", " ", ""] self._is_separator_regex = is_separator_regex self._length_function = len self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._keep_separator = keep_separator self._strip_whitespace = strip_whitespace def _join_docs(self, docs: list[str], separator: str) -> str | None: text = separator.join(docs) if self._strip_whitespace: text = text.strip() return text or None def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: separator_len = self._length_function(separator) docs = [] current_doc: list[str] = [] total = 0 for split in splits: chunk_len = self._length_function(split) if total + chunk_len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( "Created a chunk of size %d, which is longer than the specified %d", total, self._chunk_size, ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # Keep on popping if: # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( total + chunk_len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) current_doc = current_doc[1:] current_doc.append(split) total += chunk_len + (separator_len if len(current_doc) > 1 else 0) doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) return docs def _split_text(self, text: str, separators: list[str]) -> list[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use separator = separators[-1] new_separators = [] for i, s_ in enumerate(separators): escaped_sep = s_ if self._is_separator_regex else re.escape(s_) if not s_: separator = s_ break if re.search(escaped_sep, text): separator = s_ new_separators = separators[i + 1 :] break separator_ = separator if self._is_separator_regex else re.escape(separator) raw_splits = _split_text_with_regex(text, separator_, keep_separator=self._keep_separator) # Now go merging things, recursively splitting longer texts. good_splits = [] separator_ = "" if self._keep_separator else separator for s in raw_splits: if self._length_function(s) < self._chunk_size: good_splits.append(s) else: if good_splits: merged_text = self._merge_splits(good_splits, separator_) final_chunks.extend(merged_text) good_splits = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) if good_splits: merged_text = self._merge_splits(good_splits, separator_) final_chunks.extend(merged_text) return final_chunks def split_text(self, text: str) -> list[str]: """Split the input text into smaller chunks based on predefined separators. Args: text: The input text to be split. Returns: A list of text chunks obtained after splitting. """ return self._split_text(text, self._separators) def skip_empty_texts(func: Callable | None = None, empty_response: str = ""): """Decorator to skip empty or null texts in text processing.""" def decorator(f: Callable) -> Callable: @wraps(f) async def wrapper(self, texts: pa.ChunkedArray) -> pa.Array: combined = texts.combine_chunks() is_not_empty = pc.fill_null(pc.not_equal(pc.utf8_length(combined), 0), False) valid_indices = pc.indices_nonzero(is_not_empty) if len(valid_indices) == 0: return pa.array([empty_response] * len(combined), type=pa.string()) valid_texts = pc.filter(combined, is_not_empty) processed_results = await f(self, valid_texts) results = [empty_response] * len(combined) for idx, valid_idx in enumerate(valid_indices): results[valid_idx.as_py()] = processed_results[idx].as_py() return pa.array(results, type=pa.string()) return wrapper # Support both @skip_empty_texts and @skip_empty_texts(...) if func is not None: return decorator(func) return decorator def normalize_labels(labels: list[str], labels_name: str = "labels") -> list[str]: labels = list(dict.fromkeys(label.strip() for label in labels if label.strip())) if len(labels) == 0: raise ValueError(f"{labels_name} must not be empty") return labels def text_tokenize(arr: pa.Array, regex_patten: str | None = None, cjk=False) -> pa.Array: if regex_patten is None: if cjk: tokens = dedup_utils.utf8_split_mixed(arr) else: tokens = pc.utf8_split_whitespace(arr) else: tokens = pc.split_pattern_regex(arr, pattern=regex_patten) return tokens def sanitize_path_segment(text: str, max_len: int = 50, default_name: str = "section") -> str: """Return a filesystem-safe slug from *text*, preserving CJK characters.""" safe = _PATH_SEGMENT_SANITIZE_PATTERN.sub("", text) safe = re.sub(r"\s+", "_", safe).strip("_") if not safe: return default_name if len(safe) > max_len: u = uuid.uuid4().hex[:8] safe = f"{safe[: max_len - 9]}_{u}" return safe