Source code for xpark.dataset.processors.video_compute

from __future__ import annotations

import asyncio
import inspect
import re
from dataclasses import dataclass
from fractions import Fraction
from typing import Callable

# We use pa.list_ for return_dtype of the @udf
import pyarrow as pa
from ray.air.util.tensor_extensions.arrow import ArrowVariableShapedTensorType

from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import udf
from xpark.dataset.processors.video_decode_frames import VideoDecodeFrames, VideoExtractFrames
from xpark.dataset.processors.video_extract_audio import VideoExtractAudio
from xpark.dataset.processors.video_split_by_duration import VideoSplitByDuration
from xpark.dataset.processors.video_split_by_keyframe import VideoSplitByKeyFrame
from xpark.dataset.utils import deep_getattr, open_video, safe_run


@dataclass
class _VideoFileInfoMetadata:
    method_name: str
    dtype: pa.DataType
    docstring: str


VIDEO_INFO_METADATA_MAP = {
    "height": _VideoFileInfoMetadata("height", pa.int32(), "Video height, in pixels"),
    "width": _VideoFileInfoMetadata("width", pa.int32(), "Video width, in pixels"),
    "bit_rate": _VideoFileInfoMetadata(
        "bit_rate", pa.float64(), "The average bitrate of the video stream, in bits per second"
    ),
    "base_rate": _VideoFileInfoMetadata(
        "base_rate", pa.float64(), "The fundamental framerate of the stream, as a float"
    ),
    "average_rate": _VideoFileInfoMetadata(
        "average_rate", pa.float64(), "The average framerate of the video, as a float"
    ),
    "time_base": _VideoFileInfoMetadata(
        "time_base", pa.float64(), "The time base of the stream, representing the unit of time for timestamps"
    ),
    "display_aspect_ratio": _VideoFileInfoMetadata(
        "display_aspect_ratio", pa.float64(), "The display aspect ratio (DAR) of the video, e.g., 16:9 is 1.777..."
    ),
    "codec_context.codec.name": _VideoFileInfoMetadata(
        "codec", pa.string(), "The name of the codec used for the video stream, e.g., 'h264', 'vp9'"
    ),
    "pix_fmt": _VideoFileInfoMetadata("pix_fmt", pa.string(), "The pixel format of the video, e.g., 'yuv420p'"),
}


