from __future__ import annotations
import asyncio
import logging
import math
from string import Template
from typing import TYPE_CHECKING
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
if TYPE_CHECKING:
import pyarrow as pa
import torch
import transformers
else:
pa = lazy_import("pyarrow", rename="pa")
torch = lazy_import("torch")
transformers = lazy_import("transformers")
logger = logging.getLogger("ray")
PerplexityModel = {
"Qwen/Qwen2.5-0.5B": {
"label": {"test", "all"},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "Qwen/Qwen2.5-0.5B",
"model_revision": "060db6499f32faf8b98477b0a26969ef7d8b9987",
"quantizations": [None],
},
},
},
},
"Qwen/Qwen3.5-0.8B": {
"label": {"all"},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "Qwen/Qwen3.5-0.8B",
"model_revision": "2fc06364715b967f1860aea9cf38778875588b17",
"quantizations": [None],
},
},
},
},
"Qwen/Qwen3.5-2B": {
"label": {"all"},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "Qwen/Qwen3.5-2B",
"model_revision": "15852e8c16360a2fea060d615a32b45270f8a8fc",
"quantizations": [None],
},
},
},
},
"Qwen/Qwen3.5-4B": {
"label": {"all"},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "Qwen/Qwen3.5-4B",
"model_revision": "851bf6e806efd8d0a36b00ddf55e13ccb7b8cd0a",
"quantizations": [None],
},
},
},
},
}
class PerplexityModelSpec(ModelSpec):
pass
ALL_PERPLEXITY_MODELS = {k: PerplexityModelSpec.model_validate(v) for k, v in PerplexityModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_PERPLEXITY_MODELS.items() if v.label & {"all"}]
[docs]
@udf(return_dtype=DataType.float32())
class TextPerplexity(BatchColumnClassProtocol):
__doc__ = Template("""Computes the perplexity of text using a language model to evaluate fluency and naturalness.
Perplexity is a metric that measures how well a language model predicts a given text.
A lower perplexity indicates more natural and fluent text, while a higher perplexity
suggests the text is harder to predict (e.g., noisy, garbled, or low-quality content).
Args:
_local_model: Name of the language model used for perplexity computation.
Available models: $AVAILABLE_MODELS
max_length: Maximum token length for input truncation. Text exceeding this length
will be truncated. Defaults to None, which uses the model's maximum supported length.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import TextPerplexity, from_items
ds = from_items([
"The quick brown fox jumps over the lazy dog.",
"asdf qwer zxcv random noise text 1234",
])
ds = ds.with_column(
"perplexity",
TextPerplexity()
.options(num_workers={"CPU": 4}, batch_size=8)
.with_column(col("item")),
)
print(ds.take(2))
""").safe_substitute(AVAILABLE_MODELS=AVAILABLE_MODELS)
def __init__(
self,
_local_model: str = "Qwen/Qwen2.5-0.5B",
/,
*,
max_length: int | None = None,
):
if _local_model not in ALL_PERPLEXITY_MODELS:
raise ValueError(f"Unsupported model: {_local_model}. Available models: {AVAILABLE_MODELS}")
self.max_length = max_length
model_spec = ALL_PERPLEXITY_MODELS[_local_model]
model_path = cache_model("huggingface", "pytorch", model_spec, None)
torch.set_float32_matmul_precision("high")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
self.model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
)
self.model.eval()
async def __call__(self, texts: pa.ChunkedArray) -> pa.Array:
text_list = texts.to_pylist()
results = await self._batch_compute(text_list)
return pa.array(results, type=pa.float32())
def batch_tokenizer(self, text_list: list[str]) -> transformers.BatchEncoding:
encodings = self.tokenizer(
text_list,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True,
)
return encodings
def batch_forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
return logits
async def _batch_compute(self, text_list: list[str]) -> list[float]:
import torch.nn.functional as F
encodings = await asyncio.to_thread(self.batch_tokenizer, text_list)
input_ids = encodings.input_ids.to(self.model.device)
attention_mask = encodings.attention_mask.to(self.model.device)
seq_lengths = attention_mask.sum(dim=1)
logits = await asyncio.to_thread(self.batch_forward, input_ids, attention_mask)
results = []
for i in range(len(text_list)):
if seq_lengths[i] <= 1:
results.append(float("inf"))
continue
# A causal LM predicts the *next* token at each position, so logits and
# labels must be shifted by one to align predictions with their targets:
# - shift_logits: logits at positions 0~(n-2), each predicting the next token
# - shift_labels: token ids at positions 1~(n-1), the ground-truth next tokens
# - shift_mask: attention mask at positions 1~(n-1), to exclude padding tokens
shift_logits = logits[i, :-1, :].contiguous()
shift_labels = input_ids[i, 1:].contiguous()
shift_mask = attention_mask[i, 1:].contiguous()
# Compute per-token cross-entropy loss, i.e. the negative log-likelihood
# -log P(token_t | token_<t) for each position.
# reduction="none" keeps individual loss values without aggregation.
loss_per_token = F.cross_entropy(shift_logits, shift_labels, reduction="none")
# Mask out padding positions and average over real tokens only:
# mean NLL = sum(loss * mask) / sum(mask)
# Perplexity PPL = exp(mean NLL).
# A lower PPL means the model is more confident about the text,
# indicating it is more fluent and natural.
masked_loss = (loss_per_token * shift_mask).sum() / shift_mask.sum()
results.append(math.exp(masked_loss.item()))
return results