from __future__ import annotations
import asyncio
import contextlib
import contextvars
import enum
import functools
import inspect
import io
import itertools
import logging
import math
import os
import threading
import time
from asyncio import events
from collections import deque
from collections.abc import Callable
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Coroutine,
Generator,
Iterable,
ParamSpec,
Tuple,
TypeVar,
cast,
)
from urllib.parse import urlparse
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from PIL import Image
from xpark.dataset.constants import (
AUDIO_FORMAT_MAPPING,
RAY_DOC_REPLACE_PAIRS,
WRAPPER_ASSIGNMENTS,
)
from xpark.dataset.context import DatasetContext
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.types import KeyType
if TYPE_CHECKING:
import av
import fsspec
import openai
import pyarrow.compute as pc
from av.container.input import InputContainer
from av.container.output import OutputContainer
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import (
ChatCompletionMessageParam,
)
from xpark.dataset.filters.dedup import utils as dedup_utils
else:
av = lazy_import("av")
fsspec = lazy_import("fsspec")
openai = lazy_import("openai")
pc = lazy_import("pyarrow.compute", rename="pc")
pa = lazy_import("pyarrow", rename="pa")
dedup_utils = lazy_import("xpark.dataset.filters.dedup.utils", rename="dedup_utils")
logger = logging.getLogger("ray")
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def copy_sig(fn: Callable[P, R]) -> Callable[[Callable[..., T]], Callable[P, T]]:
"""Extend the functools.wraps (The `wraps` function name results in incorrect type inference in PyCharm,
so it has been renamed to copy_sig):
1. Skip copying __qualname__ and __module__ attributes
2. Update __doc__
3. Update return annotations
"""
def decorator(target: Callable[..., T]) -> Callable[P, T]:
signature = inspect.signature(target)
if signature.return_annotation is inspect._empty:
raise TypeError(
f"The {copy_sig.__name__} decorated function `{target.__name__}` must have a return annotation."
)
wrapped = functools.wraps(fn, assigned=WRAPPER_ASSIGNMENTS)(target)
wrapped.__doc__ = wrap_ray_doc(fn.__doc__)
wrapped.__annotations__ = wrapped.__annotations__.copy()
wrapped.__annotations__["return"] = signature.return_annotation
return cast(Callable[P, T], wrapped)
return decorator
def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping
def multi_replace(s: str | None, replace_pairs: dict[str, str]) -> str | None:
if s is None:
return s
for old, new in replace_pairs.items():
s = s.replace(old, new)
return s
def wrap_ray_doc(s: str | None) -> str | None:
return multi_replace(s, RAY_DOC_REPLACE_PAIRS)
def wrap_ray_class(ray_class: type[T]) -> type[T]:
"""Create a wrapper class that inherits from a Ray class and replaces its documentation.
This function creates a new class that:
1. Inherits from the original Ray class
2. Automatically replaces Ray-specific documentation with Xpark equivalents
3. Preserves all functionality of the original class
Similar to `copy_sig`, this function helps maintain Xpark-specific documentation
without modifying the original Ray classes.
Args:
ray_class: The Ray class to wrap.
Returns:
A new class that inherits from ray_class with updated documentation.
Example:
>>> import ray.data.aggregate
>>> Count = wrap_ray_class(ray.data.aggregate.Count)
>>> # Count now has Xpark-specific documentation
"""
# Special handling for Enum classes
if isinstance(ray_class, enum.EnumMeta):
raise Exception("Do not wrap enum class.")
# For regular classes, use inheritance
class WrappedClass(ray_class): # type: ignore
__doc__ = wrap_ray_doc(ray_class.__doc__)
# Preserve the original class name and module
WrappedClass.__name__ = ray_class.__name__
WrappedClass.__qualname__ = ray_class.__qualname__
return WrappedClass # type: ignore
def iter_batch(array: pa.ChunkedArray, batch_size: int) -> Generator[pa.ChunkedArray, None, None]:
x = len(array)
for i in range(0, x, batch_size):
a = i
b = min(i + batch_size, x)
yield array[a:b]
def qps_limiter(max_qps: int | None = None, max_concurrency: int | None = None):
"""Decorator to limit QPS (queries per second) for async functions."""
if max_qps and max_qps < 1:
raise ValueError("max_qps must be >= 1")
if max_concurrency and max_concurrency < 1:
raise ValueError("max_concurrency must be >= 1")
def _decorator(func):
if max_qps is None and max_concurrency is None:
return func
# Store timestamps of function calls
timestamps = deque()
lock = asyncio.Lock()
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
@copy_sig(func)
async def wrapper(*args, **kwargs) -> Any:
if semaphore:
await semaphore.acquire()
try:
if max_qps:
# Acquire lock to ensure that rate-limiting is concurrent-safe
async with lock:
# Current time
now = time.time()
# Remove timestamps that are older than 1 second from the current time
while timestamps and timestamps[0] < now - 1:
timestamps.popleft()
# Check if we exceeded the max QPS
if len(timestamps) >= max_qps:
# Calculate how much time we need to wait
sleep_time = 1 - (now - timestamps[0])
await asyncio.sleep(sleep_time)
timestamps.append(time.time())
return await func(*args, **kwargs)
finally:
if semaphore:
semaphore.release()
return wrapper
return _decorator
[docs]
class Count:
def __init__(self, firstval=0, step=1):
self._firstval = firstval
self._step = step
self._gen = itertools.count(firstval, step)
self._val = firstval - step
@property
def last_value(self) -> int:
return self._val
def range(self) -> Iterable[int]:
return range(self._firstval, self._val + self._step, self._step)
def next(self) -> int:
self._val = val = next(self._gen)
return val
def reset(self) -> None:
self._gen = itertools.count(self._firstval, self._step)
self._val = self._firstval - self._step
def is_reset(self):
return self._val == self._firstval - self._step
def audio_codec_to_format_ext(codec_name: str) -> tuple[str, str]:
"""
Convert audio codec name to file format and extension
Args:
codec_name: Audio codec name.
Returns:
File format and extension.
"""
codec_name = codec_name.lower()
if codec_name.startswith("pcm_"):
return "wav", "wav"
if codec_name not in AUDIO_FORMAT_MAPPING:
raise ValueError(f"Audio codec name not found in AUDIO_FORMAT_MAPPING: {codec_name}")
return AUDIO_FORMAT_MAPPING[codec_name]
@contextlib.asynccontextmanager
async def open_video(
video: pa.BinaryScalar | pa.StringScalar | str | bytes, **storage_options: dict[str, Any]
) -> AsyncGenerator[InputContainer, None]:
"""Open a video file asynchronously.
Supports multiple input types:
- pa.StringScalar: PyArrow string scalar (video path)
- pa.BinaryScalar: PyArrow binary scalar (video data)
- str: Raw string path
- bytes: Raw binary data
Args:
video: Video path or binary data.
**storage_options: Storage options for remote paths.
Yields:
PyAV InputContainer for the video.
"""
video_binary, video_url = b"", ""
# Handle different input types
if isinstance(video, str):
video_url = video
elif isinstance(video, bytes):
video_binary = video
elif isinstance(video, pa.StringScalar):
video_url = video.as_py()
elif isinstance(video, pa.BinaryScalar) and not isinstance(video, pa.StringScalar):
video_binary = video.as_py()
if video_url != "" and len(video_binary) == 0:
if video_url.startswith("http://") or video_url.startswith("https://"):
with await asyncio.to_thread(av.open, video_url) as input_container:
yield input_container
else:
with await asyncio.to_thread(auto_open, video_url, "rb", **storage_options) as file_obj:
with await asyncio.to_thread(av.open, file_obj) as input_container:
yield input_container
else:
with io.BytesIO(video_binary) as bio:
with await asyncio.to_thread(av.open, bio, "r") as input_container:
yield input_container
def assert_type(x: Any, expect_type: type[T]) -> T:
"""The helper function to narrow the types."""
assert isinstance(x, expect_type)
return x
async def to_thread(func, /, *args, _executor=None, **kwargs):
"""Asynchronously run function *func* in a separate thread.
Any *args and **kwargs supplied for this function are directly passed
to *func*. Also, the current :class:`contextvars.Context` is propagated,
allowing context variables from the main thread to be accessed in the
separate thread.
Return a coroutine that can be awaited to get the eventual result of *func*.
"""
loop = events.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(_executor, func_call)
def split_audio(
audio_data: np.ndarray,
sample_rate: int,
max_audio_clip_s: int = 30,
overlap_chunk_second: int = 1,
min_energy_split_window_size: int = 1600,
) -> list[np.ndarray]:
"""Copy from: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/speech_to_text.py"""
chunk_size = sample_rate * max_audio_clip_s
overlap_size = max(0, sample_rate * overlap_chunk_second)
chunks = []
i = 0
total_samples = audio_data.shape[-1]
while i < total_samples:
if i + chunk_size >= total_samples:
# handle last chunk
chunks.append(audio_data[..., i:])
break
search_end = min(i + chunk_size, total_samples)
search_window = max(1, min_energy_split_window_size, overlap_size)
search_range = min(chunk_size, search_window)
search_start = max(i, search_end - search_range)
split_point = _find_split_point(
audio_data,
search_start,
search_end,
min_energy_split_window_size,
)
if split_point <= i or split_point > search_end:
split_point = search_end
# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks
def _find_split_point(wav: np.ndarray, start_idx: int, end_idx: int, min_energy_split_window_size: int) -> int:
"""Find the best point to split audio by
looking for silence or low amplitude.
Args:
wav: Audio tensor [1, T]
start_idx: Start index of search region
end_idx: End index of search region
Returns:
Index of best splitting point
"""
segment = wav[start_idx:end_idx]
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
min_energy_window = min_energy_split_window_size
assert min_energy_window is not None
for i in range(0, len(segment) - min_energy_window, min_energy_window):
window = segment[i : i + min_energy_window]
energy = (window**2).mean() ** 0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx
async def auto_batch(
items: list[Any],
*,
map_fn: Callable[[int, Any], Coroutine[None, None, tuple[int, Any]] | tuple[int, Any]],
filter_fn: Callable[[Any], bool] | None = None,
) -> AsyncGenerator[tuple[list[int], list[Any]], None]:
"""Auto batch process the items by the map function.
Args:
items: A list of items
map_fn: The sync or async map function to process one item.
It should accept (index, item) and return (index, value).
filter_fn: A sync filter function to determine which items to process.
All the items will be processed if the filter_fn is None.
Returns:
An async generator that yields a tuple of (batch index, batch items).
"""
batch_index = []
batch_items = []
async_tasks: list[asyncio.Task] = []
if not inspect.iscoroutinefunction(map_fn):
map_fn = cast(
Callable[[int, Any], Coroutine[None, None, tuple[int, Any]]],
functools.partial(asyncio.to_thread, map_fn), # type: ignore[call-arg]
)
if filter_fn is not None:
for index, item in enumerate(items):
if filter_fn(item):
async_tasks.append(asyncio.create_task(map_fn(index, item)))
else:
batch_index.append(index)
batch_items.append(item)
else:
async_tasks = [asyncio.create_task(map_fn(index, item)) for index, item in enumerate(items)]
while batch_index or async_tasks:
if async_tasks:
ready, remaining = await asyncio.wait(async_tasks, timeout=0)
for task in ready:
idx, data = task.result()
batch_index.append(idx)
batch_items.append(data)
async_tasks[:] = remaining
if not batch_items:
# We got nothing, so wait until got one.
ready, remaining = await asyncio.wait(async_tasks, return_when=asyncio.FIRST_COMPLETED)
for task in ready:
idx, data = task.result()
batch_index.append(idx)
batch_items.append(data)
async_tasks[:] = remaining
assert len(batch_index) == len(batch_items)
yield batch_index, batch_items
batch_index.clear()
batch_items.clear()
@contextlib.asynccontextmanager
async def open_image(
image: pa.BinaryScalar | pa.StringScalar | pa.ExtensionScalar,
source_mode: str | None = None,
**storage_options: dict[str, Any],
) -> AsyncGenerator[Image.Image, None]:
if isinstance(image, pa.StringScalar):
image_url = image.as_py()
with await asyncio.to_thread(auto_open, image_url, **storage_options) as file_obj:
img = Image.open(file_obj)
# convert PIL.ImageFile.ImageFile to Image.Image
yield img.convert(source_mode if source_mode is not None else img.mode)
elif isinstance(image, pa.ExtensionScalar):
image_ndarray = image.as_py()
yield Image.fromarray(image_ndarray, source_mode)
# Note: pa.StringScalar is a subclass of pa.BinaryScalar, so we need to exclude it explicitly
elif isinstance(image, pa.BinaryScalar) and not isinstance(image, pa.StringScalar):
image_binary = image.as_py()
with io.BytesIO(image_binary) as bio:
with Image.open(bio) as img:
# convert PIL.ImageFile.ImageFile to Image.Image
yield img.convert(source_mode if source_mode is not None else img.mode)
async def read_image(
image: pa.BinaryScalar | pa.StringScalar | pa.ExtensionScalar,
source_mode: str | None = None,
**storage_options: dict[str, Any],
) -> Image.Image:
async with open_image(image, source_mode, **storage_options) as img:
return img
def _register_cos_protocol():
# Import s3fs takes more than 200ms, lazy import it here.
cos_protocol = "cos"
if cos_protocol not in fsspec.registry:
from s3fs import S3FileSystem
class COSFileSystem(S3FileSystem):
protocol = cos_protocol
try:
fsspec.register_implementation(cos_protocol, COSFileSystem)
except ValueError:
pass
if cos_protocol not in fsspec.registry:
raise Exception("Register COSFileSystem failed!")
def _register_hf_protocol():
hf_protocol = "hf"
if hf_protocol not in fsspec.registry:
try:
from huggingface_hub import HfFileSystem
fsspec.register_implementation(hf_protocol, HfFileSystem)
except (ImportError, ValueError):
pass
def get_filesystem(path: str) -> Tuple["fsspec.AbstractFileSystem", str]:
"""Get filesystem and normalized path from a path string."""
parsed = urlparse(path)
protocol = parsed.scheme or "file"
if protocol == "cos":
_register_cos_protocol()
elif protocol == "hf":
_register_hf_protocol()
ctx = DatasetContext.get_current()
storage_options = ctx.storage_options
if protocol == "hf":
hf_token = os.environ.get("HF_TOKEN") or storage_options.get("hf", {}).get("token")
proto_options = {"token": hf_token} if hf_token else {}
else:
proto_options = storage_options.get(protocol, {})
fs = fsspec.filesystem(protocol, **proto_options)
if protocol != "file":
normalized_path = path.split("://", 1)[1] if "://" in path else path
else:
normalized_path = parsed.path or path
return fs, normalized_path
def auto_open(
urlpath,
mode="rb",
compression=None,
encoding="utf8",
errors=None,
protocol=None,
newline=None,
expand=None,
**storage_options,
) -> fsspec.core.OpenFile:
"""Auto select storage options based on protocol."""
if isinstance(urlpath, (list, tuple, set)):
if not urlpath:
raise ValueError("empty urlpath sequence")
urlpath0 = fsspec.core.stringify_path(next(iter(urlpath)))
else:
urlpath0 = fsspec.core.stringify_path(urlpath)
# Ensure that the cos protocol is registered before open.
_register_cos_protocol()
selected_storage_options = {}
for bit, protocol, kw in fsspec.core._un_chain(urlpath0, {}):
so = storage_options.get(protocol)
if so:
selected_storage_options[protocol] = so
return fsspec.open(
urlpath,
mode=mode,
compression=compression,
encoding=encoding,
errors=errors,
protocol=protocol,
newline=newline,
expand=expand,
**selected_storage_options,
)
@contextlib.contextmanager
def isolated_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""A simpler version of `loop_in_thread`, don't need the ThreadPoolExecutor.
https://docs.python.org/3/library/asyncio-task.html#asyncio.run_coroutine_threadsafe
"""
def _run_loop(_loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(_loop)
_loop.run_forever()
try:
to_cancel = asyncio.tasks.all_tasks(_loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
_loop.run_until_complete(asyncio.tasks.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
_loop.call_exception_handler(
{
"message": "unhandled exception during isolated loop shutdown",
"exception": task.exception(),
"task": task,
}
)
_loop.run_until_complete(_loop.shutdown_asyncgens())
_loop.run_until_complete(_loop.shutdown_default_executor())
finally:
asyncio.set_event_loop(None)
_loop.close()
new_loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=_run_loop, args=(new_loop,))
loop_thread.start()
yield new_loop
new_loop.call_soon_threadsafe(new_loop.stop)
loop_thread.join()
def safe_run(main: Coroutine) -> Any:
try:
asyncio.get_event_loop()
except RuntimeError:
return asyncio.run(main)
with isolated_loop() as loop:
return asyncio.run_coroutine_threadsafe(main, loop).result()
def copy_video_stream(
input_container: InputContainer,
output_container: OutputContainer,
*,
start_second: float = 0,
end_second: float | None = None,
):
"""Copies the video stream without re-encoding. This method relies on locating the I-frame in the video,
which may result in slight inaccuracies. For example, if you extract a 10-second segment,
the final video might be 10.3 seconds long. This behavior is equivalent to using the '-ss' option
before '-i' and '-c:v copy' in ffmpeg.
"""
# Import av takes more than 200ms, lazy import it here.
from av.audio.stream import AudioStream
from av.packet import Packet
from av.stream import Stream
from av.video.stream import VideoStream
streams_map: dict[Stream, VideoStream | AudioStream] = {}
video_stream = input_container.streams.video[0]
if video_stream.time_base is None:
raise ValueError("video stream time base is zero, file may be broken.")
for stream in input_container.streams:
if stream.type == "video":
streams_map[stream] = output_container.add_stream_from_template(template=cast(VideoStream, stream))
elif stream.type == "audio":
streams_map[stream] = output_container.add_stream_from_template(template=cast(AudioStream, stream))
start_pts_map = {}
running_stream = set(streams_map.keys())
input_container.seek(int(start_second / video_stream.time_base), stream=video_stream)
# TODO(anthonycai) add precise segmentation logic, use decode method
for packet in input_container.demux(*streams_map.keys()):
if not isinstance(packet, Packet) or packet.pts is None:
continue
current_time_sec = packet.pts * packet.time_base
if end_second is not None and current_time_sec > end_second:
if packet.stream in running_stream:
running_stream.remove(packet.stream)
if len(running_stream) == 0:
break
continue
input_stream = packet.stream
if input_stream not in start_pts_map:
start_pts_map[input_stream] = packet.pts
offset = start_pts_map[input_stream]
packet.pts -= offset
if packet.dts is not None:
packet.dts -= offset
packet.stream = streams_map[input_stream]
output_container.mux(packet)
# flush encoder
for stream in streams_map.values():
if stream.type == "video":
output_container.mux(stream.encode(None))
def desire_video_format(list_format: list[str]) -> str:
"""
Select a single, explicitly supported output container format (FFmpeg muxer name)
from the list of possible container formats reported by FFmpeg.
Why is this function needed?
1. `input_container.format.name` returns a string like "mov,mp4,m4a,3gp,3g2,mj2",
which lists **all possible container formats** that FFmpeg believes the file
could be compatible with (comma-separated).
Passing this full string directly to `av.open(..., format=...)` will fail
because FFmpeg cannot recognize a combined name like "mov,mp4,m4a,..." —
it expects **one valid muxer name**.
2. The `format` parameter in `av.open(..., mode="w", format=...)` **must be a single
valid muxer name** (e.g., "mp4", "mov", "matroska"). Otherwise, it raises:
`ValueError: unknown format`.
3. Priority strategy:
- Prefer **mp4**: most widely supported, excellent web and mobile compatibility.
- Then **mov**: common in Apple ecosystem, good compatibility with QuickTime.
- Then **mkv** (as "matroska"): powerful container, supports many codecs.
- Fallback: use the first format in the list (guaranteed to be recognized by FFmpeg).
This ensures we always pass a **valid and highly compatible** output format to `av.open`.
"""
if "mp4" in list_format:
output_format = "mp4"
elif "mov" in list_format:
output_format = "mov"
elif "mkv" in list_format:
output_format = "matroska"
else:
output_format = list_format[0]
return output_format
def cut_video_by_seconds(
input_container: InputContainer,
start_second: float,
end_second: float | None = None,
) -> bytes:
input_video_stream = input_container.streams.video[0]
if input_video_stream.time_base is None or input_video_stream.time_base == 0:
raise ValueError("video stream time base is zero, file may be broken.")
with io.BytesIO() as output_buffer:
with av.open(
output_buffer, mode="w", format=desire_video_format(input_container.format.name.split(","))
) as output_container:
copy_video_stream(input_container, output_container, start_second=start_second, end_second=end_second)
return output_buffer.getvalue()
def deep_getattr(obj: object, attr: str) -> Any:
try:
for part in attr.split("."):
obj = getattr(obj, part)
return obj
except AttributeError:
raise AttributeError(f"'{obj.__class__.__name__}' object has no attribute '{attr}'")
class LLMChatCompletions:
def __init__(
self,
base_url: str,
model: str,
api_key: str,
max_qps: int | None = None,
max_retries: int = 0,
fallback_response: str | float | None = None,
response_format: str = "text",
**kwargs: dict[str, Any],
):
logger.info("Using Remote Http LLMChatCompletions")
self.model = model
self.kwargs = kwargs
self.fallback_response = fallback_response
supported_formats = {"json_object", "text"}
if response_format not in supported_formats:
raise ValueError(f"Unsupported response_format: {response_format}")
self.response_format = {"type": response_format}
client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries)
self.create_completions = qps_limiter(max_qps)(client.chat.completions.create)
async def __call__(self, messages: Iterable[ChatCompletionMessageParam]) -> ChatCompletion | None:
try:
return await self.create_completions(
messages=messages, model=self.model, response_format=self.response_format, **self.kwargs
)
except Exception as e:
logger.error(f"create_completions failed: {e}")
return None
async def batch_generate(
self,
texts: pa.ChunkedArray,
build_prompt: Callable[[str], Iterable[ChatCompletionMessageParam]],
post_process: Callable[[str], str | float] | None = None,
datatype: pa.DataType = pa.string(),
) -> pa.Array:
requests = []
for text in texts:
requests.append(self(messages=build_prompt(text)))
responses = await asyncio.gather(*requests)
results = []
for response in responses:
if response is not None and len(response.choices) > 0 and response.choices[0].message.content is not None:
results.append(
response.choices[0].message.content.strip()
if post_process is None
else post_process(response.choices[0].message.content.strip())
)
else:
logger.error(f"request failed, response is: {response}")
if self.fallback_response is not None:
results.append(self.fallback_response)
else:
raise ValueError("LLM call failed or returned empty content.")
return pa.array(results, type=datatype)
def skip_empty_texts(func: Callable | None = None, empty_response: str = ""):
"""Decorator to skip empty or null texts in text processing."""
def decorator(f: Callable) -> Callable:
@wraps(f)
async def wrapper(self, texts: pa.ChunkedArray) -> pa.Array:
combined = texts.combine_chunks()
is_not_empty = pc.fill_null(pc.not_equal(pc.utf8_length(combined), 0), False)
valid_indices = pc.indices_nonzero(is_not_empty)
if len(valid_indices) == 0:
return pa.array([empty_response] * len(combined), type=pa.string())
valid_texts = pc.filter(combined, is_not_empty)
processed_results = await f(self, valid_texts)
results = [empty_response] * len(combined)
for idx, valid_idx in enumerate(valid_indices):
results[valid_idx.as_py()] = processed_results[idx].as_py()
return pa.array(results, type=pa.string())
return wrapper
# Support both @skip_empty_texts and @skip_empty_texts(...)
if func is not None:
return decorator(func)
return decorator
def normalize_labels(labels: list[str], labels_name: str = "labels") -> list[str]:
labels = list(dict.fromkeys(label.strip() for label in labels if label.strip()))
if len(labels) == 0:
raise ValueError(f"{labels_name} must not be empty")
return labels
def text_tokenize(arr: pa.Array, regex_patten: str | None = None, cjk=False) -> pa.Array:
if regex_patten is None:
if cjk:
tokens = dedup_utils.utf8_split_mixed(arr)
else:
tokens = pc.utf8_split_whitespace(arr)
else:
tokens = pc.split_pattern_regex(arr, pattern=regex_patten)
return tokens