# Portions of this code are adapted from:
# https://github.com/datajuicer/data-juicer/blob/main/data_juicer/ops/filter/flagged_words_filter.py
#
# The flagged words list is sourced from:
# - https://huggingface.co/spaces/huggingface/text-data-filtering
# - https://github.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words
# - https://github.com/datajuicer/data-juicer/blob/main/data_juicer/ops/filter/flagged_words_filter.py
from __future__ import annotations
import json
from itertools import chain
from string import Template
from typing import TYPE_CHECKING, List, Literal
from pydantic import PositiveInt
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import AssetSpec, cache_asset
from xpark.dataset.utils import text_tokenize
if TYPE_CHECKING:
import pyarrow as pa
else:
pa = lazy_import("pyarrow", rename="pa")
FlaggedWordsAsset = {
"data_juicer/flagged_words": {
"label": {"test", "all"},
"asset_specs": {
"uri": {
"file": {
"asset_id": "data_juicer/flagged_words",
"asset_uri": "https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/flagged_words.json",
},
},
},
},
}
ALL_FLAGGED_WORDS_ASSETS_SPECS = {k: AssetSpec.model_validate(v) for k, v in FlaggedWordsAsset.items()}
AVAILABLE_ASSETS = [k for k, v in ALL_FLAGGED_WORDS_ASSETS_SPECS.items() if v.label & {"all"}]
[docs]
@udf(return_dtype=DataType.float32())
class TextFlaggedWordRatio(BatchColumnClassProtocol):
__doc__ = Template("""Compute the ratio of flagged words in a text.
Tokenizes the input text and calculates the proportion of flagged words
relative to the total word count. Supports CJK and space-based tokenization,
as well as word-bag augmentation (combining adjacent tokens into new candidates).
Args:
asset: The asset to use for flagged words. Defaults to ``"data_juicer/flagged_words"``.
Available assets: $AVAILABLE_ASSETS.
asset_label: The label of the asset to use. Defaults to ``"en"``.
Supported labels: ``ar``, ``ca``, ``cs``, ``da``, ``de``, ``en``, ``eo``,
``es``, ``eu``, ``fa``, ``fi``, ``fil``, ``fr``, ``fr-CA-u-sd-caqc``,
``ha``, ``hi``, ``hu``, ``id``, ``it``, ``ja``, ``kab``, ``kn``, ``ko``,
``ml``, ``mr``, ``nl``, ``no``, ``pl``, ``pt``, ``ru``, ``sv``, ``ta``,
``te``, ``th``, ``tlh``, ``tr``, ``vi``, ``zh``.
Pass ``"all"`` to use the aggregated flagged words across all languages.
custom_flagged_words_list: A user-supplied list of flagged words. When this list
is non-empty, the ``asset`` and ``asset_label`` parameters are ignored.
tokenizer: Tokenization strategy. Supports ``"cjk"`` (mixed CJK + whitespace
splitting) and ``"space"`` (whitespace-only splitting). Defaults to
``"cjk"``. Support for ``"jieba"`` may be added in the future.
use_words_aug: Whether to enable word-bag augmentation. When enabled, adjacent
tokens are joined using ``words_aug_join_char`` for each window size in
``words_aug_group_sizes`` and added to the candidate set. Defaults to
``False``.
words_aug_group_sizes: Window sizes used for word-bag augmentation.
Defaults to ``[2]``.
words_aug_join_char: Character used to join adjacent tokens during augmentation.
Defaults to ``""``.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import from_items
from xpark.dataset.processors.text_flagged_word_ratio import TextFlaggedWordRatio
ds = from_items(["This is a bad text", "Hello world"])
ds = ds.with_column(
"flagged_ratio",
TextFlaggedWordRatio(
custom_flagged_words_list=["bad"],
tokenizer="space",
)
.options(num_workers={"CPU": 1})
.with_column(col("item")),
)
print(ds.take(2))
""").safe_substitute(AVAILABLE_ASSETS=AVAILABLE_ASSETS)
def __init__(
self,
asset: str = "data_juicer/flagged_words",
asset_label: str = "en",
custom_flagged_words_list: List[str] | None = None,
tokenizer: Literal["cjk", "space"] = "cjk",
use_words_aug: bool = False,
words_aug_group_sizes: List[PositiveInt] | None = None,
words_aug_join_char: str = "",
):
if custom_flagged_words_list:
self.flagged_words_set: set[str] = {w.lower() for w in custom_flagged_words_list}
else:
self.flagged_words_set = set(self.load_flagged_words_list(asset, asset_label))
if len(self.flagged_words_set) == 0:
raise ValueError(
f"Flagged words list is empty for asset: {asset}, label: {asset_label}, custom_flagged_words_list: {custom_flagged_words_list}."
)
# Tokenization mode: "cjk" (mixed CJK + whitespace) or "space" (whitespace-only);
# TODO(anthonycai) "jieba" may be supported in the future
self.tokenizer = tokenizer
self.use_words_aug = use_words_aug
if words_aug_group_sizes is None:
words_aug_group_sizes = [2]
self.words_aug_group_sizes = set(words_aug_group_sizes)
self.words_aug_join_char = words_aug_join_char
def load_flagged_words_list(self, asset: str, asset_label: str) -> set[str]:
asset_spec = ALL_FLAGGED_WORDS_ASSETS_SPECS[asset]
file = cache_asset("uri", "file", asset_spec, max_retries=3)
with open(file, "r") as f:
flagged_words = json.load(f)
if asset_label == "all":
return {word for words in flagged_words.values() for word in words}
else:
return set(flagged_words.get(asset_label, []))
def __call__(self, texts: pa.ChunkedArray) -> pa.Array:
cjk = self.tokenizer == "cjk"
token_lists = text_tokenize(texts, cjk=cjk)
results: list[float] = []
for token_scalar in token_lists:
words = [w.lower() for w in token_scalar.as_py() or [] if w]
if not words:
results.append(0.0)
continue
# Word-bag augmentation: combine adjacent tokens by window size and add to the candidate set
if self.use_words_aug:
augmented: list[str] = list(words)
# Using chain/map/zip is ~3x faster than a nested for-loop
augmented.extend(
chain.from_iterable(
map(self.words_aug_join_char.join, zip(*(words[i:] for i in range(group_size))))
for group_size in self.words_aug_group_sizes
)
)
candidate_words = augmented
else:
candidate_words = words
# Compute the flagged-word ratio; the denominator is always the original token count
flagged_count = sum(1 for w in candidate_words if w in self.flagged_words_set)
ratio = flagged_count / len(words)
ratio = min(ratio, 1.0)
results.append(float(ratio))
return pa.array(results, type=pa.float32())