Source code for xpark.dataset.processors.speech_to_text

from __future__ import annotations

import asyncio
import functools
import io
import logging
import os
from abc import ABCMeta, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from string import Template
from types import CoroutineType
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast

# Use np.ndarray type
import numpy as np

from xpark.dataset.constants import IO_WORKER_ENV, NOT_SET
from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
from xpark.dataset.utils import assert_type, auto_open, qps_limiter, split_audio, to_thread

if TYPE_CHECKING:
    import faster_whisper
    import librosa
    import nemo.collections.asr as nemo_asr
    import openai
    import openvino_genai as ov_genai
    import pyarrow as pa
    import torch
    import torch.nn.attention as torch_attention
    import transformers
else:
    faster_whisper = lazy_import("faster_whisper")
    librosa = lazy_import("librosa")
    nemo_asr = lazy_import("nemo.collections.asr", rename="nemo_asr")
    openai = lazy_import("openai")
    ov_genai = lazy_import("openvino_genai", rename="ov_genai")
    pa = lazy_import("pyarrow", rename="pa")
    torch = lazy_import("torch")
    torch_attention = lazy_import("torch.nn.attention", rename="torch_attention")
    transformers = lazy_import("transformers")

logger = logging.getLogger("ray")

SAMPLE_RATE = 16000

SpeechToTextModel = {
    "openai/whisper-large-v3": {
        "backend": {
            "CPU": "transformers",
            "GPU": "vLLM",
        },
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "openai/whisper-large-v3",
                    "model_revision": "06f233fe06e710322aca913c1bc4249a0d71fce1",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "AI-ModelScope/whisper-large-v3",
                    "quantizations": [None],
                },
            },
        },
    },
    "openai/whisper-tiny": {
        "label": {"test"},
        "backend": "transformers",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "openai/whisper-tiny",
                    "model_revision": "169d4a4341b33bc18d8881c4b69c2e104e1cc0af",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "openai-mirror/whisper-tiny",
                    "quantizations": [None],
                },
            },
        },
    },
    "Systran/faster-whisper-large-v3": {
        "backend": "faster-whisper",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Systran/faster-whisper-large-v3",
                    "model_revision": "edaa852ec7e145841d8ffdb056a99866b5f0a478",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "Systran/faster-whisper-large-v3",
                    "quantizations": [None],
                },
            },
        },
    },
    "Systran/faster-whisper-tiny": {
        "label": {"test"},
        "backend": "faster-whisper",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Systran/faster-whisper-tiny",
                    "model_revision": "d90ca5fe260221311c53c58e660288d3deb8d356",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "Systran/faster-whisper-tiny",
                    "quantizations": [None],
                },
            },
        },
    },
    "nvidia/parakeet-tdt-0.6b-v3": {
        "label": {"test", "all"},
        "backend": "nemo",
        "model_specs": {
            "huggingface": {
                "nemo": {
                    "model_id": "nvidia/parakeet-tdt-0.6b-v3",
                    "model_revision": "be0d803fd1970eca8627f5467c208118f0f6c171",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "nemo": {
                    "model_id": "nv-community/parakeet-tdt-0.6b-v3",
                    "quantizations": [None],
                },
            },
        },
    },
    "OpenVINO/whisper-large-v3-int8-ov": {
        "backend": "openvino",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "OpenVINO/whisper-large-v3-int8-ov",
                    "model_revision": "b31e1dcee5de24d49c6cc96da2a603eae409e722",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "OpenVINO/whisper-large-v3-int8-ov",
                    "quantizations": [None],
                },
            },
        },
    },
    "OpenVINO/whisper-tiny-int8-ov": {
        "label": {"test"},
        "backend": "openvino",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "OpenVINO/whisper-tiny-int8-ov",
                    "model_revision": "f04a5e7ec899ce17bfbc3f0efcdea55df4356028",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "OpenVINO/whisper-tiny-int8-ov",
                    "quantizations": [None],
                },
            },
        },
    },
}


