Source code for xpark.dataset.processors.text_perplexity

from __future__ import annotations

import asyncio
import logging
import math
from string import Template
from typing import TYPE_CHECKING

from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model

if TYPE_CHECKING:
    import pyarrow as pa
    import torch
    import transformers
else:
    pa = lazy_import("pyarrow", rename="pa")
    torch = lazy_import("torch")
    transformers = lazy_import("transformers")

logger = logging.getLogger("ray")

PerplexityModel = {
    "Qwen/Qwen2.5-0.5B": {
        "label": {"test", "all"},
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Qwen/Qwen2.5-0.5B",
                    "model_revision": "060db6499f32faf8b98477b0a26969ef7d8b9987",
                    "quantizations": [None],
                },
            },
        },
    },
    "Qwen/Qwen3.5-0.8B": {
        "label": {"all"},
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Qwen/Qwen3.5-0.8B",
                    "model_revision": "2fc06364715b967f1860aea9cf38778875588b17",
                    "quantizations": [None],
                },
            },
        },
    },
    "Qwen/Qwen3.5-2B": {
        "label": {"all"},
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Qwen/Qwen3.5-2B",
                    "model_revision": "15852e8c16360a2fea060d615a32b45270f8a8fc",
                    "quantizations": [None],
                },
            },
        },
    },
    "Qwen/Qwen3.5-4B": {
        "label": {"all"},
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Qwen/Qwen3.5-4B",
                    "model_revision": "851bf6e806efd8d0a36b00ddf55e13ccb7b8cd0a",
                    "quantizations": [None],
                },
            },
        },
    },
}


class PerplexityModelSpec(ModelSpec):
    pass


ALL_PERPLEXITY_MODELS = {k: PerplexityModelSpec.model_validate(v) for k, v in PerplexityModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_PERPLEXITY_MODELS.items() if v.label & {"all"}]


[docs] @udf(return_dtype=DataType.float32()) class TextPerplexity(BatchColumnClassProtocol): __doc__ = Template("""Computes the perplexity of text using a language model to evaluate fluency and naturalness. Perplexity is a metric that measures how well a language model predicts a given text. A lower perplexity indicates more natural and fluent text, while a higher perplexity suggests the text is harder to predict (e.g., noisy, garbled, or low-quality content). Args: _local_model: Name of the language model used for perplexity computation. Available models: $AVAILABLE_MODELS max_length: Maximum token length for input truncation. Text exceeding this length will be truncated. Defaults to None, which uses the model's maximum supported length. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import TextPerplexity, from_items ds = from_items([ "The quick brown fox jumps over the lazy dog.", "asdf qwer zxcv random noise text 1234", ]) ds = ds.with_column( "perplexity", TextPerplexity() .options(num_workers={"CPU": 4}, batch_size=8) .with_column(col("item")), ) print(ds.take(2)) """).safe_substitute(AVAILABLE_MODELS=AVAILABLE_MODELS) def __init__( self, _local_model: str = "Qwen/Qwen2.5-0.5B", /, *, max_length: int | None = None, ): if _local_model not in ALL_PERPLEXITY_MODELS: raise ValueError(f"Unsupported model: {_local_model}. Available models: {AVAILABLE_MODELS}") self.max_length = max_length model_spec = ALL_PERPLEXITY_MODELS[_local_model] model_path = cache_model("huggingface", "pytorch", model_spec, None) torch.set_float32_matmul_precision("high") self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) self.model = transformers.AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", ) self.model.eval() async def __call__(self, texts: pa.ChunkedArray) -> pa.Array: text_list = texts.to_pylist() results = await self._batch_compute(text_list) return pa.array(results, type=pa.float32()) def batch_tokenizer(self, text_list: list[str]) -> transformers.BatchEncoding: encodings = self.tokenizer( text_list, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True, ) return encodings def batch_forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: with torch.no_grad(): outputs = self.model(input_ids, attention_mask=attention_mask) logits = outputs.logits return logits async def _batch_compute(self, text_list: list[str]) -> list[float]: import torch.nn.functional as F encodings = await asyncio.to_thread(self.batch_tokenizer, text_list) input_ids = encodings.input_ids.to(self.model.device) attention_mask = encodings.attention_mask.to(self.model.device) seq_lengths = attention_mask.sum(dim=1) logits = await asyncio.to_thread(self.batch_forward, input_ids, attention_mask) results = [] for i in range(len(text_list)): if seq_lengths[i] <= 1: results.append(float("inf")) continue # A causal LM predicts the *next* token at each position, so logits and # labels must be shifted by one to align predictions with their targets: # - shift_logits: logits at positions 0~(n-2), each predicting the next token # - shift_labels: token ids at positions 1~(n-1), the ground-truth next tokens # - shift_mask: attention mask at positions 1~(n-1), to exclude padding tokens shift_logits = logits[i, :-1, :].contiguous() shift_labels = input_ids[i, 1:].contiguous() shift_mask = attention_mask[i, 1:].contiguous() # Compute per-token cross-entropy loss, i.e. the negative log-likelihood # -log P(token_t | token_<t) for each position. # reduction="none" keeps individual loss values without aggregation. loss_per_token = F.cross_entropy(shift_logits, shift_labels, reduction="none") # Mask out padding positions and average over real tokens only: # mean NLL = sum(loss * mask) / sum(mask) # Perplexity PPL = exp(mean NLL). # A lower PPL means the model is more confident about the text, # indicating it is more fluent and natural. masked_loss = (loss_per_token * shift_mask).sum() / shift_mask.sum() results.append(math.exp(masked_loss.item())) return results