Source code for xpark.dataset.processors.text_language_detector

from __future__ import annotations

import logging
from pathlib import Path
from string import Template
from typing import TYPE_CHECKING, cast

from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model

if TYPE_CHECKING:
    import fasttext
    import pyarrow as pa
else:
    pa = lazy_import("pyarrow", rename="pa")
    fasttext = lazy_import("fasttext")

logger = logging.getLogger("ray")

LanguageModel = {
    "fasttext/lid.176.bin": {
        "label": {"test", "all"},
        "model_specs": {
            "uri": {
                "fasttext": {
                    "model_id": "fasttext/lid.176.bin",
                    "model_uri": "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin",
                },
            },
        },
    },
    "fasttext/lid.176.ftz": {
        "label": {"all"},
        "model_specs": {
            "uri": {
                "fasttext": {
                    "model_id": "fasttext/lid.176.ftz",
                    "model_uri": "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz",
                },
            },
        },
    },
}


class LanguageModelSpec(ModelSpec):
    pass


ALL_LANGUAGE_MODEL_SPECS = {k: LanguageModelSpec.model_validate(v) for k, v in LanguageModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_LANGUAGE_MODEL_SPECS.items() if v.label & {"all"}]
FASTTEXT_TOP_K = 10


def _patch_fasttext_for_numpy2() -> None:
    """Patch fasttext's internal numpy reference to be compatible with NumPy 2.x.

    fasttext uses np.array(probs, copy=False) which raises ValueError in NumPy 2.x.
    We replace the np module inside fasttext with a wrapper that overrides array()
    to use np.asarray() instead, without touching the third-party source files.
    """
    import fasttext.FastText as _ft_module

    _original_np = _ft_module.np

    class _NumpyCompat:
        """Thin wrapper around numpy that fixes the copy=False issue for NumPy 2.x."""

        def __getattr__(self, name: str):
            return getattr(_original_np, name)

        def array(self, obj, *args, copy=None, **kwargs):  # type: ignore[override]
            if copy is False:
                return _original_np.asarray(obj, *args, **kwargs)
            if copy is None:
                return _original_np.array(obj, *args, **kwargs)
            return _original_np.array(obj, *args, copy=copy, **kwargs)

    _ft_module.np = _NumpyCompat()


class FasttextDetection:
    def __init__(self, _local_model: str = "fasttext/lid.176.bin"):
        if _local_model not in ALL_LANGUAGE_MODEL_SPECS:
            raise ValueError(f"Unsupported model: {_local_model}. Available models: {AVAILABLE_MODELS}")
        model_spec = ALL_LANGUAGE_MODEL_SPECS[_local_model]
        model_path = cache_model("uri", "fasttext", model_spec, None)
        _patch_fasttext_for_numpy2()
        self.model = fasttext.load_model(str(cast(Path, model_path)))

    def predict(self, texts: list[str], k: int = FASTTEXT_TOP_K) -> list[list[tuple[str, float]]]:
        """Run fasttext prediction on a list of texts. Returns
        the top-k (label, probability) pairs for each input text."""
        labels_batch, probs_batch = self.model.predict(texts, k=k)
        results = []
        for labels, probs in zip(labels_batch, probs_batch):
            pairs = [(lbl.replace("__label__", ""), float(prob)) for lbl, prob in zip(labels, probs)]
            assert len(pairs) > 0, "fasttext should always return at least one label"
            results.append(pairs)
        return results


[docs] @udf(return_dtype=DataType.float32()) class TextLanguageScore(BatchColumnClassProtocol): __doc__ = Template("""Language score operator based on a fasttext model. For each input text, returns the probability that the text belongs to the specified ``lang``. Args: _local_model: fasttext model name. Default is ``"fasttext/lid.176.bin"``. available models: $AVAILABLE_MODELS. lang: Language code supported by fasttext. For details, see https://fasttext.cc/docs/en/language-identification.html. You can also refer to the ISO 639 standard, e.g. ``en``, ``zh``. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import from_items from xpark.dataset.processors.text_language_detector import TextLanguageScore ds = from_items(["Hello world", "今天天气很好"]) ds = ds.with_column("en_score", TextLanguageScore(lang="en").options(num_workers={"CPU": 1}).with_column(col("item"))) print(ds.take(2)) """).safe_substitute(AVAILABLE_MODELS=AVAILABLE_MODELS) def __init__(self, _local_model: str = "fasttext/lid.176.bin", lang: str = "en"): self.lang = lang self.detection = FasttextDetection(_local_model) def __call__(self, batch: "pa.ChunkedArray") -> "pa.Array": """Score each text in the batch for the target language. Args: batch: A PyArrow ChunkedArray of string values. Returns: A PyArrow float32 Array where each element is the probability that the corresponding text belongs to ``self.lang`` (0.0 if not in top-10). """ texts = batch.to_pylist() results = self.detection.predict(texts) scores: list[float] = [] for pairs in results: scores.append(next((prob for lang, prob in pairs if lang == self.lang), 0.0)) return pa.array(scores, type=pa.float32())
[docs] @udf(return_dtype=DataType.string()) class TextLanguageDetector(BatchColumnClassProtocol): __doc__ = Template("""Language detection operator based on a fasttext model. Identifies the language of each input text and returns the top-1 language. Language code supported by fasttext. For details, see https://fasttext.cc/docs/en/language-identification.html. You can also refer to the ISO 639 standard, e.g. ``en``, ``zh``. Args: _local_model: fasttext model name. Default is ``"fasttext/lid.176.bin"``. available models: {AVAILABLE_MODELS}. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import from_items from xpark.dataset.processors.text_language_detector import TextLanguageDetector ds = from_items(["Hello world", "今天阳光明媚"]) ds = ds.with_column("lang", TextLanguageDetector().options(num_workers={"CPU": 1})(col("text"))) print(ds.take(2)) """).safe_substitute(AVAILABLE_MODELS=AVAILABLE_MODELS) def __init__(self, _local_model: str = "fasttext/lid.176.bin"): self.detection = FasttextDetection(_local_model) def __call__(self, batch: "pa.ChunkedArray") -> "pa.Array": results = self.detection.predict(batch.to_pylist(), k=1) langs = [pairs[0][0] for pairs in results] return pa.array(langs, type=pa.string())