from __future__ import annotations
import functools
import inspect
import io
import operator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable
import numpy as np
import pyarrow as pa
from numpy.typing import DTypeLike
from ray.air.util.tensor_extensions.arrow import (
ArrowTensorArray,
ArrowVariableShapedTensorArray,
ArrowVariableShapedTensorType,
)
from xpark.dataset.context import DatasetContext
from xpark.dataset.expressions import Expr, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.utils import auto_open, multi_replace, split_audio
if TYPE_CHECKING:
import librosa
import soundfile
else:
librosa = lazy_import("librosa")
soundfile = lazy_import("soundfile")
@dataclass
class _SoundFileInfoMetadata:
dtype: pa.DataType
docstring: str
SOUNDFILE_INFO_METADATA_MAP = {
"samplerate": _SoundFileInfoMetadata(pa.int32(), "The sample rate of the sound file."),
"channels": _SoundFileInfoMetadata(pa.int32(), "The number of channels in the sound file."),
"frames": _SoundFileInfoMetadata(pa.int64(), "The number of frames in the sound file."),
"subtype": _SoundFileInfoMetadata(pa.string(), "The subtype of data in the the sound file."),
"format": _SoundFileInfoMetadata(pa.string(), "The major format of the sound file."),
"duration": _SoundFileInfoMetadata(pa.float64(), "The duration of the sound file."),
}
TYPED_OPEN = {str: functools.partial(auto_open, mode="rb"), bytes: lambda b, *_args, **_kwargs: io.BytesIO(b)}
def _vectorized_process(
audio_array: pa.ChunkedArray, fn: Callable[[IO], Any], return_fn: Callable[[Iterable[Any]], pa.Array]
) -> pa.Array:
if len(audio_array) == 0:
return pa.array([], type=pa.null())
py_audio_array = audio_array.to_pylist()
item_type: type[str | bytes] = type(py_audio_array[0])
if item_type not in TYPED_OPEN:
raise ValueError(f"Unsupported data type: {item_type}, expected values of {list(TYPED_OPEN.keys())}")
typed_open = TYPED_OPEN[item_type]
storage_options = DatasetContext.get_current().storage_options
def _get_info(audio):
with typed_open(audio, **storage_options) as fio:
return fn(fio)
with ThreadPoolExecutor() as executor:
results = executor.map(_get_info, py_audio_array)
return return_fn(results)
def _get_audio_info_wrapper(attr_name: str) -> Callable[..., pa.Array]:
def _get_audio_info(audios: pa.ChunkedArray) -> pa.Array:
return _vectorized_process(
audios,
fn=lambda fio: getattr(soundfile.info(fio), attr_name),
return_fn=lambda results: pa.array(results, type=SOUNDFILE_INFO_METADATA_MAP[attr_name].dtype),
)
_get_audio_info.__name__ = attr_name
_get_audio_info.__doc__ = SOUNDFILE_INFO_METADATA_MAP[attr_name].docstring
return _get_audio_info
[docs]
class AudioCompute:
""".. note:: Do not construct this class, use the staticmethod instead."""
def __new__(cls, *args, **kwargs):
raise TypeError(f"The {cls.__name__} class cannot be instantiated.")
@staticmethod
# The return dtype is dynamic, we give it a fake type here.
# TODO(baobliu): Fix the return_dtype
@udf(return_dtype=pa.struct([pa.field("y", pa.binary()), pa.field("sr", pa.float32())]))
def load(
audios: pa.ChunkedArray,
*,
sr: float | None = 16000,
mono: bool = True,
offset: float = 0.0,
duration: float | None = None,
dtype: DTypeLike = np.float32,
res_type: str = "soxr_hq",
) -> pa.Array:
"""Load an audio file as a floating point time series.
Audio will be automatically resampled to the given rate
(default ``sr=16000``).
To preserve the native sampling rate of the file, use ``sr=None``.
Parameters
----------
audios : string, pathlib.Path, http URL or audio bytes
path to the input audio.
Any codec supported by `soundfile` or `audioread` will work.
The audio bytes is all the bytes of the audio file, not the ndarray.
sr : number > 0 [scalar]
target sampling rate
'None' uses the native sampling rate
mono : bool
convert signal to mono
offset : float
start reading after this time (in seconds)
duration : float
only load up to this much audio (in seconds)
dtype : numeric type
data type of ``y``
res_type : str
resample type (see note)
.. note::
By default, this uses `soxr`'s high-quality mode ('HQ').
For alternative resampling modes, see `resample`
.. note::
`audioread` may truncate the precision of the audio data to 16 bits.
See :ref:`ioformats` for alternate loading methods.
Returns
-------
A StructArray of PyArrow that contains these two fields.
y : np.ndarray [shape=(n,) or (..., n)]
audio time series. Multi-channel is supported.
sr : number > 0 [scalar]
sampling rate of ``y``
Examples
--------
>>> # Load audio from path
>>> from xpark.dataset import AudioCompute, from_items
>>> from xpark.dataset.expressions import col
>>>
>>> items = ["sample.wav", "cos://my_bucket/sample.wav", "http://127.0.0.1:12345/sample.wav"]
>>> from_items(items).with_column("audio_data", AudioCompute.load(col("item"))).show()
{'item': 'sample.wav', 'audio_data': {'y': array([-7.4705182e-05, -5.2042997e-05, 5.3031382e-04, ...,
-1.2249170e-02, -7.8290682e-03, 0.0000000e+00], shape=(77286,), dtype=float32), 'sr': 22050}}
...
>>> # Load audio from bytes
>>> from xpark.dataset import AudioCompute, from_items
>>> from xpark.dataset.expressions import col
>>>
>>> with open("sample.wav", "rb") as f:
>>> items = [f.read()]
>>> from_items(items).with_column("audio_data", AudioCompute.load(col("item"))).show()
{'item': b'RIFFD...', 'audio_data': {'y': array([-7.4705182e-05, -5.2042997e-05, 5.3031382e-04, ...,
-1.2249170e-02, -7.8290682e-03, 0.0000000e+00], shape=(77286,), dtype=float32), 'sr': 22050}}
>>> # Load and resample the audio to 16000 samplerate
>>> from xpark.dataset import AudioCompute, from_items
>>> from xpark.dataset.expressions import col
>>>
>>> from_items(["sample.wav"]).with_column("audio_data", AudioCompute.load(col("item"))).show()
{'item': 'sample.wav', 'audio_data': {'y': array([-6.1035156e-05, 9.1552734e-05, 1.0681152e-03, ...,
-2.1972656e-03, -1.1383057e-02, -8.8195801e-03], shape=(56080,), dtype=float32), 'sr': 16000}}
"""
def _load_audio(fio):
return librosa.load(fio, sr=sr, mono=mono, offset=offset, duration=duration, dtype=dtype, res_type=res_type)
def _convert_to_array(results):
data = []
sample_rates = []
for y, sr in results:
data.append(y)
sample_rates.append(sr)
return pa.StructArray.from_arrays(
[
ArrowTensorArray.from_numpy(data),
pa.array(sample_rates, type=SOUNDFILE_INFO_METADATA_MAP["samplerate"].dtype),
],
names=["y", "sr"],
)
return _vectorized_process(audios, fn=_load_audio, return_fn=_convert_to_array)
@staticmethod
@udf(return_dtype=pa.list_(ArrowVariableShapedTensorType(pa.float32(), ndim=1)))
def split_by_duration(
audios: pa.ChunkedArray,
sample_rate: int | pa.ChunkedArray | None = None,
max_audio_clip_s: int = 30,
overlap_chunk_second: int = 1,
min_energy_split_window_size: int = 1600,
) -> pa.Array:
"""Split audio by duration.
Parameters
----------
audios : string, pathlib.Path, http URL, audio bytes, or ndarray
path to the input audio.
Any codec supported by `soundfile` or `audioread` will work.
The audio bytes is all the bytes of the audio file, not the ndarray.
If the input audio type is ndarray, then sample rate must be specified.
sample_rate : int, optional
The input audio sample rate. If the input audio type is **not** ndarray,
then sample rate is not used.
max_audio_clip_s : int, optional
Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if allow_audio_chunking
evaluates to True, otherwise it will be rejected.
overlap_chunk_second : int, optional
Overlap duration in seconds between consecutive audio chunks when splitting
long audio.
min_energy_split_window_size : int, optional
Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
at 16kHz. If None, no chunking will be done.
Returns
-------
An array of ArrowVariableShapedTensorArray, each element of the array is a part of the audio in ndarray format.
"""
if not audios:
return pa.array([], type=pa.string())
if max_audio_clip_s <= 0:
raise ValueError("max_audio_clip_s must be positive.")
if overlap_chunk_second < 0:
raise ValueError("overlap_chunk_second must be zero or positive.")
if min_energy_split_window_size < 0:
raise ValueError("min_energy_split_window_size must be positive.")
py_audio_array = audios.to_pylist()
item = py_audio_array[0]
if isinstance(item, np.ndarray):
if sample_rate is None:
raise ValueError("Sample rate must be specified for ndarray audios.")
if type(sample_rate) is int:
def _sample_rate_generator():
for _ in range(len(audios)):
yield sample_rate
gen_sample_rate = _sample_rate_generator()
else:
assert len(audios) == len(sample_rate)
gen_sample_rate = sample_rate.to_pylist()
split_list = map(
functools.partial(
split_audio,
max_audio_clip_s=max_audio_clip_s,
overlap_chunk_second=overlap_chunk_second,
min_energy_split_window_size=min_energy_split_window_size,
),
py_audio_array,
gen_sample_rate,
)
tensor_array_list = [ArrowVariableShapedTensorArray.from_numpy(parts) for parts in split_list]
nested_tensor_array = pa.array(
map(operator.attrgetter("storage"), tensor_array_list),
type=pa.list_(tensor_array_list[0].type.storage_type),
)
return nested_tensor_array.cast(pa.list_(tensor_array_list[0].type))
else:
# We do not resample the audio and keep the original signal tracks.
loaded_audio = AudioCompute.load.__metadata__.wrapped(audios, sr=None, mono=False)
return AudioCompute.split_by_duration.__metadata__.wrapped(
audios=loaded_audio.field("y"),
sample_rate=loaded_audio.field("sr"),
max_audio_clip_s=max_audio_clip_s,
overlap_chunk_second=overlap_chunk_second,
min_energy_split_window_size=min_energy_split_window_size,
)
for attr, metadata in SOUNDFILE_INFO_METADATA_MAP.items():
method = _get_audio_info_wrapper(attr)
setattr(AudioCompute, attr, staticmethod(udf(return_dtype=metadata.dtype)(method)))
def _gen_stub_code() -> str:
from io import StringIO
with StringIO() as stub:
stub.write('"""This is an auto-generated stub. Please do not modify this file."""\n\n')
stub.write("import numpy as np\n")
stub.write("from numpy.typing import DTypeLike\n")
stub.write("import ray.data.expressions\n\n")
stub.write("def _gen_stub_code() -> str: ...\n\n")
stub.write("class AudioCompute:\n")
replace_map = {
"<class 'numpy.float32'>": "np.float32",
"pa.ChunkedArray": "ray.data.expressions.ColumnExpr",
}
for name, fn in AudioCompute.__dict__.items():
if name != "__new__" and isinstance(fn, staticmethod):
signature = inspect.signature(fn).replace(return_annotation=Expr)
signature_string = multi_replace(str(signature), replace_map)
stub.write(" @staticmethod\n")
stub.write(f" def {name}{signature_string}: ...\n")
return stub.getvalue()
if __name__ == "__main__":
import subprocess
with open(__file__ + "i", "w") as f:
f.write(_gen_stub_code())
# ruff format
subprocess.run(["ruff", "format", __file__ + "i"])
subprocess.run(["ruff", "check", "--fix", __file__ + "i"])