from __future__ import annotations
import logging
from functools import partial
from typing import TYPE_CHECKING, Any, Iterable
from xpark.dataset.constants import NOT_SET
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.utils import LabelSpec, LLMChatCompletions, _format_labels, skip_empty_texts
if TYPE_CHECKING:
import pyarrow as pa
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
else:
openai = lazy_import("openai")
pa = lazy_import("pyarrow", rename="pa")
logger = logging.getLogger("ray")
# prompt modify from https://github.com/apache/doris/blob/4.0.2-rc01/be/src/vec/functions/ai/ai_mask.h
SYSTEM_ROLE_PROMPT = (
"You are a data privacy assistant. You will identify and mask sensitive information "
"in the user's input according to the provided labels. "
"The user will provide `Labels` and `Text`. For each label, you must hide all related "
'information in the Text and replace it with "[MASKED]". Only return the text after masking. '
"Labels are organized into multiple lines, with each line containing two fields: label and description, separated by a space. "
"They represent the label name and the description of the content respectively. The description may be empty."
)
PROMPT_TEMPLATE = """
Labels:
{}
Text:
{}
"""
def build_prompt(specs: list[LabelSpec], text: str) -> Iterable[ChatCompletionMessageParam]:
from openai.types.chat.chat_completion_message_param import (
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
return [
ChatCompletionSystemMessageParam(role="system", content=SYSTEM_ROLE_PROMPT),
ChatCompletionUserMessageParam(role="user", content=PROMPT_TEMPLATE.format(_format_labels(specs), str(text))),
]
[docs]
@udf(return_dtype=DataType.string())
class TextMask(BatchColumnClassProtocol):
"""TextMask processor replaces sensitive information in the original text with [MASKED] according to the labels.
Args:
labels: The labels to mask. Accepts two formats:
- ``list[str]``: plain label names, e.g. ``[\"email\", \"phone_num\"]``
- ``list[dict]``: dicts with ``\"label\"`` (required) and ``\"description\"`` (optional),
e.g. ``[{\"label\": \"email\", \"description\": \"email address\"}]``
Descriptions are injected into the prompt to guide the model when label names alone
are ambiguous.
base_url: The base URL of the LLM server.
model: The request model name.
api_key: The request API key.
max_qps: The maximum number of requests per second.
max_retries: The maximum number of retries per request in the event of failures.
We retry with exponential backoff upto this specific maximum retries.
fallback_response: The response value to return when the LLM request fails.
If set to None, the exception will be raised instead.
**kwargs: Keyword arguments to pass to the `openai.AsyncClient.chat.completions.create
<https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions/completions.py>`_ API.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import TextMask, from_items
ds = from_items(["My email is rarity@example.com and my phone is 123-456-7890"])
# Plain labels
ds = ds.with_column(
"masked_text",
TextMask(
["email", "phone_num"],
model="deepseek-v3-0324",
base_url=os.getenv("LLM_ENDPOINT"),
api_key=os.getenv("LLM_API_KEY"),
)
.options(num_workers={"IO": 1}, batch_size=1)
.with_column(col("item")),
)
# Labels with descriptions
ds = ds.with_column(
"masked_text",
TextMask(
[
{"label": "email", "description": "email address"},
{"label": "phone_num", "description": "phone number"},
],
model="deepseek-v3-0324",
base_url=os.getenv("LLM_ENDPOINT"),
api_key=os.getenv("LLM_API_KEY"),
)
.options(num_workers={"IO": 1}, batch_size=1)
.with_column(col("item")),
)
print(ds.take_all())
"""
def __init__(
self,
labels: list[str | dict[str, str]],
/,
*,
base_url: str,
model: str,
api_key: str = NOT_SET,
max_qps: int | None = None,
max_retries: int = 0,
fallback_response: str | None = None,
**kwargs: dict[str, Any],
):
if not labels:
raise ValueError("labels must not be empty")
self.specs = [LabelSpec.model_validate({"label": item} if isinstance(item, str) else item) for item in labels]
self.model = LLMChatCompletions(
base_url=base_url,
model=model,
api_key=api_key,
max_qps=max_qps,
max_retries=max_retries,
fallback_response=fallback_response,
response_format="text",
**kwargs,
)
@skip_empty_texts
async def __call__(self, texts: pa.ChunkedArray) -> pa.Array:
return await self.model.batch_generate(
texts=texts,
build_prompt=partial(build_prompt, self.specs),
)