BACKEND = (
    Literal["faster-whisper", "transformers", "nemo", "openvino", "vLLM"]
    | dict[Literal["CPU", "GPU"], Literal["faster-whisper", "transformers", "nemo", "openvino", "vLLM"]]
)


class SpeechToTextModelSpec(ModelSpec):
    backend: BACKEND


ALL_SPEECH_TO_TEXT_MODELS = {k: SpeechToTextModelSpec.model_validate(v) for k, v in SpeechToTextModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_SPEECH_TO_TEXT_MODELS.items() if v.label & {"all"}]

InputType: TypeAlias = np.ndarray
OutputType: TypeAlias = str


class SpeechToTextBackend(BatchColumnClassProtocol, metaclass=ABCMeta):
    @abstractmethod
    def __init__(self, model_spec: SpeechToTextModelSpec): ...

    @abstractmethod
    def __call__(self, audio_list: list[InputType]) -> list[OutputType]: ...


class FasterWhisper(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using FasterWhisper backend.")
        model_path = cache_model("huggingface", "pytorch", model_spec, None)
        compute_type = "float16" if torch.cuda.is_available() else "int8"
        model = faster_whisper.WhisperModel(model_path, compute_type=compute_type)
        # This is only for batch multiple segments of single audio
        self.batched_model = faster_whisper.BatchedInferencePipeline(model=model)

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        # TODO(baobliu): Expose kwargs to user.
        results = []
        for audio in audio_list:
            # Multiple audio batching has not merged: https://github.com/SYSTRAN/faster-whisper/pull/1359
            segments, info = self.batched_model.transcribe(audio, batch_size=16)
            results.append(" ".join(seg.text for seg in segments))
        return results


class Transformers(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using Transformers backend.")
        model_path = cache_model("huggingface", "pytorch", model_spec, None)

        torch.set_float32_matmul_precision("high")

        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained(
            model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True
        ).to(device)

        # Enable static cache and compile the forward pass
        model.generation_config.cache_implementation = "static"
        model.generation_config.max_new_tokens = 256
        model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

        processor = transformers.AutoProcessor.from_pretrained(model_path)

        self.model = transformers.pipeline(
            "automatic-speech-recognition",
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            torch_dtype=torch_dtype,
            device=device,
        )

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        texts = []
        with torch_attention.sdpa_kernel(torch_attention.SDPBackend.MATH):
            # If the input audio is an http or https url, it downloads all the contents in the preprocess step.
            # https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/automatic_speech_recognition.py
            # TODO(baobliu): Expose kwargs to user.
            results = self.model(audio_list, return_timestamps=True)
            for r in results:
                texts.append(r["text"])
        return texts


class Nemo(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using Nemo backend.")
        model_path, nemo_extracted_dir = cache_model("huggingface", "nemo", model_spec, None)
        from nemo.core.connectors.save_restore_connector import SaveRestoreConnector

        connector = SaveRestoreConnector()
        connector._model_extracted_dir = nemo_extracted_dir
        self.model = nemo_asr.models.ASRModel.restore_from(model_path, save_restore_connector=connector)

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        # TODO(baobliu): Expose kwargs to user.
        results = self.model.transcribe(audio_list)
        return [r.text for r in results]


class OpenVINO(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using OpenVINO backend.")
        model_path = cache_model("huggingface", "pytorch", model_spec, None)
        self.model = ov_genai.WhisperPipeline(
            model_path, "CPU", INFERENCE_NUM_THREADS=1, ENABLE_MMAP=True, ENABLE_CPU_PINNING=False
        )

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        # TODO(baobliu): Expose kwargs to user.
        results = []
        for audio in audio_list:
            r = self.model.generate(audio.tolist())
            results.append(r.texts[0])
        return results


class vLLM(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using vLLM backend.")
        os.environ["VLLM_USE_V1"] = "0"

        model_path = cache_model("huggingface", "pytorch", model_spec, None)

        from vllm import LLM

        self.model = LLM(
            model=model_path,
            max_model_len=448,
            max_num_seqs=16,
            limit_mm_per_prompt={"audio": 1},
            dtype="half",
        )

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        # TODO(baobliu): Expose kwargs to user.

        from vllm import PromptType, SamplingParams

        prompts = []
        chunk_count = []
        for audio in audio_list:
            duration = librosa.get_duration(y=audio, sr=SAMPLE_RATE)
            chunks = [audio] if duration < 30 else split_audio(audio, SAMPLE_RATE)
            chunk_count.append(len(chunks))
            for chunk in chunks:
                prompt = {  # Test implicit prompt
                    "encoder_prompt": {
                        # Whisper does not support encoder prompt.
                        "prompt": "",
                        "multi_modal_data": {
                            "audio": (chunk, SAMPLE_RATE),
                        },
                    },
                    # Whisper does not support language detection:
                    # https://github.com/vllm-project/vllm/issues/14174
                    "decoder_prompt": "<|prev|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
                }
                prompts.append(cast(PromptType, prompt))
        assert len(prompts) == sum(chunk_count)

        # Create a sampling params object.
        sampling_params = SamplingParams(
            temperature=0,
            top_p=1.0,
            max_tokens=4096,
            skip_special_tokens=False,
        )

        outputs = self.model.generate(prompts, sampling_params)
        assert len(outputs) == len(prompts)

        # Get output.
        results = []
        i = 0
        for c in chunk_count:
            text = "".join(outputs[j].outputs[0].text for j in range(i, i + c))
            results.append(text)
            i = i + c
        return results


BACKENDS: dict[BACKEND, type[SpeechToTextBackend]] = {
    "faster-whisper": FasterWhisper,
    "transformers": Transformers,
    "vLLM": vLLM,
    "nemo": Nemo,
    "openvino": OpenVINO,
}


TYPED_OPEN = {str: functools.partial(auto_open, mode="rb"), bytes: lambda b, *_args, **_kwargs: io.BytesIO(b)}


class SpeechToTextLocal(BatchColumnClassProtocol):
    def __init__(self, model: str):
        model_spec = ALL_SPEECH_TO_TEXT_MODELS[model]
        backend = model_spec.backend
        if isinstance(model_spec.backend, dict):
            backend = model_spec.backend["GPU" if torch.cuda.is_available() else "CPU"]
        self.model = BACKENDS[backend](model_spec)
        self.executor = ThreadPoolExecutor(max_workers=1)

    async def __call__(self, py_audio_array: list) -> pa.Array:
        if not py_audio_array:
            raise ValueError("Audio array is empty.")

        item = py_audio_array[0]

        # Check the ndarray ndim and yield the channel data if the channel is 1
        if isinstance(item, np.ndarray):
            formatted = []
            for audio in py_audio_array:
                if audio.ndim == 2:
                    if audio.shape[0] == 1:
                        formatted.append(audio[0])
                    else:
                        raise ValueError(
                            "Local transcription (Speech To Text) requires a single channel audio input, "
                            "however the audio has multiple channels."
                        )
                elif audio.ndim > 2:
                    raise ValueError("The shape of audio ndarray is incorrect, we expect the ndim <= 2.")
                else:
                    formatted.append(audio)

            outputs: list[OutputType] = await to_thread(self.model, formatted, _executor=self.executor)
            return pa.array(outputs, type=pa.string())

        elif isinstance(item, (str, bytes)):
            typed_open = TYPED_OPEN[type(item)]
            storage_options = DatasetContext.get_current().storage_options

            def _read_ndarray(_path):
                # The Nemo backend only accepts local files, ndarrays, or torch tensors.
                # HTTP URLs or bytes are not supported.
                # We just use type here, do not support subtype.
                with typed_open(_path, **storage_options) as f:
                    return librosa.load(f, sr=SAMPLE_RATE, mono=True)[0]

            with ThreadPoolExecutor() as executor:
                arrays = list(executor.map(_read_ndarray, py_audio_array))
                return await self(arrays)

        else:
            raise ValueError(f"We expect a numpy ndarray or a str as input, got `{type(item)}`")


class SpeechToTextHttp(BatchColumnClassProtocol):
    def __init__(
        self,
        base_url: str,
        model: str,
        api_key: str,
        max_qps: int | None = None,
        max_retries: int = 0,
        **kwargs: dict[str, Any],
    ):
        logger.info("Using Remote Http Transcription (Speech To Text)")
        self.model = model
        self.kwargs = kwargs
        client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries)
        self.create_transcriptions = qps_limiter(max_qps)(client.audio.transcriptions.create)

    async def __call__(self, py_audio_array: list) -> pa.Array:
        if not py_audio_array:
            raise ValueError("Audio array is empty.")

        item = py_audio_array[0]
        storage_options = DatasetContext.get_current().storage_options

        if isinstance(item, np.ndarray):
            raise ValueError("Http remote transcription (Speech To Text) does not support numpy arrays.")

        if isinstance(item, bytes):
            requests = [
                self.create_transcriptions(file=audio, model=self.model, **self.kwargs) for audio in py_audio_array
            ]

        elif isinstance(item, str):

            async def _create_request(path):
                with await asyncio.to_thread(auto_open, path, "rb", **storage_options) as file:
                    return await self.create_transcriptions(file=file, model=self.model, **self.kwargs)

            requests = [_create_request(audio) for audio in py_audio_array]

        else:
            raise ValueError(f"Http Transcription (Speech To Text) does not support type {type(item)}.")

        responses = await asyncio.gather(*requests)
        return pa.array([r.text for r in responses], type=pa.string())


[docs] @udf(return_dtype=DataType.string()) class SpeechToText(BatchColumnClassProtocol): __doc__ = Template("""Speech to text processor for CPU, GPU and remote Http requests. Args: _local_model: The speech to text model name for CPU or GPU, available models: $AVAILABLE_MODELS base_url: The base URL of the LLM server. model: The request model name. api_key: The request API key. batch_rows: The number of rows to request once. max_qps: The maximum number of requests per second. max_retries: The maximum number of retries per request in the event of failures. We retry with exponential backoff upto this specific maximum retries. **kwargs: Keyword arguments to pass to the `openai.AsyncClient.audio.transcriptions.create <https://github.com/openai/openai-python/blob/main/src/openai/resources/audio/transcriptions.py>`_ API. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import SpeechToText, from_items ds = from_items(["multilingual.mp3"]) ds = ds.with_column( "text", SpeechToText( # Local transcriptions model. "openai/whisper-large-v3", # For remote transcriptions requests. base_url="http://127.0.0.1:9997/v1", model="whisper1", ) # One IO worker for HTTP request, 10 CPU workers for local transcriptions. .options(num_workers={"CPU": 10, "IO": 1}) .with_column(col("item")), ) print(ds.take(2)) """).safe_substitute(AVAILABLE_MODELS=AVAILABLE_MODELS) model: BatchColumnClassProtocol def __init__( self, _local_model: str | None = None, /, *, base_url: str | None = None, model: str | None = None, api_key: str = NOT_SET, max_qps: int | None = None, max_retries: int = 0, **kwargs: dict[str, Any], ): if _local_model is None and base_url is None: raise ValueError("Either _local_model or base_url must be specified.") if os.environ.get(IO_WORKER_ENV): if base_url is None: raise ValueError("base_url must be specified for IO worker.") if model is None: raise ValueError("model must be specified if base_url is specified.") self.model = SpeechToTextHttp( base_url=base_url, model=model, api_key=api_key, max_qps=max_qps, max_retries=max_retries, **kwargs ) else: if _local_model is None: raise TypeError("_local_model must be specified for CPU or GPU worker.") self.model = SpeechToTextLocal(_local_model) async def __call__(self, audio_array: pa.ChunkedArray) -> pa.Array: """Transcript the audio array to text array. Args: audio_array: The audio array. The type of array is either: - `str` that is either the filename of a local audio file, or a public URL address to download the audio file. The file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system. - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the same way. - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) Raw audio at the correct sampling rate (For example, 16K). Returns: The transcribed text array. """ if len(audio_array) == 0: return pa.array([], type=pa.string()) return await assert_type(self.model(audio_array.to_pylist()), CoroutineType)