Source code for xpark.dataset.processors.image_aesthetic_score

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import pyarrow as pa

from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
from xpark.dataset.utils import read_image

if TYPE_CHECKING:
    import aesthetics_predictor
    import torch
    import transformers
else:
    torch = lazy_import("torch")
    aesthetics_predictor = lazy_import("aesthetics_predictor")
    transformers = lazy_import("transformers")

ImageAestheticModel = {
    "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE": {
        "label": {"test", "all"},
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE",
                    "model_revision": "684098de3856fa4678bf800efc05635de5b6cde5",
                    "quantizations": [None],
                },
            },
        },
    },
    "shunk031/aesthetics-predictor-v1-vit-large-patch14": {
        "label": {"all"},
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "shunk031/aesthetics-predictor-v1-vit-large-patch14",
                    "model_revision": "74fd3ab002ca9252b5593f079514e6a1eaa132f9",
                    "quantizations": [None],
                },
            },
        },
    },
}


class ImageAestheticModelSpec(ModelSpec):
    pass


ALL_IMAGE_AESTHETIC_MODEL = {k: ImageAestheticModelSpec.model_validate(v) for k, v in ImageAestheticModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_IMAGE_AESTHETIC_MODEL.items() if v.label & {"all"}]


[docs] @udf(return_dtype=DataType.float32()) class ImageAestheticScore(BatchColumnClassProtocol): __doc__ = f"""Image aesthetic score calculation processor for CPU, GPU. Image aesthetic score is a value between 0 and 10, with higher scores indicating better image quality. Args: _local_model: The CLIP base aesthetic model name for CPU or GPU. default model is "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE" available models: {AVAILABLE_MODELS} normalized: Whether to normalize the score to [0, 1], default is False Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import ImageAestheticScore, from_items import numpy as np ds = from_items([ {{"image": np.random.randint(0, 255, (256, 256, 3)).astype(np.uint8), "path": "test.jpg"}} ]) ds = ds.with_column( "image_score", ImageAestheticScore() .options(num_workers={{"CPU": 4}}, batch_size=1) .with_column(col("image")), ) print(ds.take(1)) """ def __init__( self, _local_model: str = "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE", normalized: bool = False, ): if _local_model not in ALL_IMAGE_AESTHETIC_MODEL: raise ValueError(f"model is not supported. {_local_model}") self.normalized = normalized model_spec = ALL_IMAGE_AESTHETIC_MODEL[_local_model] model_path = cache_model("huggingface", "pytorch", model_spec, None) torch.set_float32_matmul_precision("high") torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 if _local_model == "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE": model_class = aesthetics_predictor.AestheticsPredictorV2Linear else: model_class = aesthetics_predictor.AestheticsPredictorV1 self.model = model_class.from_pretrained(model_path, device_map="auto") # Enable static cache and compile the forward pass self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True) self.model.to(torch_dtype) self.processor = transformers.CLIPProcessor.from_pretrained(model_path) async def __call__(self, images: pa.ChunkedArray) -> pa.Array: storage_options = DatasetContext.get_current().storage_options tasks = [] for image in images: tasks.append(read_image(image, source_mode="RGB", **storage_options)) pil_images = await asyncio.gather(*tasks) inputs = await asyncio.to_thread(self.processor, images=pil_images, return_tensors="pt") inputs.to(self.model.device) with torch.no_grad(): outputs = self.model(**inputs) if self.normalized: aesthetics_scores = outputs.logits / 10.0 else: aesthetics_scores = outputs.logits return pa.array(aesthetics_scores.detach().cpu().flatten().numpy(), type=pa.float32())