from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
import pyarrow as pa
from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
from xpark.dataset.utils import read_image
if TYPE_CHECKING:
import aesthetics_predictor
import torch
import transformers
else:
torch = lazy_import("torch")
aesthetics_predictor = lazy_import("aesthetics_predictor")
transformers = lazy_import("transformers")
ImageAestheticModel = {
"shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE": {
"label": {"test", "all"},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE",
"model_revision": "684098de3856fa4678bf800efc05635de5b6cde5",
"quantizations": [None],
},
},
},
},
"shunk031/aesthetics-predictor-v1-vit-large-patch14": {
"label": {"all"},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "shunk031/aesthetics-predictor-v1-vit-large-patch14",
"model_revision": "74fd3ab002ca9252b5593f079514e6a1eaa132f9",
"quantizations": [None],
},
},
},
},
}
class ImageAestheticModelSpec(ModelSpec):
pass
ALL_IMAGE_AESTHETIC_MODEL = {k: ImageAestheticModelSpec.model_validate(v) for k, v in ImageAestheticModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_IMAGE_AESTHETIC_MODEL.items() if v.label & {"all"}]
[docs]
@udf(return_dtype=DataType.float32())
class ImageAestheticScore(BatchColumnClassProtocol):
__doc__ = f"""Image aesthetic score calculation processor for CPU, GPU. Image aesthetic score is a value
between 0 and 10, with higher scores indicating better image quality.
Args:
_local_model: The CLIP base aesthetic model name for CPU or GPU.
default model is "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE"
available models: {AVAILABLE_MODELS}
normalized: Whether to normalize the score to [0, 1], default is False
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import ImageAestheticScore, from_items
import numpy as np
ds = from_items([
{{"image": np.random.randint(0, 255, (256, 256, 3)).astype(np.uint8), "path": "test.jpg"}}
])
ds = ds.with_column(
"image_score",
ImageAestheticScore()
.options(num_workers={{"CPU": 4}}, batch_size=1)
.with_column(col("image")),
)
print(ds.take(1))
"""
def __init__(
self,
_local_model: str = "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE",
normalized: bool = False,
):
if _local_model not in ALL_IMAGE_AESTHETIC_MODEL:
raise ValueError(f"model is not supported. {_local_model}")
self.normalized = normalized
model_spec = ALL_IMAGE_AESTHETIC_MODEL[_local_model]
model_path = cache_model("huggingface", "pytorch", model_spec, None)
torch.set_float32_matmul_precision("high")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
if _local_model == "shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE":
model_class = aesthetics_predictor.AestheticsPredictorV2Linear
else:
model_class = aesthetics_predictor.AestheticsPredictorV1
self.model = model_class.from_pretrained(model_path, device_map="auto")
# Enable static cache and compile the forward pass
self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True)
self.model.to(torch_dtype)
self.processor = transformers.CLIPProcessor.from_pretrained(model_path)
async def __call__(self, images: pa.ChunkedArray) -> pa.Array:
storage_options = DatasetContext.get_current().storage_options
tasks = []
for image in images:
tasks.append(read_image(image, source_mode="RGB", **storage_options))
pil_images = await asyncio.gather(*tasks)
inputs = await asyncio.to_thread(self.processor, images=pil_images, return_tensors="pt")
inputs.to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs)
if self.normalized:
aesthetics_scores = outputs.logits / 10.0
else:
aesthetics_scores = outputs.logits
return pa.array(aesthetics_scores.detach().cpu().flatten().numpy(), type=pa.float32())