Source code for xpark.dataset.processors.audio_compute

from __future__ import annotations

import functools
import inspect
import io
import operator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable

import numpy as np
import pyarrow as pa
from numpy.typing import DTypeLike
from ray.air.util.tensor_extensions.arrow import (
    ArrowTensorArray,
    ArrowVariableShapedTensorArray,
    ArrowVariableShapedTensorType,
)

from xpark.dataset.context import DatasetContext
from xpark.dataset.expressions import Expr, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.utils import auto_open, multi_replace, split_audio

if TYPE_CHECKING:
    import librosa
    import soundfile
else:
    librosa = lazy_import("librosa")
    soundfile = lazy_import("soundfile")


@dataclass
class _SoundFileInfoMetadata:
    dtype: pa.DataType
    docstring: str


SOUNDFILE_INFO_METADATA_MAP = {
    "samplerate": _SoundFileInfoMetadata(pa.int32(), "The sample rate of the sound file."),
    "channels": _SoundFileInfoMetadata(pa.int32(), "The number of channels in the sound file."),
    "frames": _SoundFileInfoMetadata(pa.int64(), "The number of frames in the sound file."),
    "subtype": _SoundFileInfoMetadata(pa.string(), "The subtype of data in the the sound file."),
    "format": _SoundFileInfoMetadata(pa.string(), "The major format of the sound file."),
    "duration": _SoundFileInfoMetadata(pa.float64(), "The duration of the sound file."),
}

TYPED_OPEN = {str: functools.partial(auto_open, mode="rb"), bytes: lambda b, *_args, **_kwargs: io.BytesIO(b)}


def _vectorized_process(
    audio_array: pa.ChunkedArray, fn: Callable[[IO], Any], return_fn: Callable[[Iterable[Any]], pa.Array]
) -> pa.Array:
    if len(audio_array) == 0:
        return pa.array([], type=pa.null())

    py_audio_array = audio_array.to_pylist()
    item_type: type[str | bytes] = type(py_audio_array[0])
    if item_type not in TYPED_OPEN:
        raise ValueError(f"Unsupported data type: {item_type}, expected values of {list(TYPED_OPEN.keys())}")

    typed_open = TYPED_OPEN[item_type]
    storage_options = DatasetContext.get_current().storage_options

    def _get_info(audio):
        with typed_open(audio, **storage_options) as fio:
            return fn(fio)

    with ThreadPoolExecutor() as executor:
        results = executor.map(_get_info, py_audio_array)
        return return_fn(results)


def _get_audio_info_wrapper(attr_name: str) -> Callable[..., pa.Array]:
    def _get_audio_info(audios: pa.ChunkedArray) -> pa.Array:
        return _vectorized_process(
            audios,
            fn=lambda fio: getattr(soundfile.info(fio), attr_name),
            return_fn=lambda results: pa.array(results, type=SOUNDFILE_INFO_METADATA_MAP[attr_name].dtype),
        )

    _get_audio_info.__name__ = attr_name
    _get_audio_info.__doc__ = SOUNDFILE_INFO_METADATA_MAP[attr_name].docstring

    return _get_audio_info


