Source code for xpark.dataset.processors.image_nsfw_score
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
import pyarrow as pa
from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model, prepare_huggingface_model
from xpark.dataset.utils import read_image
if TYPE_CHECKING:
import torch
else:
torch = lazy_import("torch")
ImageNSFWModel = {
"Falconsai/nsfw_image_detection": {
"label": {"test", "all"},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "Falconsai/nsfw_image_detection",
"model_revision": "04367978d3474804ab1a00a9bd6548b741764069",
"quantizations": [None],
},
},
},
}
}
class ImageNSWFModelSpec(ModelSpec):
pass
ALL_IMAGE_NSFW_MODEL = {k: ImageNSWFModelSpec.model_validate(v) for k, v in ImageNSFWModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_IMAGE_NSFW_MODEL.items() if v.label & {"all"}]
[docs]
@udf(return_dtype=DataType.float32())
class ImageNSFWScore(BatchColumnClassProtocol):
__doc__ = f"""Image NSFW score calculation processor for CPU, GPU
Args:
_local_model: The nsfw model name for CPU or GPU. default is "Falconsai/nsfw_image_detection"
available models {AVAILABLE_MODELS}
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import ImageNSFWScore, from_items
import numpy as np
ds = from_items([
{{"image": np.random.randint(0, 255, (256, 256, 3)).astype(np.uint8), "path": "test.jpg"}}
])
ds = ds.with_column(
"image_nsfw_score",
ImageNSFWScore().options(num_workers={{"CPU": 4}}, batch_size=1).with_column(col("item")),
)
print(ds.take(1))
"""
def __init__(self, _local_model: str = "Falconsai/nsfw_image_detection"):
if _local_model not in ALL_IMAGE_NSFW_MODEL:
raise ValueError(f"model is not supported. {_local_model}")
model_spec = ALL_IMAGE_NSFW_MODEL[_local_model]
model_path = cache_model("huggingface", "pytorch", model_spec, None)
self.model, self.processor = prepare_huggingface_model(model_path, device_map="auto")
async def __call__(self, images: pa.ChunkedArray) -> pa.Array:
storage_options = DatasetContext.get_current().storage_options
tasks = []
for image in images:
tasks.append(read_image(image, source_mode="RGB", **storage_options))
pil_images = await asyncio.gather(*tasks)
inputs = await asyncio.to_thread(self.processor, images=pil_images, return_tensors="pt")
inputs.to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
nsfw_scores = torch.softmax(logits, dim=-1)[:, 1].detach().cpu().numpy()
return pa.array(nsfw_scores, type=pa.float32())