[docs] class VideoCompute: """.. note:: Do not construct this class, use the staticmethod instead.""" def __new__(cls, *args, **kwargs): raise TypeError(f"The {cls.__name__} class cannot be instantiated.") @staticmethod @udf(return_dtype=DataType.binary()) def extract_audio( videos: pa.ChunkedArray, codec: str | None = None, sample_rate: int | None = None, stream_index: int | None = None, start_second: float = 0, end_second: float | None = None, ) -> pa.Array: """Extract audio from video. This processor extracts audio data from video files and returns the audio binary data along with its file extension. It supports various file systems including COS, S3, HTTP, binary sources, and other fsspec-compatible storage systems. Args: videos: The videos to be processed. codec: Output audio format. sample_rate: Output audio sample rate. stream_index: Index of the audio stream to extract. start_second: Start time of the audio to extract. end_second: End time of the audio to extract. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset.processors.video_compute import VideoCompute from xpark.dataset.context import DatasetContext ctx = DatasetContext.get_current() # set cos storage_options # ctx.storage_options = {"cos": {"endpoint_url": "https://your-cos-endpoint"}} # VIDEO_COS_PATH is cos path like cos://bucket/path/to/video.mp4 ds = from_items([{"video": VIDEO_COS_PATH}]) ds = ds.with_column( "audio", VideoCompute.extract_audio .options(num_workers={"CPU": 1}, batch_size=1) .with_column(col("video")), ) audio_bytes = ds.take(1)[0]['audio'] """ return safe_run( VideoExtractAudio( codec=codec, sample_rate=sample_rate, stream_index=stream_index, start_second=start_second, end_second=end_second, )(videos) ) @staticmethod @udf(return_dtype=DataType.from_arrow(pa.list_(pa.binary()))) def split_by_duration( videos: pa.ChunkedArray, segment_duration: float = 10, min_segment_duration: float = 0, ) -> pa.Array: """Split video by duration. This processor splits video files into multiple segments based on a fixed time length (`segment_duration`). For each video in the input, it outputs a list of binary data for the video segments. The default split points are keyframes. It supports various file systems including COS, S3, HTTP, binary sources, and other fsspec-compatible storage systems. Args: videos: The videos to be processed. segment_duration: Target duration for each segment in seconds, default is 10s min_segment_duration: Minimum duration for each segment, segments shorter than this value will be discarded, used for handling overly short segments at the end of videos, default value is 0 Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset.processors.video_compute import VideoCompute from xpark.dataset.context import DatasetContext ctx = DatasetContext.get_current() # set cos storage_options # ctx.storage_options = {"cos": {"endpoint_url": "https://your-cos-endpoint"}} # VIDEO_COS_PATH is cos path like cos://bucket/path/to/video.mp4 ds = from_items([{"video": VIDEO_COS_PATH}]) ds = ds.with_column( "split_videos", VideoCompute.split_by_duration .options(num_workers={"CPU": 1}, batch_size=1) .with_column(col("video")), ) split_videos = ds.take(1)[0]['split_videos'] """ return safe_run( VideoSplitByDuration(segment_duration=segment_duration, min_segment_duration=min_segment_duration)(videos) ) @staticmethod @udf(return_dtype=DataType.from_arrow(pa.list_(pa.binary()))) def split_by_key_frame( videos: pa.ChunkedArray, ) -> pa.Array: """Split video by keyframe This function splits a video into segments based on keyframes, which are frames in a video stream that contain a complete image. Unlike other frames, keyframes (also known as I-frames) do not rely on previous frames for decoding and can be used as reference points to extract or seek specific video segments. Keyframes are crucial for tasks like video editing, seeking, or streaming, as they represent points where the video can be independently decoded. It supports various file systems including COS, S3, HTTP, binary sources, and other fsspec-compatible storage systems. Args: videos: The videos to be processed. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset.processors.video_compute import VideoCompute from xpark.dataset.context import DatasetContext ctx = DatasetContext.get_current() # set cos storage_options # ctx.storage_options = {"cos": {"endpoint_url": "https://your-cos-endpoint"}} # VIDEO_COS_PATH is cos path like cos://bucket/path/to/video.mp4 ds = from_items([{"video": VIDEO_COS_PATH}]) ds = ds.with_column( "split_videos", VideoCompute.split_by_key_frame .options(num_workers={"CPU": 1}, batch_size=1) .with_column(col("video")), ) split_videos = ds.take(1)[0]['split_videos'] """ return safe_run(VideoSplitByKeyFrame()(videos)) @staticmethod @udf(return_dtype=ArrowVariableShapedTensorType(pa.float32(), ndim=3)) def decode_frame_at( videos: pa.ChunkedArray, timestamp: pa.ChunkedArray, tolerance_s: float = 0.1, ) -> pa.Array: """Decode a video frame at specified timestamp. This processor decodes a video frame from video file at the given timestamp and returns it as a tensor array. It uses PyAV (ffmpeg wrapper) for decoding and supports various file systems including local, COS, S3, HuggingFace Hub (hf://), and other fsspec-compatible storage systems. Args: videos: Column of video file paths (strings) or binary data. Supports local paths and remote paths (cos://, s3://, hf://, etc.). timestamp: Column of timestamps in seconds (floats) at which to extract a frame from the corresponding video. tolerance_s: Allowed deviation in seconds for frame retrieval. If the exact timestamp is not available, the closest frame within this tolerance will be returned. Default is 0.1 seconds. Returns: ArrowTensorArray of decoded frame with shape (C, H, W), dtype float32, values in range [0, 1]. Examples: Basic usage: .. code-block:: python from xpark.dataset import from_items from xpark.dataset.expressions import col from xpark.dataset.processors.video_compute import VideoCompute ds = from_items([ {"video": "video.mp4", "timestamp": 0.5}, {"video": "video.mp4", "timestamp": 1.0}, ]) # Decode video frame at specified timestamp ds = ds.with_column( "frame", VideoCompute.decode_frame_at(col("video"), col("timestamp")), ) # Get decoded frames frames = ds.take(2) With custom options: .. code-block:: python ds = ds.with_column( "frame", VideoCompute.decode_frame_at .options(num_workers={"CPU": 2}, batch_size=16) .with_column( col("video"), col("timestamp"), tolerance_s=0.05, ), ) """ return safe_run( VideoDecodeFrames( tolerance_s=tolerance_s, )(videos, timestamp) ) @staticmethod @udf(return_dtype=ArrowVariableShapedTensorType(pa.float32(), ndim=4)) def decode_frames( videos: pa.ChunkedArray, timestamps: pa.ChunkedArray | None = None, *, fps: float | None = None, keyframes_only: bool = False, num_frames: int | None = None, max_frames: int | None = None, start_time: float | None = None, end_time: float | None = None, tolerance_s: float = 0.1, ) -> pa.Array: """Extract frames from video with multiple modes. This processor extracts video frames with flexible extraction modes and supports both path (string) and binary (bytes) input. It returns frames as tensor arrays. Extraction modes (mutually exclusive, in priority order): 1. timestamps: Extract frames at specific timestamps (requires timestamps column) 2. fps: Extract frames at a fixed frame rate 3. keyframes_only: Extract only keyframes (I-frames) 4. num_frames: Extract N frames uniformly distributed Args: videos: Column of video file paths (strings) or binary data (bytes). Supports local paths and remote paths (cos://, s3://, hf://, etc.). timestamps: Optional column of timestamps in seconds (floats) at which to extract frames. Can be a single float or a list of floats per video. fps: Target frames per second to extract. For example, fps=2.0 extracts 2 frames per second of video. keyframes_only: If True, extract only keyframes (I-frames). Keyframes are independently decodable frames, useful for scene detection. num_frames: Number of frames to extract uniformly distributed across the video duration. max_frames: Maximum number of frames to extract. Applies to fps and keyframes_only modes. start_time: Start time in seconds for extraction range. end_time: End time in seconds for extraction range. tolerance_s: Allowed deviation in seconds for timestamp mode. Default is 0.1 seconds. Returns: ArrowTensorArray of extracted frames with shape (N, C, H, W) per video, dtype float32, values in range [0, 1]. Examples: Extract frames at specific timestamps: .. code-block:: python from xpark.dataset import from_items from xpark.dataset.expressions import col from xpark.dataset.processors.video_compute import VideoCompute ds = from_items([{"video": "video.mp4", "ts": [0.5, 1.0, 1.5]}]) ds = ds.with_column( "frames", VideoCompute.decode_frames(col("video"), col("ts")), ) Extract frames at 2 FPS: .. code-block:: python ds = ds.with_column( "frames", VideoCompute.decode_frames(col("video"), fps=2.0, max_frames=10), ) Extract keyframes only: .. code-block:: python ds = ds.with_column( "keyframes", VideoCompute.decode_frames(col("video"), keyframes_only=True), ) Extract 10 frames uniformly: .. code-block:: python ds = ds.with_column( "frames", VideoCompute.decode_frames(col("video"), num_frames=10), ) Extract from binary video data: .. code-block:: python ds = xd.from_items([{"video_bytes": video_binary}]) ds = ds.with_column( "frames", VideoCompute.decode_frames(col("video_bytes"), num_frames=5), ) See Also: - :meth:`decode_frame_at`: Simpler API for single frame extraction. - data-juicer VideoExtractFramesMapper for similar functionality. """ return safe_run( VideoExtractFrames( fps=fps, keyframes_only=keyframes_only, num_frames=num_frames, max_frames=max_frames, start_time=start_time, end_time=end_time, tolerance_s=tolerance_s, )(videos, timestamps) )
def _get_video_info_wrapper(attr_name: str) -> Callable[..., pa.Array]: dtype = VIDEO_INFO_METADATA_MAP[attr_name].dtype async def extract_meta(video: pa.StringScalar | pa.BinaryScalar, **storage_options) -> float: async with open_video(video, **storage_options) as input_container: if len(input_container.streams.video) > 0: video_stream = input_container.streams.video[0] result = deep_getattr(video_stream, attr_name) if isinstance(result, Fraction): return float(result) return result else: raise ValueError("container does not contain video stream") async def _vectorized_process(videos: pa.ChunkedArray) -> pa.Array: storage_options = DatasetContext.get_current().storage_options tasks = [extract_meta(video, **storage_options) for video in videos] result = [metadata for metadata in await asyncio.gather(*tasks)] return pa.array(result, dtype) def _get_video_info(videos: pa.ChunkedArray) -> pa.Array: return safe_run(_vectorized_process(videos)) _get_video_info.__name__ = VIDEO_INFO_METADATA_MAP[attr_name].method_name _get_video_info.__doc__ = VIDEO_INFO_METADATA_MAP[attr_name].docstring return _get_video_info for attr, metadata in VIDEO_INFO_METADATA_MAP.items(): method = _get_video_info_wrapper(attr) setattr( VideoCompute, metadata.method_name, staticmethod(udf(return_dtype=DataType.from_arrow(metadata.dtype))(method)) ) def _gen_stub_code() -> str: from io import StringIO with StringIO() as stub: stub.write('"""This is an auto-generated stub. Please do not modify this file."""\n\n') stub.write("from ray.data.expressions import ColumnExpr, Expr\n\n") stub.write("def _gen_stub_code() -> str: ...\n\n") stub.write("class VideoCompute:\n") for name, fn in VideoCompute.__dict__.items(): if name != "__new__" and isinstance(fn, staticmethod): signature = inspect.signature(fn) # replace <IntEnum: id> to IntEnum for pattern = r"<([a-zA-Z]+\.[a-zA-Z]+):\s*\w+>" params = re.sub(pattern, r"\1", str(signature)) params = re.sub(r"videos: 'pa\.ChunkedArray'", "videos: ColumnExpr", params) params = re.sub(r"timestamp: 'pa\.ChunkedArray'", "timestamp: Expr", params) params = re.sub(r"timestamps: 'pa\.ChunkedArray'", "timestamps: Expr", params) params = re.sub(r"timestamps: 'pa\.ChunkedArray \| None'", "timestamps: 'Expr | None'", params) params = re.sub(r" -> 'pa.Array'", "", params) stub.write(" @staticmethod\n") stub.write(f" def {name}{params} -> Expr: ...\n") return stub.getvalue() if __name__ == "__main__": import subprocess with open(__file__ + "i", "w") as f: f.write(_gen_stub_code()) # ruff format subprocess.run(["ruff", "format", __file__ + "i"]) subprocess.run(["ruff", "check", "--fix", __file__ + "i"])