Source code for xpark.dataset.processors.speech_to_text

from __future__ import annotations

import asyncio
import functools
import io
import logging
import os
from abc import ABCMeta, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from types import CoroutineType
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast

# Use np.ndarray type
import numpy as np

from xpark.dataset.constants import IO_WORKER_ENV, NOT_SET
from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
from xpark.dataset.utils import assert_type, auto_open, qps_limiter, split_audio, to_thread

if TYPE_CHECKING:
    import faster_whisper
    import librosa
    import nemo.collections.asr as nemo_asr
    import openai
    import openvino_genai as ov_genai
    import pyarrow as pa
    import torch
    import torch.nn.attention as torch_attention
    import transformers
else:
    faster_whisper = lazy_import("faster_whisper")
    librosa = lazy_import("librosa")
    nemo_asr = lazy_import("nemo.collections.asr", rename="nemo_asr")
    openai = lazy_import("openai")
    ov_genai = lazy_import("openvino_genai", rename="ov_genai")
    pa = lazy_import("pyarrow", rename="pa")
    torch = lazy_import("torch")
    torch_attention = lazy_import("torch.nn.attention", rename="torch_attention")
    transformers = lazy_import("transformers")

logger = logging.getLogger("ray")

SAMPLE_RATE = 16000

SpeechToTextModel = {
    "openai/whisper-large-v3": {
        "backend": {
            "CPU": "transformers",
            "GPU": "vLLM",
        },
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "openai/whisper-large-v3",
                    "model_revision": "06f233fe06e710322aca913c1bc4249a0d71fce1",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "AI-ModelScope/whisper-large-v3",
                    "quantizations": [None],
                },
            },
        },
    },
    "openai/whisper-tiny": {
        "label": {"test"},
        "backend": "transformers",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "openai/whisper-tiny",
                    "model_revision": "169d4a4341b33bc18d8881c4b69c2e104e1cc0af",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "openai-mirror/whisper-tiny",
                    "quantizations": [None],
                },
            },
        },
    },
    "Systran/faster-whisper-large-v3": {
        "backend": "faster-whisper",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Systran/faster-whisper-large-v3",
                    "model_revision": "edaa852ec7e145841d8ffdb056a99866b5f0a478",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "Systran/faster-whisper-large-v3",
                    "quantizations": [None],
                },
            },
        },
    },
    "Systran/faster-whisper-tiny": {
        "label": {"test"},
        "backend": "faster-whisper",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Systran/faster-whisper-tiny",
                    "model_revision": "d90ca5fe260221311c53c58e660288d3deb8d356",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "Systran/faster-whisper-tiny",
                    "quantizations": [None],
                },
            },
        },
    },
    "nvidia/parakeet-tdt-0.6b-v3": {
        "label": {"test", "all"},
        "backend": "nemo",
        "model_specs": {
            "huggingface": {
                "nemo": {
                    "model_id": "nvidia/parakeet-tdt-0.6b-v3",
                    "model_revision": "be0d803fd1970eca8627f5467c208118f0f6c171",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "nemo": {
                    "model_id": "nv-community/parakeet-tdt-0.6b-v3",
                    "quantizations": [None],
                },
            },
        },
    },
    "OpenVINO/whisper-large-v3-int8-ov": {
        "backend": "openvino",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "OpenVINO/whisper-large-v3-int8-ov",
                    "model_revision": "b31e1dcee5de24d49c6cc96da2a603eae409e722",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "OpenVINO/whisper-large-v3-int8-ov",
                    "quantizations": [None],
                },
            },
        },
    },
    "OpenVINO/whisper-tiny-int8-ov": {
        "label": {"test"},
        "backend": "openvino",
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "OpenVINO/whisper-tiny-int8-ov",
                    "model_revision": "f04a5e7ec899ce17bfbc3f0efcdea55df4356028",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "pytorch": {
                    "model_id": "OpenVINO/whisper-tiny-int8-ov",
                    "quantizations": [None],
                },
            },
        },
    },
}


