from __future__ import annotations
import asyncio
import contextlib
import contextvars
import enum
import functools
import inspect
import io
import itertools
import logging
import math
import os
import re
import threading
import time
import uuid
from asyncio import events
from collections import deque
from collections.abc import Callable
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Coroutine,
Generator,
Iterable,
Literal,
ParamSpec,
Tuple,
TypeVar,
cast,
)
from urllib.parse import urlparse
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from PIL import Image
from pydantic import BaseModel, model_validator
from xpark.dataset.constants import (
AUDIO_FORMAT_MAPPING,
RAY_DOC_REPLACE_PAIRS,
WRAPPER_ASSIGNMENTS,
)
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.types import KeyType
if TYPE_CHECKING:
import av
import fsspec
import openai
import pyarrow.compute as pc
from av.container.input import InputContainer
from av.container.output import OutputContainer
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from xpark.dataset.filters.dedup import utils as dedup_utils
else:
av = lazy_import("av")
fsspec = lazy_import("fsspec")
openai = lazy_import("openai")
pc = lazy_import("pyarrow.compute", rename="pc")
pa = lazy_import("pyarrow", rename="pa")
dedup_utils = lazy_import("xpark.dataset.filters.dedup.utils", rename="dedup_utils")
logger = logging.getLogger("ray")
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
_PATH_SEGMENT_SANITIZE_PATTERN = re.compile(
r"[^\w\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7af"
r"\u3400-\u4dbf\U00020000-\U0002a6df\s-]"
)
def copy_sig(fn: Callable[P, R]) -> Callable[[Callable[..., T]], Callable[P, T]]:
"""Extend the functools.wraps (The `wraps` function name results in incorrect type inference in PyCharm,
so it has been renamed to copy_sig):
1. Skip copying __qualname__ and __module__ attributes
2. Update __doc__
3. Update return annotations
"""
def decorator(target: Callable[..., T]) -> Callable[P, T]:
signature = inspect.signature(target)
if signature.return_annotation is inspect._empty:
raise TypeError(
f"The {copy_sig.__name__} decorated function `{target.__name__}` must have a return annotation."
)
wrapped = functools.wraps(fn, assigned=WRAPPER_ASSIGNMENTS)(target)
wrapped.__doc__ = wrap_ray_doc(fn.__doc__)
wrapped.__annotations__ = wrapped.__annotations__.copy()
wrapped.__annotations__["return"] = signature.return_annotation
return cast(Callable[P, T], wrapped)
return decorator
def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping
def multi_replace(s: str | None, replace_pairs: dict[str, str]) -> str | None:
if s is None:
return s
for old, new in replace_pairs.items():
s = s.replace(old, new)
return s
def wrap_ray_doc(s: str | None) -> str | None:
return multi_replace(s, RAY_DOC_REPLACE_PAIRS)
def wrap_ray_class(ray_class: type[T]) -> type[T]:
"""Create a wrapper class that inherits from a Ray class and replaces its documentation.
This function creates a new class that:
1. Inherits from the original Ray class
2. Automatically replaces Ray-specific documentation with Xpark equivalents
3. Preserves all functionality of the original class
Similar to `copy_sig`, this function helps maintain Xpark-specific documentation
without modifying the original Ray classes.
Args:
ray_class: The Ray class to wrap.
Returns:
A new class that inherits from ray_class with updated documentation.
Example:
>>> import ray.data.aggregate
>>> Count = wrap_ray_class(ray.data.aggregate.Count)
>>> # Count now has Xpark-specific documentation
"""
# Special handling for Enum classes
if isinstance(ray_class, enum.EnumMeta):
raise Exception("Do not wrap enum class.")
# For regular classes, use inheritance
class WrappedClass(ray_class): # type: ignore
__doc__ = wrap_ray_doc(ray_class.__doc__)
# Preserve the original class name and module
WrappedClass.__name__ = ray_class.__name__
WrappedClass.__qualname__ = ray_class.__qualname__
return WrappedClass # type: ignore
def iter_batch(array: pa.ChunkedArray, batch_size: int) -> Generator[pa.ChunkedArray, None, None]:
x = len(array)
for i in range(0, x, batch_size):
a = i
b = min(i + batch_size, x)
yield array[a:b]
def qps_limiter(max_qps: int | None = None, max_concurrency: int | None = None):
"""Decorator to limit QPS (queries per second) for async functions."""
if max_qps and max_qps < 1:
raise ValueError("max_qps must be >= 1")
if max_concurrency and max_concurrency < 1:
raise ValueError("max_concurrency must be >= 1")
def _decorator(func):
if max_qps is None and max_concurrency is None:
return func
# Store timestamps of function calls
timestamps = deque()
lock = asyncio.Lock()
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
@copy_sig(func)
async def wrapper(*args, **kwargs) -> Any:
if semaphore:
await semaphore.acquire()
try:
if max_qps:
# Acquire lock to ensure that rate-limiting is concurrent-safe
async with lock:
# Current time
now = time.time()
# Remove timestamps that are older than 1 second from the current time
while timestamps and timestamps[0] < now - 1:
timestamps.popleft()
# Check if we exceeded the max QPS
if len(timestamps) >= max_qps:
# Calculate how much time we need to wait
sleep_time = 1 - (now - timestamps[0])
await asyncio.sleep(sleep_time)
timestamps.append(time.time())
return await func(*args, **kwargs)
finally:
if semaphore:
semaphore.release()
return wrapper
return _decorator
[docs]
class Count:
def __init__(self, firstval=0, step=1):
self._firstval = firstval
self._step = step
self._gen = itertools.count(firstval, step)
self._val = firstval - step
@property
def last_value(self) -> int:
return self._val
def range(self) -> Iterable[int]:
return range(self._firstval, self._val + self._step, self._step)
def next(self) -> int:
self._val = val = next(self._gen)
return val
def reset(self) -> None:
self._gen = itertools.count(self._firstval, self._step)
self._val = self._firstval - self._step
def is_reset(self):
return self._val == self._firstval - self._step
def audio_codec_to_format_ext(codec_name: str) -> tuple[str, str]:
"""
Convert audio codec name to file format and extension
Args:
codec_name: Audio codec name.
Returns:
File format and extension.
"""
codec_name = codec_name.lower()
if codec_name.startswith("pcm_"):
return "wav", "wav"
if codec_name not in AUDIO_FORMAT_MAPPING:
raise ValueError(f"Audio codec name not found in AUDIO_FORMAT_MAPPING: {codec_name}")
return AUDIO_FORMAT_MAPPING[codec_name]
@contextlib.asynccontextmanager
async def open_video(
video: pa.BinaryScalar | pa.StringScalar | str | bytes, **storage_options: dict[str, Any]
) -> AsyncGenerator[InputContainer, None]:
"""Open a video file asynchronously.
Supports multiple input types:
- pa.StringScalar: PyArrow string scalar (video path)
- pa.BinaryScalar: PyArrow binary scalar (video data)
- str: Raw string path
- bytes: Raw binary data
Args:
video: Video path or binary data.
**storage_options: Storage options for remote paths.
Yields:
PyAV InputContainer for the video.
"""
video_binary, video_url = b"", ""
# Handle different input types
if isinstance(video, str):
video_url = video
elif isinstance(video, bytes):
video_binary = video
elif isinstance(video, pa.StringScalar):
video_url = video.as_py()
elif isinstance(video, pa.BinaryScalar) and not isinstance(video, pa.StringScalar):
video_binary = video.as_py()
if video_url != "" and len(video_binary) == 0:
if video_url.startswith("http://") or video_url.startswith("https://"):
with await asyncio.to_thread(av.open, video_url) as input_container:
yield input_container
else:
with await asyncio.to_thread(auto_open, video_url, "rb", **storage_options) as file_obj:
with await asyncio.to_thread(av.open, file_obj) as input_container:
yield input_container
else:
with io.BytesIO(video_binary) as bio:
with await asyncio.to_thread(av.open, bio, "r") as input_container:
yield input_container
def assert_type(x: Any, expect_type: type[T]) -> T:
"""The helper function to narrow the types."""
assert isinstance(x, expect_type)
return x
async def to_thread(func, /, *args, _executor=None, **kwargs):
"""Asynchronously run function *func* in a separate thread.
Any *args and **kwargs supplied for this function are directly passed
to *func*. Also, the current :class:`contextvars.Context` is propagated,
allowing context variables from the main thread to be accessed in the
separate thread.
Return a coroutine that can be awaited to get the eventual result of *func*.
"""
loop = events.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(_executor, func_call)
def split_audio(
audio_data: np.ndarray,
sample_rate: int,
max_audio_clip_s: int = 30,
overlap_chunk_second: int = 1,
min_energy_split_window_size: int = 1600,
) -> list[np.ndarray]:
"""Copy from: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/speech_to_text.py"""
chunk_size = sample_rate * max_audio_clip_s
overlap_size = max(0, sample_rate * overlap_chunk_second)
chunks = []
i = 0
total_samples = audio_data.shape[-1]
while i < total_samples:
if i + chunk_size >= total_samples:
# handle last chunk
chunks.append(audio_data[..., i:])
break
search_end = min(i + chunk_size, total_samples)
search_window = max(1, min_energy_split_window_size, overlap_size)
search_range = min(chunk_size, search_window)
search_start = max(i, search_end - search_range)
split_point = _find_split_point(
audio_data,
search_start,
search_end,
min_energy_split_window_size,
)
if split_point <= i or split_point > search_end:
split_point = search_end
# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks
def _find_split_point(wav: np.ndarray, start_idx: int, end_idx: int, min_energy_split_window_size: int) -> int:
"""Find the best point to split audio by
looking for silence or low amplitude.
Args:
wav: Audio tensor [1, T]
start_idx: Start index of search region
end_idx: End index of search region
Returns:
Index of best splitting point
"""
segment = wav[start_idx:end_idx]
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
min_energy_window = min_energy_split_window_size
assert min_energy_window is not None
for i in range(0, len(segment) - min_energy_window, min_energy_window):
window = segment[i : i + min_energy_window]
energy = (window**2).mean() ** 0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx
async def auto_batch(
items: list[Any],
*,
map_fn: Callable[[int, Any], Coroutine[None, None, tuple[int, Any]] | tuple[int, Any]],
filter_fn: Callable[[Any], bool] | None = None,
) -> AsyncGenerator[tuple[list[int], list[Any]], None]:
"""Auto batch process the items by the map function.
Args:
items: A list of items
map_fn: The sync or async map function to process one item.
It should accept (index, item) and return (index, value).
filter_fn: A sync filter function to determine which items to process.
All the items will be processed if the filter_fn is None.
Returns:
An async generator that yields a tuple of (batch index, batch items).
"""
batch_index = []
batch_items = []
async_tasks: list[asyncio.Task] = []
if not inspect.iscoroutinefunction(map_fn):
map_fn = cast(
Callable[[int, Any], Coroutine[None, None, tuple[int, Any]]],
functools.partial(asyncio.to_thread, map_fn), # type: ignore[call-arg]
)
if filter_fn is not None:
for index, item in enumerate(items):
if filter_fn(item):
async_tasks.append(asyncio.create_task(map_fn(index, item)))
else:
batch_index.append(index)
batch_items.append(item)
else:
async_tasks = [asyncio.create_task(map_fn(index, item)) for index, item in enumerate(items)]
while batch_index or async_tasks:
if async_tasks:
ready, remaining = await asyncio.wait(async_tasks, timeout=0)
for task in ready:
idx, data = task.result()
batch_index.append(idx)
batch_items.append(data)
async_tasks[:] = remaining
if not batch_items:
# We got nothing, so wait until got one.
ready, remaining = await asyncio.wait(async_tasks, return_when=asyncio.FIRST_COMPLETED)
for task in ready:
idx, data = task.result()
batch_index.append(idx)
batch_items.append(data)
async_tasks[:] = remaining
assert len(batch_index) == len(batch_items)
yield batch_index, batch_items
batch_index.clear()
batch_items.clear()
@contextlib.asynccontextmanager
async def open_image(
image: pa.BinaryScalar | pa.StringScalar | pa.ExtensionScalar,
source_mode: str | None = None,
**storage_options: dict[str, Any],
) -> AsyncGenerator[Image.Image, None]:
if isinstance(image, pa.StringScalar):
image_url = image.as_py()
with await asyncio.to_thread(auto_open, image_url, **storage_options) as file_obj:
img = Image.open(file_obj)
# convert PIL.ImageFile.ImageFile to Image.Image
yield img.convert(source_mode if source_mode is not None else img.mode)
elif isinstance(image, pa.ExtensionScalar):
image_ndarray = image.as_py()
yield Image.fromarray(image_ndarray, source_mode)
# Note: pa.StringScalar is a subclass of pa.BinaryScalar, so we need to exclude it explicitly
elif isinstance(image, pa.BinaryScalar) and not isinstance(image, pa.StringScalar):
image_binary = image.as_py()
with io.BytesIO(image_binary) as bio:
with Image.open(bio) as img:
# convert PIL.ImageFile.ImageFile to Image.Image
yield img.convert(source_mode if source_mode is not None else img.mode)
async def read_image(
image: pa.BinaryScalar | pa.StringScalar | pa.ExtensionScalar,
source_mode: str | None = None,
**storage_options: dict[str, Any],
) -> Image.Image:
async with open_image(image, source_mode, **storage_options) as img:
return img
def _register_cos_protocol():
# Import s3fs takes more than 200ms, lazy import it here.
cos_protocol = "cos"
if cos_protocol not in fsspec.registry:
from s3fs import S3FileSystem
class COSFileSystem(S3FileSystem):
protocol = cos_protocol
try:
fsspec.register_implementation(cos_protocol, COSFileSystem)
except ValueError:
pass
if cos_protocol not in fsspec.registry:
raise Exception("Register COSFileSystem failed!")
def _register_hf_protocol():
hf_protocol = "hf"
if hf_protocol not in fsspec.registry:
try:
from huggingface_hub import HfFileSystem
fsspec.register_implementation(hf_protocol, HfFileSystem)
except (ImportError, ValueError):
pass
def get_filesystem(path: str, **storage_options) -> Tuple["fsspec.AbstractFileSystem", str]:
"""Get filesystem and normalized path from a path string."""
parsed = urlparse(path)
protocol = parsed.scheme or "file"
if protocol == "cos":
_register_cos_protocol()
elif protocol == "hf":
_register_hf_protocol()
if protocol == "hf":
hf_token = os.environ.get("HF_TOKEN") or storage_options.get("hf", {}).get("token")
proto_options = {"token": hf_token} if hf_token else {}
else:
proto_options = storage_options.get(protocol, {})
fs = fsspec.filesystem(protocol, **proto_options)
if protocol != "file":
normalized_path = path.split("://", 1)[1] if "://" in path else path
else:
normalized_path = parsed.path or path
return fs, normalized_path
def auto_open(
urlpath,
mode="rb",
compression=None,
encoding="utf8",
errors=None,
protocol=None,
newline=None,
expand=None,
**storage_options,
) -> fsspec.core.OpenFile:
"""Auto select storage options based on protocol."""
if isinstance(urlpath, (list, tuple, set)):
if not urlpath:
raise ValueError("empty urlpath sequence")
urlpath0 = fsspec.core.stringify_path(next(iter(urlpath)))
else:
urlpath0 = fsspec.core.stringify_path(urlpath)
# Ensure that the cos protocol is registered before open.
_register_cos_protocol()
selected_storage_options = {}
for bit, protocol, kw in fsspec.core._un_chain(urlpath0, {}):
so = storage_options.get(protocol)
if so:
selected_storage_options[protocol] = so
return fsspec.open(
urlpath,
mode=mode,
compression=compression,
encoding=encoding,
errors=errors,
protocol=protocol,
newline=newline,
expand=expand,
**selected_storage_options,
)
@contextlib.contextmanager
def isolated_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""A simpler version of `loop_in_thread`, don't need the ThreadPoolExecutor.
https://docs.python.org/3/library/asyncio-task.html#asyncio.run_coroutine_threadsafe
"""
def _run_loop(_loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(_loop)
_loop.run_forever()
try:
to_cancel = asyncio.tasks.all_tasks(_loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
_loop.run_until_complete(asyncio.tasks.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
_loop.call_exception_handler(
{
"message": "unhandled exception during isolated loop shutdown",
"exception": task.exception(),
"task": task,
}
)
_loop.run_until_complete(_loop.shutdown_asyncgens())
_loop.run_until_complete(_loop.shutdown_default_executor())
finally:
asyncio.set_event_loop(None)
_loop.close()
new_loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=_run_loop, args=(new_loop,))
loop_thread.start()
yield new_loop
new_loop.call_soon_threadsafe(new_loop.stop)
loop_thread.join()
def safe_run(main: Coroutine) -> Any:
try:
asyncio.get_event_loop()
except RuntimeError:
return asyncio.run(main)
with isolated_loop() as loop:
return asyncio.run_coroutine_threadsafe(main, loop).result()
def copy_video_stream(
input_container: InputContainer,
output_container: OutputContainer,
*,
start_second: float = 0,
end_second: float | None = None,
):
"""Copies the video stream without re-encoding. This method relies on locating the I-frame in the video,
which may result in slight inaccuracies. For example, if you extract a 10-second segment,
the final video might be 10.3 seconds long. This behavior is equivalent to using the '-ss' option
before '-i' and '-c:v copy' in ffmpeg.
"""
# Import av takes more than 200ms, lazy import it here.
from av.audio.stream import AudioStream
from av.packet import Packet
from av.stream import Stream
from av.video.stream import VideoStream
streams_map: dict[Stream, VideoStream | AudioStream] = {}
video_stream = input_container.streams.video[0]
if video_stream.time_base is None:
raise ValueError("video stream time base is zero, file may be broken.")
for stream in input_container.streams:
if stream.type == "video":
streams_map[stream] = output_container.add_stream_from_template(template=cast(VideoStream, stream))
elif stream.type == "audio":
streams_map[stream] = output_container.add_stream_from_template(template=cast(AudioStream, stream))
start_pts_map = {}
running_stream = set(streams_map.keys())
input_container.seek(int(start_second / video_stream.time_base), stream=video_stream)
# TODO(anthonycai) add precise segmentation logic, use decode method
for packet in input_container.demux(*streams_map.keys()):
if not isinstance(packet, Packet) or packet.pts is None:
continue
current_time_sec = packet.pts * packet.time_base
if end_second is not None and current_time_sec > end_second:
if packet.stream in running_stream:
running_stream.remove(packet.stream)
if len(running_stream) == 0:
break
continue
input_stream = packet.stream
if input_stream not in start_pts_map:
start_pts_map[input_stream] = packet.pts
offset = start_pts_map[input_stream]
packet.pts -= offset
if packet.dts is not None:
packet.dts -= offset
packet.stream = streams_map[input_stream]
output_container.mux(packet)
# flush encoder
for stream in streams_map.values():
if stream.type == "video":
output_container.mux(stream.encode(None))
def desire_video_format(list_format: list[str]) -> str:
"""
Select a single, explicitly supported output container format (FFmpeg muxer name)
from the list of possible container formats reported by FFmpeg.
Why is this function needed?
1. `input_container.format.name` returns a string like "mov,mp4,m4a,3gp,3g2,mj2",
which lists **all possible container formats** that FFmpeg believes the file
could be compatible with (comma-separated).
Passing this full string directly to `av.open(..., format=...)` will fail
because FFmpeg cannot recognize a combined name like "mov,mp4,m4a,..." —
it expects **one valid muxer name**.
2. The `format` parameter in `av.open(..., mode="w", format=...)` **must be a single
valid muxer name** (e.g., "mp4", "mov", "matroska"). Otherwise, it raises:
`ValueError: unknown format`.
3. Priority strategy:
- Prefer **mp4**: most widely supported, excellent web and mobile compatibility.
- Then **mov**: common in Apple ecosystem, good compatibility with QuickTime.
- Then **mkv** (as "matroska"): powerful container, supports many codecs.
- Fallback: use the first format in the list (guaranteed to be recognized by FFmpeg).
This ensures we always pass a **valid and highly compatible** output format to `av.open`.
"""
if "mp4" in list_format:
output_format = "mp4"
elif "mov" in list_format:
output_format = "mov"
elif "mkv" in list_format:
output_format = "matroska"
else:
output_format = list_format[0]
return output_format
def cut_video_by_seconds(
input_container: InputContainer,
start_second: float,
end_second: float | None = None,
) -> bytes:
input_video_stream = input_container.streams.video[0]
if input_video_stream.time_base is None or input_video_stream.time_base == 0:
raise ValueError("video stream time base is zero, file may be broken.")
with io.BytesIO() as output_buffer:
with av.open(
output_buffer, mode="w", format=desire_video_format(input_container.format.name.split(","))
) as output_container:
copy_video_stream(input_container, output_container, start_second=start_second, end_second=end_second)
return output_buffer.getvalue()
def deep_getattr(obj: object, attr: str) -> Any:
try:
for part in attr.split("."):
obj = getattr(obj, part)
return obj
except AttributeError:
raise AttributeError(f"'{obj.__class__.__name__}' object has no attribute '{attr}'")
class LabelSpec(BaseModel):
"""A label with an optional description for richer LLM prompt hints."""
label: str
description: str | None = None
@model_validator(mode="after")
def strip_and_validate(self) -> LabelSpec:
self.label = self.label.strip()
if not self.label:
raise ValueError("label must not be empty")
if self.description is not None:
self.description = self.description.strip() or None
return self
def _format_labels(specs: list[LabelSpec]) -> str:
"""Format label specs into a prompt-friendly string."""
parts = []
for spec in specs:
parts.append(f"label: {spec.label} description: {spec.description}" if spec.description else spec.label)
return "\n".join(parts)
class LLMChatCompletions:
def __init__(
self,
base_url: str,
model: str,
api_key: str,
max_qps: int | None = None,
max_retries: int = 0,
fallback_response: str | float | list[str] | None = None,
response_format: str = "text",
**kwargs: dict[str, Any],
):
logger.info("Using Remote Http LLMChatCompletions")
self.model = model
self.kwargs = kwargs
self.fallback_response = fallback_response
supported_formats = {"json_object", "text"}
if response_format not in supported_formats:
raise ValueError(f"Unsupported response_format: {response_format}")
self.response_format = {"type": response_format}
client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries)
self.create_completions = qps_limiter(max_qps)(client.chat.completions.create)
async def __call__(self, messages: Iterable[ChatCompletionMessageParam]) -> ChatCompletion | None:
try:
return await self.create_completions(
messages=messages, model=self.model, response_format=self.response_format, **self.kwargs
)
except Exception as e:
logger.error(f"create_completions failed: {e}")
return None
async def call_with_fallback(
self,
messages: Iterable[ChatCompletionMessageParam],
post_process: Callable[[str], str | float | list[str]] | None = None,
) -> str | float | list[str] | None:
"""Call the LLM and extract the response content, with fallback handling on failure.
Unlike ``__call__``, which returns the raw ``ChatCompletion`` object (or ``None`` on
network/API errors), this method goes one step further: it validates the response,
extracts the text content, applies an optional post-processing function, and handles
the failure case via ``fallback_response``.
Args:
messages: The message list to send to the LLM.
post_process: An optional function applied to the stripped response text.
If ``None``, the raw text is returned as-is.
Returns:
The (optionally post-processed) response string or numeric value.
Raises:
ValueError: If the LLM call fails or returns empty content and no
``fallback_response`` is configured.
"""
response = await self(messages=messages)
if response is not None and len(response.choices) > 0 and response.choices[0].message.content is not None:
content = response.choices[0].message.content.strip()
return content if post_process is None else post_process(content)
else:
logger.error(f"request failed, response is: {response}")
if self.fallback_response is not None:
return self.fallback_response
else:
raise ValueError("LLM call failed or returned empty content.")
async def batch_generate(
self,
texts: pa.ChunkedArray,
build_prompt: Callable[[str], Iterable[ChatCompletionMessageParam]],
post_process: Callable[[str], str | float | list[str]] | None = None,
datatype: pa.DataType = pa.string(),
) -> pa.Array:
requests = [self.call_with_fallback(messages=build_prompt(text), post_process=post_process) for text in texts]
results = await asyncio.gather(*requests)
return pa.array(results, type=datatype)
def _split_text_with_regex(text: str, separator: str, *, keep_separator: bool | Literal["start", "end"]) -> list[str]:
"""Split a text into chunks using re.split with separator."""
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
splits_ = re.split(f"({separator})", text)
splits = (
([splits_[i] + splits_[i + 1] for i in range(0, len(splits_) - 1, 2)])
if keep_separator == "end"
else ([splits_[i] + splits_[i + 1] for i in range(1, len(splits_), 2)])
)
if len(splits_) % 2 == 0:
splits += splits_[-1:]
splits = ([*splits, splits_[-1]]) if keep_separator == "end" else ([splits_[0], *splits])
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s]
class RecursiveCharacterTextSplitter:
"""Splitting text by recursively look at characters.
Reference: https://github.com/langchain-ai/langchain/blob/v0.1.16/libs/text-splitters/langchain_text_splitters/character.py
Recursively tries to split by different characters to find one
that works. The splitter attempts each separator in order; if a separator
is found in the text, it is used to split the text and the remaining
separators are passed recursively for further splitting of oversized chunks.
The resulting splits are then merged back into chunks that respect
``chunk_size`` and ``chunk_overlap`` constraints.
Args:
separators: An ordered list of separator strings to try when splitting.
The splitter iterates through this list and uses the first separator
that is found in the text. If none match, the last separator is used
as a fallback. Defaults to ``["\\n\\n", "\\n", " ", ""]``, which
progressively splits on paragraph breaks, line breaks, spaces, and
finally individual characters.
keep_separator: Controls whether the matched separator is retained in
the output chunks and where it is placed.
- ``True`` or ``"start"``: the separator is prepended to the
following chunk.
- ``"end"``: the separator is appended to the preceding chunk.
- ``False``: the separator is discarded entirely.
Defaults to ``True``.
chunk_size: The maximum number of characters allowed in a single output
chunk. Defaults to ``4000``.
chunk_overlap: The number of characters that consecutive chunks are
allowed to share at their boundaries. A non-zero overlap helps
preserve context across chunk boundaries. Must be less than
``chunk_size``. Defaults to ``200``.
is_separator_regex: If ``True``, the strings in ``separators`` are
treated as regular expression patterns rather than plain strings.
Defaults to ``False``.
strip_whitespace: If ``True``, leading and trailing whitespace is
stripped from each output chunk before it is emitted. Chunks that
become empty after stripping are discarded. Defaults to ``True``.
Raises:
ValueError: If `chunk_size` is less than or equal to 0
ValueError: If `chunk_overlap` is less than 0
ValueError: If `chunk_overlap` is greater than `chunk_size`
"""
def __init__(
self,
separators: list[str] | None = None,
keep_separator: bool | Literal["start", "end"] = True,
chunk_size: int = 4000,
chunk_overlap: int = 200,
is_separator_regex: bool = False,
strip_whitespace: bool = True,
) -> None:
if chunk_size <= 0:
raise ValueError("chunk_size must be > 0")
if chunk_overlap < 0:
raise ValueError("chunk_overlap must be >= 0")
if chunk_overlap > chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex
self._length_function = len
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._keep_separator = keep_separator
self._strip_whitespace = strip_whitespace
def _join_docs(self, docs: list[str], separator: str) -> str | None:
text = separator.join(docs)
if self._strip_whitespace:
text = text.strip()
return text or None
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
separator_len = self._length_function(separator)
docs = []
current_doc: list[str] = []
total = 0
for split in splits:
chunk_len = self._length_function(split)
if total + chunk_len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
if total > self._chunk_size:
logger.warning(
"Created a chunk of size %d, which is longer than the specified %d",
total,
self._chunk_size,
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + chunk_len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
current_doc = current_doc[1:]
current_doc.append(split)
total += chunk_len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs
def _split_text(self, text: str, separators: list[str]) -> list[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, s_ in enumerate(separators):
escaped_sep = s_ if self._is_separator_regex else re.escape(s_)
if not s_:
separator = s_
break
if re.search(escaped_sep, text):
separator = s_
new_separators = separators[i + 1 :]
break
separator_ = separator if self._is_separator_regex else re.escape(separator)
raw_splits = _split_text_with_regex(text, separator_, keep_separator=self._keep_separator)
# Now go merging things, recursively splitting longer texts.
good_splits = []
separator_ = "" if self._keep_separator else separator
for s in raw_splits:
if self._length_function(s) < self._chunk_size:
good_splits.append(s)
else:
if good_splits:
merged_text = self._merge_splits(good_splits, separator_)
final_chunks.extend(merged_text)
good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if good_splits:
merged_text = self._merge_splits(good_splits, separator_)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> list[str]:
"""Split the input text into smaller chunks based on predefined separators.
Args:
text: The input text to be split.
Returns:
A list of text chunks obtained after splitting.
"""
return self._split_text(text, self._separators)
def skip_empty_texts(func: Callable | None = None, empty_response: str = ""):
"""Decorator to skip empty or null texts in text processing."""
def decorator(f: Callable) -> Callable:
@wraps(f)
async def wrapper(self, texts: pa.ChunkedArray) -> pa.Array:
combined = texts.combine_chunks()
is_not_empty = pc.fill_null(pc.not_equal(pc.utf8_length(combined), 0), False)
valid_indices = pc.indices_nonzero(is_not_empty)
if len(valid_indices) == 0:
return pa.array([empty_response] * len(combined), type=pa.string())
valid_texts = pc.filter(combined, is_not_empty)
processed_results = await f(self, valid_texts)
results = [empty_response] * len(combined)
for idx, valid_idx in enumerate(valid_indices):
results[valid_idx.as_py()] = processed_results[idx].as_py()
return pa.array(results, type=pa.string())
return wrapper
# Support both @skip_empty_texts and @skip_empty_texts(...)
if func is not None:
return decorator(func)
return decorator
def normalize_labels(labels: list[str], labels_name: str = "labels") -> list[str]:
labels = list(dict.fromkeys(label.strip() for label in labels if label.strip()))
if len(labels) == 0:
raise ValueError(f"{labels_name} must not be empty")
return labels
def text_tokenize(arr: pa.Array, regex_patten: str | None = None, cjk=False) -> pa.Array:
if regex_patten is None:
if cjk:
tokens = dedup_utils.utf8_split_mixed(arr)
else:
tokens = pc.utf8_split_whitespace(arr)
else:
tokens = pc.split_pattern_regex(arr, pattern=regex_patten)
return tokens
def sanitize_path_segment(text: str, max_len: int = 50, default_name: str = "section") -> str:
"""Return a filesystem-safe slug from *text*, preserving CJK characters."""
safe = _PATH_SEGMENT_SANITIZE_PATTERN.sub("", text)
safe = re.sub(r"\s+", "_", safe).strip("_")
if not safe:
return default_name
if len(safe) > max_len:
u = uuid.uuid4().hex[:8]
safe = f"{safe[: max_len - 9]}_{u}"
return safe