from __future__ import annotations
import asyncio
import inspect
import re
from dataclasses import dataclass
from fractions import Fraction
from typing import Callable
# We use pa.list_ for return_dtype of the @udf
import pyarrow as pa
from ray.air.util.tensor_extensions.arrow import ArrowVariableShapedTensorType
from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import udf
from xpark.dataset.processors.video_decode_frames import VideoDecodeFrames, VideoExtractFrames
from xpark.dataset.processors.video_extract_audio import VideoExtractAudio
from xpark.dataset.processors.video_split_by_duration import VideoSplitByDuration
from xpark.dataset.processors.video_split_by_keyframe import VideoSplitByKeyFrame
from xpark.dataset.utils import deep_getattr, open_video, safe_run
@dataclass
class _VideoFileInfoMetadata:
method_name: str
dtype: pa.DataType
docstring: str
VIDEO_INFO_METADATA_MAP = {
"height": _VideoFileInfoMetadata("height", pa.int32(), "Video height, in pixels"),
"width": _VideoFileInfoMetadata("width", pa.int32(), "Video width, in pixels"),
"bit_rate": _VideoFileInfoMetadata(
"bit_rate", pa.float64(), "The average bitrate of the video stream, in bits per second"
),
"base_rate": _VideoFileInfoMetadata(
"base_rate", pa.float64(), "The fundamental framerate of the stream, as a float"
),
"average_rate": _VideoFileInfoMetadata(
"average_rate", pa.float64(), "The average framerate of the video, as a float"
),
"time_base": _VideoFileInfoMetadata(
"time_base", pa.float64(), "The time base of the stream, representing the unit of time for timestamps"
),
"display_aspect_ratio": _VideoFileInfoMetadata(
"display_aspect_ratio", pa.float64(), "The display aspect ratio (DAR) of the video, e.g., 16:9 is 1.777..."
),
"codec_context.codec.name": _VideoFileInfoMetadata(
"codec", pa.string(), "The name of the codec used for the video stream, e.g., 'h264', 'vp9'"
),
"pix_fmt": _VideoFileInfoMetadata("pix_fmt", pa.string(), "The pixel format of the video, e.g., 'yuv420p'"),
}
[docs]
class VideoCompute:
""".. note:: Do not construct this class, use the staticmethod instead."""
def __new__(cls, *args, **kwargs):
raise TypeError(f"The {cls.__name__} class cannot be instantiated.")
@staticmethod
@udf(return_dtype=DataType.binary())
def extract_audio(
videos: pa.ChunkedArray,
codec: str | None = None,
sample_rate: int | None = None,
stream_index: int | None = None,
start_second: float = 0,
end_second: float | None = None,
) -> pa.Array:
"""Extract audio from video.
This processor extracts audio data from video files and returns the audio binary
data along with its file extension. It supports various file systems including COS,
S3, HTTP, binary sources, and other fsspec-compatible storage systems.
Args:
videos: The videos to be processed.
codec: Output audio format.
sample_rate: Output audio sample rate.
stream_index: Index of the audio stream to extract.
start_second: Start time of the audio to extract.
end_second: End time of the audio to extract.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset.processors.video_compute import VideoCompute
from xpark.dataset.context import DatasetContext
ctx = DatasetContext.get_current()
# set cos storage_options
# ctx.storage_options = {"cos": {"endpoint_url": "https://your-cos-endpoint"}}
# VIDEO_COS_PATH is cos path like cos://bucket/path/to/video.mp4
ds = from_items([{"video": VIDEO_COS_PATH}])
ds = ds.with_column(
"audio",
VideoCompute.extract_audio
.options(num_workers={"CPU": 1}, batch_size=1)
.with_column(col("video")),
)
audio_bytes = ds.take(1)[0]['audio']
"""
return safe_run(
VideoExtractAudio(
codec=codec,
sample_rate=sample_rate,
stream_index=stream_index,
start_second=start_second,
end_second=end_second,
)(videos)
)
@staticmethod
@udf(return_dtype=DataType.from_arrow(pa.list_(pa.binary())))
def split_by_duration(
videos: pa.ChunkedArray,
segment_duration: float = 10,
min_segment_duration: float = 0,
) -> pa.Array:
"""Split video by duration.
This processor splits video files into multiple segments based on a fixed time
length (`segment_duration`). For each video in the input, it outputs a list of binary
data for the video segments. The default split points are keyframes. It supports various
file systems including COS, S3, HTTP, binary sources, and other fsspec-compatible
storage systems.
Args:
videos: The videos to be processed.
segment_duration: Target duration for each segment in seconds, default is 10s
min_segment_duration: Minimum duration for each segment, segments shorter than this value
will be discarded, used for handling overly short segments at the end of videos,
default value is 0
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset.processors.video_compute import VideoCompute
from xpark.dataset.context import DatasetContext
ctx = DatasetContext.get_current()
# set cos storage_options
# ctx.storage_options = {"cos": {"endpoint_url": "https://your-cos-endpoint"}}
# VIDEO_COS_PATH is cos path like cos://bucket/path/to/video.mp4
ds = from_items([{"video": VIDEO_COS_PATH}])
ds = ds.with_column(
"split_videos",
VideoCompute.split_by_duration
.options(num_workers={"CPU": 1}, batch_size=1)
.with_column(col("video")),
)
split_videos = ds.take(1)[0]['split_videos']
"""
return safe_run(
VideoSplitByDuration(segment_duration=segment_duration, min_segment_duration=min_segment_duration)(videos)
)
@staticmethod
@udf(return_dtype=DataType.from_arrow(pa.list_(pa.binary())))
def split_by_key_frame(
videos: pa.ChunkedArray,
) -> pa.Array:
"""Split video by keyframe
This function splits a video into segments based on keyframes, which are frames in a video stream
that contain a complete image. Unlike other frames, keyframes (also known as I-frames) do not
rely on previous frames for decoding and can be used as reference points to extract or seek
specific video segments. Keyframes are crucial for tasks like video editing, seeking, or
streaming, as they represent points where the video can be independently decoded.
It supports various file systems including COS,
S3, HTTP, binary sources, and other fsspec-compatible storage systems.
Args:
videos: The videos to be processed.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset.processors.video_compute import VideoCompute
from xpark.dataset.context import DatasetContext
ctx = DatasetContext.get_current()
# set cos storage_options
# ctx.storage_options = {"cos": {"endpoint_url": "https://your-cos-endpoint"}}
# VIDEO_COS_PATH is cos path like cos://bucket/path/to/video.mp4
ds = from_items([{"video": VIDEO_COS_PATH}])
ds = ds.with_column(
"split_videos",
VideoCompute.split_by_key_frame
.options(num_workers={"CPU": 1}, batch_size=1)
.with_column(col("video")),
)
split_videos = ds.take(1)[0]['split_videos']
"""
return safe_run(VideoSplitByKeyFrame()(videos))
@staticmethod
@udf(return_dtype=ArrowVariableShapedTensorType(pa.float32(), ndim=3))
def decode_frame_at(
videos: pa.ChunkedArray,
timestamp: pa.ChunkedArray,
tolerance_s: float = 0.1,
) -> pa.Array:
"""Decode a video frame at specified timestamp.
This processor decodes a video frame from video file at the given timestamp
and returns it as a tensor array. It uses PyAV (ffmpeg wrapper) for decoding
and supports various file systems including local, COS, S3, HuggingFace Hub
(hf://), and other fsspec-compatible storage systems.
Args:
videos: Column of video file paths (strings) or binary data.
Supports local paths and remote paths (cos://, s3://, hf://, etc.).
timestamp: Column of timestamps in seconds (floats) at which to
extract a frame from the corresponding video.
tolerance_s: Allowed deviation in seconds for frame retrieval.
If the exact timestamp is not available, the closest frame
within this tolerance will be returned. Default is 0.1 seconds.
Returns:
ArrowTensorArray of decoded frame with shape (C, H, W),
dtype float32, values in range [0, 1].
Examples:
Basic usage:
.. code-block:: python
from xpark.dataset import from_items
from xpark.dataset.expressions import col
from xpark.dataset.processors.video_compute import VideoCompute
ds = from_items([
{"video": "video.mp4", "timestamp": 0.5},
{"video": "video.mp4", "timestamp": 1.0},
])
# Decode video frame at specified timestamp
ds = ds.with_column(
"frame",
VideoCompute.decode_frame_at(col("video"), col("timestamp")),
)
# Get decoded frames
frames = ds.take(2)
With custom options:
.. code-block:: python
ds = ds.with_column(
"frame",
VideoCompute.decode_frame_at
.options(num_workers={"CPU": 2}, batch_size=16)
.with_column(
col("video"),
col("timestamp"),
tolerance_s=0.05,
),
)
"""
return safe_run(
VideoDecodeFrames(
tolerance_s=tolerance_s,
)(videos, timestamp)
)
@staticmethod
@udf(return_dtype=ArrowVariableShapedTensorType(pa.float32(), ndim=4))
def decode_frames(
videos: pa.ChunkedArray,
timestamps: pa.ChunkedArray | None = None,
*,
fps: float | None = None,
keyframes_only: bool = False,
num_frames: int | None = None,
max_frames: int | None = None,
start_time: float | None = None,
end_time: float | None = None,
tolerance_s: float = 0.1,
) -> pa.Array:
"""Extract frames from video with multiple modes.
This processor extracts video frames with flexible extraction modes and supports
both path (string) and binary (bytes) input. It returns frames as tensor arrays.
Extraction modes (mutually exclusive, in priority order):
1. timestamps: Extract frames at specific timestamps (requires timestamps column)
2. fps: Extract frames at a fixed frame rate
3. keyframes_only: Extract only keyframes (I-frames)
4. num_frames: Extract N frames uniformly distributed
Args:
videos: Column of video file paths (strings) or binary data (bytes).
Supports local paths and remote paths (cos://, s3://, hf://, etc.).
timestamps: Optional column of timestamps in seconds (floats) at which to
extract frames. Can be a single float or a list of floats per video.
fps: Target frames per second to extract. For example, fps=2.0 extracts
2 frames per second of video.
keyframes_only: If True, extract only keyframes (I-frames). Keyframes are
independently decodable frames, useful for scene detection.
num_frames: Number of frames to extract uniformly distributed across
the video duration.
max_frames: Maximum number of frames to extract. Applies to fps and
keyframes_only modes.
start_time: Start time in seconds for extraction range.
end_time: End time in seconds for extraction range.
tolerance_s: Allowed deviation in seconds for timestamp mode.
Default is 0.1 seconds.
Returns:
ArrowTensorArray of extracted frames with shape (N, C, H, W) per video,
dtype float32, values in range [0, 1].
Examples:
Extract frames at specific timestamps:
.. code-block:: python
from xpark.dataset import from_items
from xpark.dataset.expressions import col
from xpark.dataset.processors.video_compute import VideoCompute
ds = from_items([{"video": "video.mp4", "ts": [0.5, 1.0, 1.5]}])
ds = ds.with_column(
"frames",
VideoCompute.decode_frames(col("video"), col("ts")),
)
Extract frames at 2 FPS:
.. code-block:: python
ds = ds.with_column(
"frames",
VideoCompute.decode_frames(col("video"), fps=2.0, max_frames=10),
)
Extract keyframes only:
.. code-block:: python
ds = ds.with_column(
"keyframes",
VideoCompute.decode_frames(col("video"), keyframes_only=True),
)
Extract 10 frames uniformly:
.. code-block:: python
ds = ds.with_column(
"frames",
VideoCompute.decode_frames(col("video"), num_frames=10),
)
Extract from binary video data:
.. code-block:: python
ds = xd.from_items([{"video_bytes": video_binary}])
ds = ds.with_column(
"frames",
VideoCompute.decode_frames(col("video_bytes"), num_frames=5),
)
See Also:
- :meth:`decode_frame_at`: Simpler API for single frame extraction.
- data-juicer VideoExtractFramesMapper for similar functionality.
"""
return safe_run(
VideoExtractFrames(
fps=fps,
keyframes_only=keyframes_only,
num_frames=num_frames,
max_frames=max_frames,
start_time=start_time,
end_time=end_time,
tolerance_s=tolerance_s,
)(videos, timestamps)
)
def _get_video_info_wrapper(attr_name: str) -> Callable[..., pa.Array]:
dtype = VIDEO_INFO_METADATA_MAP[attr_name].dtype
async def extract_meta(video: pa.StringScalar | pa.BinaryScalar, **storage_options) -> float:
async with open_video(video, **storage_options) as input_container:
if len(input_container.streams.video) > 0:
video_stream = input_container.streams.video[0]
result = deep_getattr(video_stream, attr_name)
if isinstance(result, Fraction):
return float(result)
return result
else:
raise ValueError("container does not contain video stream")
async def _vectorized_process(videos: pa.ChunkedArray) -> pa.Array:
storage_options = DatasetContext.get_current().storage_options
tasks = [extract_meta(video, **storage_options) for video in videos]
result = [metadata for metadata in await asyncio.gather(*tasks)]
return pa.array(result, dtype)
def _get_video_info(videos: pa.ChunkedArray) -> pa.Array:
return safe_run(_vectorized_process(videos))
_get_video_info.__name__ = VIDEO_INFO_METADATA_MAP[attr_name].method_name
_get_video_info.__doc__ = VIDEO_INFO_METADATA_MAP[attr_name].docstring
return _get_video_info
for attr, metadata in VIDEO_INFO_METADATA_MAP.items():
method = _get_video_info_wrapper(attr)
setattr(
VideoCompute, metadata.method_name, staticmethod(udf(return_dtype=DataType.from_arrow(metadata.dtype))(method))
)
def _gen_stub_code() -> str:
from io import StringIO
with StringIO() as stub:
stub.write('"""This is an auto-generated stub. Please do not modify this file."""\n\n')
stub.write("from ray.data.expressions import ColumnExpr, Expr\n\n")
stub.write("def _gen_stub_code() -> str: ...\n\n")
stub.write("class VideoCompute:\n")
for name, fn in VideoCompute.__dict__.items():
if name != "__new__" and isinstance(fn, staticmethod):
signature = inspect.signature(fn)
# replace <IntEnum: id> to IntEnum for
pattern = r"<([a-zA-Z]+\.[a-zA-Z]+):\s*\w+>"
params = re.sub(pattern, r"\1", str(signature))
params = re.sub(r"videos: 'pa\.ChunkedArray'", "videos: ColumnExpr", params)
params = re.sub(r"timestamp: 'pa\.ChunkedArray'", "timestamp: Expr", params)
params = re.sub(r"timestamps: 'pa\.ChunkedArray'", "timestamps: Expr", params)
params = re.sub(r"timestamps: 'pa\.ChunkedArray \| None'", "timestamps: 'Expr | None'", params)
params = re.sub(r" -> 'pa.Array'", "", params)
stub.write(" @staticmethod\n")
stub.write(f" def {name}{params} -> Expr: ...\n")
return stub.getvalue()
if __name__ == "__main__":
import subprocess
with open(__file__ + "i", "w") as f:
f.write(_gen_stub_code())
# ruff format
subprocess.run(["ruff", "format", __file__ + "i"])
subprocess.run(["ruff", "check", "--fix", __file__ + "i"])