[docs] class AudioCompute: """.. note:: Do not construct this class, use the staticmethod instead.""" def __new__(cls, *args, **kwargs): raise TypeError(f"The {cls.__name__} class cannot be instantiated.") @staticmethod # The return dtype is dynamic, we give it a fake type here. # TODO(baobliu): Fix the return_dtype @udf(return_dtype=pa.struct([pa.field("y", pa.binary()), pa.field("sr", pa.float32())])) def load( audios: pa.ChunkedArray, *, sr: float | None = 16000, mono: bool = True, offset: float = 0.0, duration: float | None = None, dtype: DTypeLike = np.float32, res_type: str = "soxr_hq", ) -> pa.Array: """Load an audio file as a floating point time series. Audio will be automatically resampled to the given rate (default ``sr=16000``). To preserve the native sampling rate of the file, use ``sr=None``. Parameters ---------- audios : string, pathlib.Path, http URL or audio bytes path to the input audio. Any codec supported by `soundfile` or `audioread` will work. The audio bytes is all the bytes of the audio file, not the ndarray. sr : number > 0 [scalar] target sampling rate 'None' uses the native sampling rate mono : bool convert signal to mono offset : float start reading after this time (in seconds) duration : float only load up to this much audio (in seconds) dtype : numeric type data type of ``y`` res_type : str resample type (see note) .. note:: By default, this uses `soxr`'s high-quality mode ('HQ'). For alternative resampling modes, see `resample` .. note:: `audioread` may truncate the precision of the audio data to 16 bits. See :ref:`ioformats` for alternate loading methods. Returns ------- A StructArray of PyArrow that contains these two fields. y : np.ndarray [shape=(n,) or (..., n)] audio time series. Multi-channel is supported. sr : number > 0 [scalar] sampling rate of ``y`` Examples -------- >>> # Load audio from path >>> from xpark.dataset import AudioCompute, from_items >>> from xpark.dataset.expressions import col >>> >>> items = ["sample.wav", "cos://my_bucket/sample.wav", "http://127.0.0.1:12345/sample.wav"] >>> from_items(items).with_column("audio_data", AudioCompute.load(col("item"))).show() {'item': 'sample.wav', 'audio_data': {'y': array([-7.4705182e-05, -5.2042997e-05, 5.3031382e-04, ..., -1.2249170e-02, -7.8290682e-03, 0.0000000e+00], shape=(77286,), dtype=float32), 'sr': 22050}} ... >>> # Load audio from bytes >>> from xpark.dataset import AudioCompute, from_items >>> from xpark.dataset.expressions import col >>> >>> with open("sample.wav", "rb") as f: >>> items = [f.read()] >>> from_items(items).with_column("audio_data", AudioCompute.load(col("item"))).show() {'item': b'RIFFD...', 'audio_data': {'y': array([-7.4705182e-05, -5.2042997e-05, 5.3031382e-04, ..., -1.2249170e-02, -7.8290682e-03, 0.0000000e+00], shape=(77286,), dtype=float32), 'sr': 22050}} >>> # Load and resample the audio to 16000 samplerate >>> from xpark.dataset import AudioCompute, from_items >>> from xpark.dataset.expressions import col >>> >>> from_items(["sample.wav"]).with_column("audio_data", AudioCompute.load(col("item"))).show() {'item': 'sample.wav', 'audio_data': {'y': array([-6.1035156e-05, 9.1552734e-05, 1.0681152e-03, ..., -2.1972656e-03, -1.1383057e-02, -8.8195801e-03], shape=(56080,), dtype=float32), 'sr': 16000}} """ def _load_audio(fio): return librosa.load(fio, sr=sr, mono=mono, offset=offset, duration=duration, dtype=dtype, res_type=res_type) def _convert_to_array(results): data = [] sample_rates = [] for y, sr in results: data.append(y) sample_rates.append(sr) return pa.StructArray.from_arrays( [ ArrowTensorArray.from_numpy(data), pa.array(sample_rates, type=SOUNDFILE_INFO_METADATA_MAP["samplerate"].dtype), ], names=["y", "sr"], ) return _vectorized_process(audios, fn=_load_audio, return_fn=_convert_to_array) @staticmethod @udf(return_dtype=pa.list_(ArrowVariableShapedTensorType(pa.float32(), ndim=1))) def split_by_duration( audios: pa.ChunkedArray, sample_rate: int | pa.ChunkedArray | None = None, max_audio_clip_s: int = 30, overlap_chunk_second: int = 1, min_energy_split_window_size: int = 1600, ) -> pa.Array: """Split audio by duration. Parameters ---------- audios : string, pathlib.Path, http URL, audio bytes, or ndarray path to the input audio. Any codec supported by `soundfile` or `audioread` will work. The audio bytes is all the bytes of the audio file, not the ndarray. If the input audio type is ndarray, then sample rate must be specified. sample_rate : int, optional The input audio sample rate. If the input audio type is **not** ndarray, then sample rate is not used. max_audio_clip_s : int, optional Maximum duration in seconds for a single audio clip without chunking. Audio longer than this will be split into smaller chunks if allow_audio_chunking evaluates to True, otherwise it will be rejected. overlap_chunk_second : int, optional Overlap duration in seconds between consecutive audio chunks when splitting long audio. min_energy_split_window_size : int, optional Window size in samples for finding low-energy (quiet) regions to split audio chunks. The algorithm looks for the quietest moment within this window to minimize cutting through speech. Default 1600 samples ≈ 100ms at 16kHz. If None, no chunking will be done. Returns ------- An array of ArrowVariableShapedTensorArray, each element of the array is a part of the audio in ndarray format. """ if not audios: return pa.array([], type=pa.string()) if max_audio_clip_s <= 0: raise ValueError("max_audio_clip_s must be positive.") if overlap_chunk_second < 0: raise ValueError("overlap_chunk_second must be zero or positive.") if min_energy_split_window_size < 0: raise ValueError("min_energy_split_window_size must be positive.") py_audio_array = audios.to_pylist() item = py_audio_array[0] if isinstance(item, np.ndarray): if sample_rate is None: raise ValueError("Sample rate must be specified for ndarray audios.") if type(sample_rate) is int: def _sample_rate_generator(): for _ in range(len(audios)): yield sample_rate gen_sample_rate = _sample_rate_generator() else: assert len(audios) == len(sample_rate) gen_sample_rate = sample_rate.to_pylist() split_list = map( functools.partial( split_audio, max_audio_clip_s=max_audio_clip_s, overlap_chunk_second=overlap_chunk_second, min_energy_split_window_size=min_energy_split_window_size, ), py_audio_array, gen_sample_rate, ) tensor_array_list = [ArrowVariableShapedTensorArray.from_numpy(parts) for parts in split_list] nested_tensor_array = pa.array( map(operator.attrgetter("storage"), tensor_array_list), type=pa.list_(tensor_array_list[0].type.storage_type), ) return nested_tensor_array.cast(pa.list_(tensor_array_list[0].type)) else: # We do not resample the audio and keep the original signal tracks. loaded_audio = AudioCompute.load.__metadata__.wrapped(audios, sr=None, mono=False) return AudioCompute.split_by_duration.__metadata__.wrapped( audios=loaded_audio.field("y"), sample_rate=loaded_audio.field("sr"), max_audio_clip_s=max_audio_clip_s, overlap_chunk_second=overlap_chunk_second, min_energy_split_window_size=min_energy_split_window_size, )
for attr, metadata in SOUNDFILE_INFO_METADATA_MAP.items(): method = _get_audio_info_wrapper(attr) setattr(AudioCompute, attr, staticmethod(udf(return_dtype=metadata.dtype)(method))) def _gen_stub_code() -> str: from io import StringIO with StringIO() as stub: stub.write('"""This is an auto-generated stub. Please do not modify this file."""\n\n') stub.write("import numpy as np\n") stub.write("from numpy.typing import DTypeLike\n") stub.write("import ray.data.expressions\n\n") stub.write("def _gen_stub_code() -> str: ...\n\n") stub.write("class AudioCompute:\n") replace_map = { "<class 'numpy.float32'>": "np.float32", "pa.ChunkedArray": "ray.data.expressions.ColumnExpr", } for name, fn in AudioCompute.__dict__.items(): if name != "__new__" and isinstance(fn, staticmethod): signature = inspect.signature(fn).replace(return_annotation=Expr) signature_string = multi_replace(str(signature), replace_map) stub.write(" @staticmethod\n") stub.write(f" def {name}{signature_string}: ...\n") return stub.getvalue() if __name__ == "__main__": import subprocess with open(__file__ + "i", "w") as f: f.write(_gen_stub_code()) # ruff format subprocess.run(["ruff", "format", __file__ + "i"]) subprocess.run(["ruff", "check", "--fix", __file__ + "i"])