from __future__ import annotations
import asyncio
import functools
import io
import logging
import os
from abc import ABCMeta, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from types import CoroutineType
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast
# Use np.ndarray type
import numpy as np
from xpark.dataset.constants import IO_WORKER_ENV, NOT_SET
from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
from xpark.dataset.utils import assert_type, auto_open, qps_limiter, split_audio, to_thread
if TYPE_CHECKING:
import faster_whisper
import librosa
import nemo.collections.asr as nemo_asr
import openai
import openvino_genai as ov_genai
import pyarrow as pa
import torch
import torch.nn.attention as torch_attention
import transformers
else:
faster_whisper = lazy_import("faster_whisper")
librosa = lazy_import("librosa")
nemo_asr = lazy_import("nemo.collections.asr", rename="nemo_asr")
openai = lazy_import("openai")
ov_genai = lazy_import("openvino_genai", rename="ov_genai")
pa = lazy_import("pyarrow", rename="pa")
torch = lazy_import("torch")
torch_attention = lazy_import("torch.nn.attention", rename="torch_attention")
transformers = lazy_import("transformers")
logger = logging.getLogger("ray")
SAMPLE_RATE = 16000
SpeechToTextModel = {
"openai/whisper-large-v3": {
"backend": {
"CPU": "transformers",
"GPU": "vLLM",
},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "openai/whisper-large-v3",
"model_revision": "06f233fe06e710322aca913c1bc4249a0d71fce1",
"quantizations": [None],
},
},
"modelscope": {
"pytorch": {
"model_id": "AI-ModelScope/whisper-large-v3",
"quantizations": [None],
},
},
},
},
"openai/whisper-tiny": {
"label": {"test"},
"backend": "transformers",
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "openai/whisper-tiny",
"model_revision": "169d4a4341b33bc18d8881c4b69c2e104e1cc0af",
"quantizations": [None],
},
},
"modelscope": {
"pytorch": {
"model_id": "openai-mirror/whisper-tiny",
"quantizations": [None],
},
},
},
},
"Systran/faster-whisper-large-v3": {
"backend": "faster-whisper",
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "Systran/faster-whisper-large-v3",
"model_revision": "edaa852ec7e145841d8ffdb056a99866b5f0a478",
"quantizations": [None],
},
},
"modelscope": {
"pytorch": {
"model_id": "Systran/faster-whisper-large-v3",
"quantizations": [None],
},
},
},
},
"Systran/faster-whisper-tiny": {
"label": {"test"},
"backend": "faster-whisper",
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "Systran/faster-whisper-tiny",
"model_revision": "d90ca5fe260221311c53c58e660288d3deb8d356",
"quantizations": [None],
},
},
"modelscope": {
"pytorch": {
"model_id": "Systran/faster-whisper-tiny",
"quantizations": [None],
},
},
},
},
"nvidia/parakeet-tdt-0.6b-v3": {
"label": {"test", "all"},
"backend": "nemo",
"model_specs": {
"huggingface": {
"nemo": {
"model_id": "nvidia/parakeet-tdt-0.6b-v3",
"model_revision": "be0d803fd1970eca8627f5467c208118f0f6c171",
"quantizations": [None],
},
},
"modelscope": {
"nemo": {
"model_id": "nv-community/parakeet-tdt-0.6b-v3",
"quantizations": [None],
},
},
},
},
"OpenVINO/whisper-large-v3-int8-ov": {
"backend": "openvino",
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "OpenVINO/whisper-large-v3-int8-ov",
"model_revision": "b31e1dcee5de24d49c6cc96da2a603eae409e722",
"quantizations": [None],
},
},
"modelscope": {
"pytorch": {
"model_id": "OpenVINO/whisper-large-v3-int8-ov",
"quantizations": [None],
},
},
},
},
"OpenVINO/whisper-tiny-int8-ov": {
"label": {"test"},
"backend": "openvino",
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "OpenVINO/whisper-tiny-int8-ov",
"model_revision": "f04a5e7ec899ce17bfbc3f0efcdea55df4356028",
"quantizations": [None],
},
},
"modelscope": {
"pytorch": {
"model_id": "OpenVINO/whisper-tiny-int8-ov",
"quantizations": [None],
},
},
},
},
}
BACKEND = (
Literal["faster-whisper", "transformers", "nemo", "openvino", "vLLM"]
| dict[Literal["CPU", "GPU"], Literal["faster-whisper", "transformers", "nemo", "openvino", "vLLM"]]
)
class SpeechToTextModelSpec(ModelSpec):
backend: BACKEND
ALL_SPEECH_TO_TEXT_MODELS = {k: SpeechToTextModelSpec.model_validate(v) for k, v in SpeechToTextModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_SPEECH_TO_TEXT_MODELS.items() if v.label & {"all"}]
InputType: TypeAlias = np.ndarray
OutputType: TypeAlias = str
class SpeechToTextBackend(BatchColumnClassProtocol, metaclass=ABCMeta):
@abstractmethod
def __init__(self, model_spec: SpeechToTextModelSpec): ...
@abstractmethod
def __call__(self, audio_list: list[InputType]) -> list[OutputType]: ...
class FasterWhisper(SpeechToTextBackend):
def __init__(self, model_spec: SpeechToTextModelSpec):
logger.info("Using FasterWhisper backend.")
model_path = cache_model("huggingface", "pytorch", model_spec, None)
compute_type = "float16" if torch.cuda.is_available() else "int8"
model = faster_whisper.WhisperModel(model_path, compute_type=compute_type)
# This is only for batch multiple segments of single audio
self.batched_model = faster_whisper.BatchedInferencePipeline(model=model)
def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
# TODO(baobliu): Expose kwargs to user.
results = []
for audio in audio_list:
# Multiple audio batching has not merged: https://github.com/SYSTRAN/faster-whisper/pull/1359
segments, info = self.batched_model.transcribe(audio, batch_size=16)
results.append(" ".join(seg.text for seg in segments))
return results
class Transformers(SpeechToTextBackend):
def __init__(self, model_spec: SpeechToTextModelSpec):
logger.info("Using Transformers backend.")
model_path = cache_model("huggingface", "pytorch", model_spec, None)
torch.set_float32_matmul_precision("high")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained(
model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True
).to(device)
# Enable static cache and compile the forward pass
model.generation_config.cache_implementation = "static"
model.generation_config.max_new_tokens = 256
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
processor = transformers.AutoProcessor.from_pretrained(model_path)
self.model = transformers.pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
texts = []
with torch_attention.sdpa_kernel(torch_attention.SDPBackend.MATH):
# If the input audio is an http or https url, it downloads all the contents in the preprocess step.
# https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/automatic_speech_recognition.py
# TODO(baobliu): Expose kwargs to user.
results = self.model(audio_list, return_timestamps=True)
for r in results:
texts.append(r["text"])
return texts
class Nemo(SpeechToTextBackend):
def __init__(self, model_spec: SpeechToTextModelSpec):
logger.info("Using Nemo backend.")
model_path, nemo_extracted_dir = cache_model("huggingface", "nemo", model_spec, None)
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
connector = SaveRestoreConnector()
connector._model_extracted_dir = nemo_extracted_dir
self.model = nemo_asr.models.ASRModel.restore_from(model_path, save_restore_connector=connector)
def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
# TODO(baobliu): Expose kwargs to user.
results = self.model.transcribe(audio_list)
return [r.text for r in results]
class OpenVINO(SpeechToTextBackend):
def __init__(self, model_spec: SpeechToTextModelSpec):
logger.info("Using OpenVINO backend.")
model_path = cache_model("huggingface", "pytorch", model_spec, None)
self.model = ov_genai.WhisperPipeline(
model_path, "CPU", INFERENCE_NUM_THREADS=1, ENABLE_MMAP=True, ENABLE_CPU_PINNING=False
)
def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
# TODO(baobliu): Expose kwargs to user.
results = []
for audio in audio_list:
r = self.model.generate(audio.tolist())
results.append(r.texts[0])
return results
class vLLM(SpeechToTextBackend):
def __init__(self, model_spec: SpeechToTextModelSpec):
logger.info("Using vLLM backend.")
os.environ["VLLM_USE_V1"] = "0"
model_path = cache_model("huggingface", "pytorch", model_spec, None)
from vllm import LLM
self.model = LLM(
model=model_path,
max_model_len=448,
max_num_seqs=16,
limit_mm_per_prompt={"audio": 1},
dtype="half",
)
def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
# TODO(baobliu): Expose kwargs to user.
from vllm import PromptType, SamplingParams
prompts = []
chunk_count = []
for audio in audio_list:
duration = librosa.get_duration(y=audio, sr=SAMPLE_RATE)
chunks = [audio] if duration < 30 else split_audio(audio, SAMPLE_RATE)
chunk_count.append(len(chunks))
for chunk in chunks:
prompt = { # Test implicit prompt
"encoder_prompt": {
# Whisper does not support encoder prompt.
"prompt": "",
"multi_modal_data": {
"audio": (chunk, SAMPLE_RATE),
},
},
# Whisper does not support language detection:
# https://github.com/vllm-project/vllm/issues/14174
"decoder_prompt": "<|prev|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
}
prompts.append(cast(PromptType, prompt))
assert len(prompts) == sum(chunk_count)
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=4096,
skip_special_tokens=False,
)
outputs = self.model.generate(prompts, sampling_params)
assert len(outputs) == len(prompts)
# Get output.
results = []
i = 0
for c in chunk_count:
text = "".join(outputs[j].outputs[0].text for j in range(i, i + c))
results.append(text)
i = i + c
return results
BACKENDS: dict[BACKEND, type[SpeechToTextBackend]] = {
"faster-whisper": FasterWhisper,
"transformers": Transformers,
"vLLM": vLLM,
"nemo": Nemo,
"openvino": OpenVINO,
}
TYPED_OPEN = {str: functools.partial(auto_open, mode="rb"), bytes: lambda b, *_args, **_kwargs: io.BytesIO(b)}
class SpeechToTextLocal(BatchColumnClassProtocol):
def __init__(self, model: str):
model_spec = ALL_SPEECH_TO_TEXT_MODELS[model]
backend = model_spec.backend
if isinstance(model_spec.backend, dict):
backend = model_spec.backend["GPU" if torch.cuda.is_available() else "CPU"]
self.model = BACKENDS[backend](model_spec)
self.executor = ThreadPoolExecutor(max_workers=1)
async def __call__(self, py_audio_array: list) -> pa.Array:
if not py_audio_array:
raise ValueError("Audio array is empty.")
item = py_audio_array[0]
# Check the ndarray ndim and yield the channel data if the channel is 1
if isinstance(item, np.ndarray):
formatted = []
for audio in py_audio_array:
if audio.ndim == 2:
if audio.shape[0] == 1:
formatted.append(audio[0])
else:
raise ValueError(
"Local transcription (Speech To Text) requires a single channel audio input, "
"however the audio has multiple channels."
)
elif audio.ndim > 2:
raise ValueError("The shape of audio ndarray is incorrect, we expect the ndim <= 2.")
else:
formatted.append(audio)
outputs: list[OutputType] = await to_thread(self.model, formatted, _executor=self.executor)
return pa.array(outputs, type=pa.string())
elif isinstance(item, (str, bytes)):
typed_open = TYPED_OPEN[type(item)]
storage_options = DatasetContext.get_current().storage_options
def _read_ndarray(_path):
# The Nemo backend only accepts local files, ndarrays, or torch tensors.
# HTTP URLs or bytes are not supported.
# We just use type here, do not support subtype.
with typed_open(_path, **storage_options) as f:
return librosa.load(f, sr=SAMPLE_RATE, mono=True)[0]
with ThreadPoolExecutor() as executor:
arrays = list(executor.map(_read_ndarray, py_audio_array))
return await self(arrays)
else:
raise ValueError(f"We expect a numpy ndarray or a str as input, got `{type(item)}`")
class SpeechToTextHttp(BatchColumnClassProtocol):
def __init__(
self,
base_url: str,
model: str,
api_key: str,
max_qps: int | None = None,
max_retries: int = 0,
**kwargs: dict[str, Any],
):
logger.info("Using Remote Http Transcription (Speech To Text)")
self.model = model
self.kwargs = kwargs
client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries)
self.create_transcriptions = qps_limiter(max_qps)(client.audio.transcriptions.create)
async def __call__(self, py_audio_array: list) -> pa.Array:
if not py_audio_array:
raise ValueError("Audio array is empty.")
item = py_audio_array[0]
storage_options = DatasetContext.get_current().storage_options
if isinstance(item, np.ndarray):
raise ValueError("Http remote transcription (Speech To Text) does not support numpy arrays.")
if isinstance(item, bytes):
requests = [
self.create_transcriptions(file=audio, model=self.model, **self.kwargs) for audio in py_audio_array
]
elif isinstance(item, str):
async def _create_request(path):
with await asyncio.to_thread(auto_open, path, "rb", **storage_options) as file:
return await self.create_transcriptions(file=file, model=self.model, **self.kwargs)
requests = [_create_request(audio) for audio in py_audio_array]
else:
raise ValueError(f"Http Transcription (Speech To Text) does not support type {type(item)}.")
responses = await asyncio.gather(*requests)
return pa.array([r.text for r in responses], type=pa.string())
[docs]
@udf(return_dtype=DataType.string())
class SpeechToText(BatchColumnClassProtocol):
__doc__ = f"""Speech to text processor for CPU, GPU and remote Http requests.
Args:
_local_model: The speech to text model name for CPU or GPU, available models: {AVAILABLE_MODELS}
base_url: The base URL of the LLM server.
model: The request model name.
api_key: The request API key.
batch_rows: The number of rows to request once.
max_qps: The maximum number of requests per second.
max_retries: The maximum number of retries per request in the event of failures.
We retry with exponential backoff upto this specific maximum retries.
**kwargs: Keyword arguments to pass to the `openai.AsyncClient.audio.transcriptions.create
<https://github.com/openai/openai-python/blob/main/src/openai/resources/audio/transcriptions.py>`_ API.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import SpeechToText, from_items
ds = from_items(["multilingual.mp3"])
ds = ds.with_column(
"text",
SpeechToText(
# Local transcriptions model.
"openai/whisper-large-v3",
# For remote transcriptions requests.
base_url="http://127.0.0.1:9997/v1",
model="whisper1",
)
# One IO worker for HTTP request, 10 CPU workers for local transcriptions.
.options(num_workers={{"CPU": 10, "IO": 1}})
.with_column(col("item")),
)
print(ds.take(2))
"""
model: BatchColumnClassProtocol
def __init__(
self,
_local_model: str | None = None,
/,
*,
base_url: str | None = None,
model: str | None = None,
api_key: str = NOT_SET,
max_qps: int | None = None,
max_retries: int = 0,
**kwargs: dict[str, Any],
):
if _local_model is None and base_url is None:
raise ValueError("Either _local_model or base_url must be specified.")
if os.environ.get(IO_WORKER_ENV):
if base_url is None:
raise ValueError("base_url must be specified for IO worker.")
if model is None:
raise ValueError("model must be specified if base_url is specified.")
self.model = SpeechToTextHttp(
base_url=base_url, model=model, api_key=api_key, max_qps=max_qps, max_retries=max_retries, **kwargs
)
else:
if _local_model is None:
raise TypeError("_local_model must be specified for CPU or GPU worker.")
self.model = SpeechToTextLocal(_local_model)
async def __call__(self, audio_array: pa.ChunkedArray) -> pa.Array:
"""Transcript the audio array to text array.
Args:
audio_array: The audio array. The type of array is either:
- `str` that is either the filename of a local audio file, or a public URL address to download the
audio file. The file will be read at the correct sampling rate to get the waveform using
*ffmpeg*. This requires *ffmpeg* to be installed on the system.
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
same way.
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
Raw audio at the correct sampling rate (For example, 16K).
Returns:
The transcribed text array.
"""
if len(audio_array) == 0:
return pa.array([], type=pa.string())
return await assert_type(self.model(audio_array.to_pylist()), CoroutineType)