from __future__ import annotations
import logging
from pathlib import Path
from string import Template
from typing import TYPE_CHECKING, cast
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
if TYPE_CHECKING:
import fasttext
import pyarrow as pa
else:
pa = lazy_import("pyarrow", rename="pa")
fasttext = lazy_import("fasttext")
logger = logging.getLogger("ray")
LanguageModel = {
"fasttext/lid.176.bin": {
"label": {"test", "all"},
"model_specs": {
"uri": {
"fasttext": {
"model_id": "fasttext/lid.176.bin",
"model_uri": "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin",
},
},
},
},
"fasttext/lid.176.ftz": {
"label": {"all"},
"model_specs": {
"uri": {
"fasttext": {
"model_id": "fasttext/lid.176.ftz",
"model_uri": "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz",
},
},
},
},
}
class LanguageModelSpec(ModelSpec):
pass
ALL_LANGUAGE_MODEL_SPECS = {k: LanguageModelSpec.model_validate(v) for k, v in LanguageModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_LANGUAGE_MODEL_SPECS.items() if v.label & {"all"}]
FASTTEXT_TOP_K = 10
def _patch_fasttext_for_numpy2() -> None:
"""Patch fasttext's internal numpy reference to be compatible with NumPy 2.x.
fasttext uses np.array(probs, copy=False) which raises ValueError in NumPy 2.x.
We replace the np module inside fasttext with a wrapper that overrides array()
to use np.asarray() instead, without touching the third-party source files.
"""
import fasttext.FastText as _ft_module
_original_np = _ft_module.np
class _NumpyCompat:
"""Thin wrapper around numpy that fixes the copy=False issue for NumPy 2.x."""
def __getattr__(self, name: str):
return getattr(_original_np, name)
def array(self, obj, *args, copy=None, **kwargs): # type: ignore[override]
if copy is False:
return _original_np.asarray(obj, *args, **kwargs)
if copy is None:
return _original_np.array(obj, *args, **kwargs)
return _original_np.array(obj, *args, copy=copy, **kwargs)
_ft_module.np = _NumpyCompat()
class FasttextDetection:
def __init__(self, _local_model: str = "fasttext/lid.176.bin"):
if _local_model not in ALL_LANGUAGE_MODEL_SPECS:
raise ValueError(f"Unsupported model: {_local_model}. Available models: {AVAILABLE_MODELS}")
model_spec = ALL_LANGUAGE_MODEL_SPECS[_local_model]
model_path = cache_model("uri", "fasttext", model_spec, None)
_patch_fasttext_for_numpy2()
self.model = fasttext.load_model(str(cast(Path, model_path)))
def predict(self, texts: list[str], k: int = FASTTEXT_TOP_K) -> list[list[tuple[str, float]]]:
"""Run fasttext prediction on a list of texts. Returns
the top-k (label, probability) pairs for each input text."""
labels_batch, probs_batch = self.model.predict(texts, k=k)
results = []
for labels, probs in zip(labels_batch, probs_batch):
pairs = [(lbl.replace("__label__", ""), float(prob)) for lbl, prob in zip(labels, probs)]
assert len(pairs) > 0, "fasttext should always return at least one label"
results.append(pairs)
return results
[docs]
@udf(return_dtype=DataType.float32())
class TextLanguageScore(BatchColumnClassProtocol):
__doc__ = Template("""Language score operator based on a fasttext model.
For each input text, returns the probability that the text belongs to the
specified ``lang``.
Args:
_local_model: fasttext model name. Default is ``"fasttext/lid.176.bin"``.
available models: $AVAILABLE_MODELS.
lang: Language code supported by fasttext. For details, see https://fasttext.cc/docs/en/language-identification.html.
You can also refer to the ISO 639 standard, e.g. ``en``, ``zh``.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import from_items
from xpark.dataset.processors.text_language_detector import TextLanguageScore
ds = from_items(["Hello world", "今天天气很好"])
ds = ds.with_column("en_score", TextLanguageScore(lang="en").options(num_workers={"CPU": 1}).with_column(col("item")))
print(ds.take(2))
""").safe_substitute(AVAILABLE_MODELS=AVAILABLE_MODELS)
def __init__(self, _local_model: str = "fasttext/lid.176.bin", lang: str = "en"):
self.lang = lang
self.detection = FasttextDetection(_local_model)
def __call__(self, batch: "pa.ChunkedArray") -> "pa.Array":
"""Score each text in the batch for the target language.
Args:
batch: A PyArrow ChunkedArray of string values.
Returns:
A PyArrow float32 Array where each element is the probability that
the corresponding text belongs to ``self.lang`` (0.0 if not in top-10).
"""
texts = batch.to_pylist()
results = self.detection.predict(texts)
scores: list[float] = []
for pairs in results:
scores.append(next((prob for lang, prob in pairs if lang == self.lang), 0.0))
return pa.array(scores, type=pa.float32())
[docs]
@udf(return_dtype=DataType.string())
class TextLanguageDetector(BatchColumnClassProtocol):
__doc__ = Template("""Language detection operator based on a fasttext model.
Identifies the language of each input text and returns the top-1 language.
Language code supported by fasttext. For details, see https://fasttext.cc/docs/en/language-identification.html.
You can also refer to the ISO 639 standard, e.g. ``en``, ``zh``.
Args:
_local_model: fasttext model name. Default is ``"fasttext/lid.176.bin"``.
available models: {AVAILABLE_MODELS}.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import from_items
from xpark.dataset.processors.text_language_detector import TextLanguageDetector
ds = from_items(["Hello world", "今天阳光明媚"])
ds = ds.with_column("lang", TextLanguageDetector().options(num_workers={"CPU": 1})(col("text")))
print(ds.take(2))
""").safe_substitute(AVAILABLE_MODELS=AVAILABLE_MODELS)
def __init__(self, _local_model: str = "fasttext/lid.176.bin"):
self.detection = FasttextDetection(_local_model)
def __call__(self, batch: "pa.ChunkedArray") -> "pa.Array":
results = self.detection.predict(batch.to_pylist(), k=1)
langs = [pairs[0][0] for pairs in results]
return pa.array(langs, type=pa.string())