BACKEND = (
    Literal["faster-whisper", "transformers", "nemo", "openvino", "vLLM"]
    | dict[Literal["CPU", "GPU"], Literal["faster-whisper", "transformers", "nemo", "openvino", "vLLM"]]
)


class SpeechToTextModelSpec(ModelSpec):
    backend: BACKEND


ALL_SPEECH_TO_TEXT_MODELS = {k: SpeechToTextModelSpec.model_validate(v) for k, v in SpeechToTextModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_SPEECH_TO_TEXT_MODELS.items() if v.label & {"all"}]

InputType: TypeAlias = np.ndarray
OutputType: TypeAlias = str


class SpeechToTextBackend(BatchColumnClassProtocol, metaclass=ABCMeta):
    @abstractmethod
    def __init__(self, model_spec: SpeechToTextModelSpec): ...

    @abstractmethod
    def __call__(self, audio_list: list[InputType]) -> list[OutputType]: ...


class FasterWhisper(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using FasterWhisper backend.")
        model_path = cache_model("huggingface", "pytorch", model_spec, None)
        compute_type = "float16" if torch.cuda.is_available() else "int8"
        model = faster_whisper.WhisperModel(model_path, compute_type=compute_type)
        # This is only for batch multiple segments of single audio
        self.batched_model = faster_whisper.BatchedInferencePipeline(model=model)

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        # TODO(baobliu): Expose kwargs to user.
        results = []
        for audio in audio_list:
            # Multiple audio batching has not merged: https://github.com/SYSTRAN/faster-whisper/pull/1359
            segments, info = self.batched_model.transcribe(audio, batch_size=16)
            results.append(" ".join(seg.text for seg in segments))
        return results


class Transformers(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using Transformers backend.")
        model_path = cache_model("huggingface", "pytorch", model_spec, None)

        torch.set_float32_matmul_precision("high")

        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained(
            model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True
        ).to(device)

        # Enable static cache and compile the forward pass
        model.generation_config.cache_implementation = "static"
        model.generation_config.max_new_tokens = 256
        model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

        processor = transformers.AutoProcessor.from_pretrained(model_path)

        self.model = transformers.pipeline(
            "automatic-speech-recognition",
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            torch_dtype=torch_dtype,
            device=device,
        )

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        texts = []
        with torch_attention.sdpa_kernel(torch_attention.SDPBackend.MATH):
            # If the input audio is an http or https url, it downloads all the contents in the preprocess step.
            # https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/automatic_speech_recognition.py
            # TODO(baobliu): Expose kwargs to user.
            results = self.model(audio_list, return_timestamps=True)
            for r in results:
                texts.append(r["text"])
        return texts


class Nemo(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using Nemo backend.")
        model_path, nemo_extracted_dir = cache_model("huggingface", "nemo", model_spec, None)
        from nemo.core.connectors.save_restore_connector import SaveRestoreConnector

        connector = SaveRestoreConnector()
        connector._model_extracted_dir = nemo_extracted_dir
        self.model = nemo_asr.models.ASRModel.restore_from(model_path, save_restore_connector=connector)

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        # TODO(baobliu): Expose kwargs to user.
        results = self.model.transcribe(audio_list)
        return [r.text for r in results]


class OpenVINO(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using OpenVINO backend.")
        model_path = cache_model("huggingface", "pytorch", model_spec, None)
        self.model = ov_genai.WhisperPipeline(
            model_path, "CPU", INFERENCE_NUM_THREADS=1, ENABLE_MMAP=True, ENABLE_CPU_PINNING=False
        )

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        # TODO(baobliu): Expose kwargs to user.
        results = []
        for audio in audio_list:
            r = self.model.generate(audio.tolist())
            results.append(r.texts[0])
        return results


class vLLM(SpeechToTextBackend):
    def __init__(self, model_spec: SpeechToTextModelSpec):
        logger.info("Using vLLM backend.")
        os.environ["VLLM_USE_V1"] = "0"

        model_path = cache_model("huggingface", "pytorch", model_spec, None)

        from vllm import LLM

        self.model = LLM(
            model=model_path,
            max_model_len=448,
            max_num_seqs=16,
            limit_mm_per_prompt={"audio": 1},
            dtype="half",
        )

    def __call__(self, audio_list: list[InputType]) -> list[OutputType]:
        # TODO(baobliu): Expose kwargs to user.

        from vllm import PromptType, SamplingParams

        prompts = []
        chunk_count = []
        for audio in audio_list:
            duration = librosa.get_duration(y=audio, sr=SAMPLE_RATE)
            chunks = [audio] if duration < 30 else split_audio(audio, SAMPLE_RATE)
            chunk_count.append(len(chunks))
            for chunk in chunks:
                prompt = {  # Test implicit prompt
                    "encoder_prompt": {
                        # Whisper does not support encoder prompt.
                        "prompt": "",
                        "multi_modal_data": {
                            "audio": (chunk, SAMPLE_RATE),
                        },
                    },
                    # Whisper does not support language detection:
                    # https://github.com/vllm-project/vllm/issues/14174
                    "decoder_prompt": "<|prev|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
                }
                prompts.append(cast(PromptType, prompt))
        assert len(prompts) == sum(chunk_count)

        # Create a sampling params object.
        sampling_params = SamplingParams(
            temperature=0,
            top_p=1.0,
            max_tokens=4096,
            skip_special_tokens=False,
        )

        outputs = self.model.generate(prompts, sampling_params)
        assert len(outputs) == len(prompts)

        # Get output.
        results = []
        i = 0
        for c in chunk_count:
            text = "".join(outputs[j].outputs[0].text for j in range(i, i + c))
            results.append(text)
            i = i + c
        return results


BACKENDS: dict[BACKEND, type[SpeechToTextBackend]] = {
    "faster-whisper": FasterWhisper,
    "transformers": Transformers,
    "vLLM": vLLM,
    "nemo": Nemo,
    "openvino": OpenVINO,
}


TYPED_OPEN = {str: functools.partial(auto_open, mode="rb"), bytes: lambda b, *_args, **_kwargs: io.BytesIO(b)}


class SpeechToTextLocal(BatchColumnClassProtocol):
    def __init__(self, model: str):
        model_spec = ALL_SPEECH_TO_TEXT_MODELS[model]
        backend = model_spec.backend
        if isinstance(model_spec.backend, dict):
            backend = model_spec.backend["GPU" if torch.cuda.is_available() else "CPU"]
        self.model = BACKENDS[backend](model_spec)
        self.executor = ThreadPoolExecutor(max_workers=1)

    async def __call__(self, py_audio_array: list) -> pa.Array:
        if not py_audio_array:
            raise ValueError("Audio array is empty.")

        item = py_audio_array[0]

        # Check the ndarray ndim and yield the channel data if the channel is 1
        if isinstance(item, np.ndarray):
            formatted = []
            for audio in py_audio_array:
                if audio.ndim == 2:
                    if audio.shape[0] == 1:
                        formatted.append(audio[0])
                    else:
                        raise ValueError(
                            "Local transcription (Speech To Text) requires a single channel audio input, "
                            "however the audio has multiple channels."
                        )
                elif audio.ndim > 2:
                    raise ValueError("The shape of audio ndarray is incorrect, we expect the ndim <= 2.")
                else:
                    formatted.append(audio)

            outputs: list[OutputType] = await to_thread(self.model, formatted, _executor=self.executor)
            return pa.array(outputs, type=pa.string())

        elif isinstance(item, (str, bytes)):
            typed_open = TYPED_OPEN[type(item)]
            storage_options = DatasetContext.get_current().storage_options

            def _read_ndarray(_path):
                # The Nemo backend only accepts local files, ndarrays, or torch tensors.
                # HTTP URLs or bytes are not supported.
                # We just use type here, do not support subtype.
                with typed_open(_path, **storage_options) as f:
                    return librosa.load(f, sr=SAMPLE_RATE, mono=True)[0]

            with ThreadPoolExecutor() as executor:
                arrays = list(executor.map(_read_ndarray, py_audio_array))
                return await self(arrays)

        else:
            raise ValueError(f"We expect a numpy ndarray or a str as input, got `{type(item)}`")


class SpeechToTextHttp(BatchColumnClassProtocol):
    def __init__(
        self,
        base_url: str,
        model: str,
        api_key: str,
        max_qps: int | None = None,
        max_retries: int = 0,
        **kwargs: dict[str, Any],
    ):
        logger.info("Using Remote Http Transcription (Speech To Text)")
        self.model = model
        self.kwargs = kwargs
        client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries)
        self.create_transcriptions = qps_limiter(max_qps)(client.audio.transcriptions.create)

    async def __call__(self, py_audio_array: list) -> pa.Array:
        if not py_audio_array:
            raise ValueError("Audio array is empty.")

        item = py_audio_array[0]
        storage_options = DatasetContext.get_current().storage_options

        if isinstance(item, np.ndarray):
            raise ValueError("Http remote transcription (Speech To Text) does not support numpy arrays.")

        if isinstance(item, bytes):
            requests = [
                self.create_transcriptions(file=audio, model=self.model, **self.kwargs) for audio in py_audio_array
            ]

        elif isinstance(item, str):

            async def _create_request(path):
                with await asyncio.to_thread(auto_open, path, "rb", **storage_options) as file:
                    return await self.create_transcriptions(file=file, model=self.model, **self.kwargs)

            requests = [_create_request(audio) for audio in py_audio_array]

        else:
            raise ValueError(f"Http Transcription (Speech To Text) does not support type {type(item)}.")

        responses = await asyncio.gather(*requests)
        return pa.array([r.text for r in responses], type=pa.string())


[docs] @udf(return_dtype=DataType.string()) class SpeechToText(BatchColumnClassProtocol): __doc__ = f"""Speech to text processor for CPU, GPU and remote Http requests. Args: _local_model: The speech to text model name for CPU or GPU, available models: {AVAILABLE_MODELS} base_url: The base URL of the LLM server. model: The request model name. api_key: The request API key. batch_rows: The number of rows to request once. max_qps: The maximum number of requests per second. max_retries: The maximum number of retries per request in the event of failures. We retry with exponential backoff upto this specific maximum retries. **kwargs: Keyword arguments to pass to the `openai.AsyncClient.audio.transcriptions.create <https://github.com/openai/openai-python/blob/main/src/openai/resources/audio/transcriptions.py>`_ API. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import SpeechToText, from_items ds = from_items(["multilingual.mp3"]) ds = ds.with_column( "text", SpeechToText( # Local transcriptions model. "openai/whisper-large-v3", # For remote transcriptions requests. base_url="http://127.0.0.1:9997/v1", model="whisper1", ) # One IO worker for HTTP request, 10 CPU workers for local transcriptions. .options(num_workers={{"CPU": 10, "IO": 1}}) .with_column(col("item")), ) print(ds.take(2)) """ model: BatchColumnClassProtocol def __init__( self, _local_model: str | None = None, /, *, base_url: str | None = None, model: str | None = None, api_key: str = NOT_SET, max_qps: int | None = None, max_retries: int = 0, **kwargs: dict[str, Any], ): if _local_model is None and base_url is None: raise ValueError("Either _local_model or base_url must be specified.") if os.environ.get(IO_WORKER_ENV): if base_url is None: raise ValueError("base_url must be specified for IO worker.") if model is None: raise ValueError("model must be specified if base_url is specified.") self.model = SpeechToTextHttp( base_url=base_url, model=model, api_key=api_key, max_qps=max_qps, max_retries=max_retries, **kwargs ) else: if _local_model is None: raise TypeError("_local_model must be specified for CPU or GPU worker.") self.model = SpeechToTextLocal(_local_model) async def __call__(self, audio_array: pa.ChunkedArray) -> pa.Array: """Transcript the audio array to text array. Args: audio_array: The audio array. The type of array is either: - `str` that is either the filename of a local audio file, or a public URL address to download the audio file. The file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system. - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the same way. - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) Raw audio at the correct sampling rate (For example, 16K). Returns: The transcribed text array. """ if len(audio_array) == 0: return pa.array([], type=pa.string()) return await assert_type(self.model(audio_array.to_pylist()), CoroutineType)