Source code for xpark.dataset.processors.image_nsfw_score

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import pyarrow as pa

from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model, prepare_huggingface_model
from xpark.dataset.utils import read_image

if TYPE_CHECKING:
    import torch
else:
    torch = lazy_import("torch")

ImageNSFWModel = {
    "Falconsai/nsfw_image_detection": {
        "label": {"test", "all"},
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "Falconsai/nsfw_image_detection",
                    "model_revision": "04367978d3474804ab1a00a9bd6548b741764069",
                    "quantizations": [None],
                },
            },
        },
    }
}


class ImageNSWFModelSpec(ModelSpec):
    pass


ALL_IMAGE_NSFW_MODEL = {k: ImageNSWFModelSpec.model_validate(v) for k, v in ImageNSFWModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_IMAGE_NSFW_MODEL.items() if v.label & {"all"}]


[docs] @udf(return_dtype=DataType.float32()) class ImageNSFWScore(BatchColumnClassProtocol): __doc__ = f"""Image NSFW score calculation processor for CPU, GPU Args: _local_model: The nsfw model name for CPU or GPU. default is "Falconsai/nsfw_image_detection" available models {AVAILABLE_MODELS} Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import ImageNSFWScore, from_items import numpy as np ds = from_items([ {{"image": np.random.randint(0, 255, (256, 256, 3)).astype(np.uint8), "path": "test.jpg"}} ]) ds = ds.with_column( "image_nsfw_score", ImageNSFWScore().options(num_workers={{"CPU": 4}}, batch_size=1).with_column(col("item")), ) print(ds.take(1)) """ def __init__(self, _local_model: str = "Falconsai/nsfw_image_detection"): if _local_model not in ALL_IMAGE_NSFW_MODEL: raise ValueError(f"model is not supported. {_local_model}") model_spec = ALL_IMAGE_NSFW_MODEL[_local_model] model_path = cache_model("huggingface", "pytorch", model_spec, None) self.model, self.processor = prepare_huggingface_model(model_path, device_map="auto") async def __call__(self, images: pa.ChunkedArray) -> pa.Array: storage_options = DatasetContext.get_current().storage_options tasks = [] for image in images: tasks.append(read_image(image, source_mode="RGB", **storage_options)) pil_images = await asyncio.gather(*tasks) inputs = await asyncio.to_thread(self.processor, images=pil_images, return_tensors="pt") inputs.to(self.model.device) with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits nsfw_scores = torch.softmax(logits, dim=-1)[:, 1].detach().cpu().numpy() return pa.array(nsfw_scores, type=pa.float32())