Source code for xpark.dataset.processors.image_text_similarity_score

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model, prepare_huggingface_model
from xpark.dataset.utils import read_image

if TYPE_CHECKING:
    import pyarrow as pa
    import torch
else:
    pa = lazy_import("pyarrow", rename="pa")
    torch = lazy_import("torch")

ImageTextModel = {
    "openai/clip-vit-base-patch32": {
        "label": {"test", "all"},
        "model_specs": {
            "huggingface": {
                "pytorch": {
                    "model_id": "openai/clip-vit-base-patch32",
                    "model_revision": "3d74acf9a28c67741b2f4f2ea7635f0aaf6f0268",
                    "quantizations": [None],
                },
            },
        },
    }
}


class ImageTextModelSpec(ModelSpec):
    pass


ALL_IMAGE_TEXT_MODEL = {k: ImageTextModelSpec.model_validate(v) for k, v in ImageTextModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_IMAGE_TEXT_MODEL.items() if v.label & {"all"}]


[docs] @udf(return_dtype=DataType.float32()) class ImageTextSimilarityScore(BatchColumnClassProtocol): __doc__ = f"""Image text similarity score calculation processor for CPU, GPU Args: text: The text to be used for similarity score calculation. _local_model: The CLIP model name for CPU or GPU. default is "openai/clip-vit-base-patch32" available models: {AVAILABLE_MODELS} Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import ImageTextSimilarityScore, from_items import numpy as np ds = from_items([ {{"image": np.random.randint(0, 255, (256, 256, 3)).astype(np.uint8), "path": "test.jpg"}} ]) ds = ds.with_column( "image_score", ImageTextSimilarityScore(text="a photo of a cat") .options(num_workers={{"CPU": 4}}, batch_size=1) .with_column(col("image")), ) print(ds.take(1)) """ def __init__(self, text: str, _local_model: str = "openai/clip-vit-base-patch32"): if _local_model not in ALL_IMAGE_TEXT_MODEL: raise ValueError(f"model is not supported. {_local_model}") model_spec = ALL_IMAGE_TEXT_MODEL[_local_model] model_path = cache_model("huggingface", "pytorch", model_spec, None) self.text = text self.model, self.processor = prepare_huggingface_model(model_path, device_map="auto") async def __call__(self, images: pa.ChunkedArray) -> pa.Array: storage_options = DatasetContext.get_current().storage_options tasks = [] for image in images: tasks.append(read_image(image, source_mode="RGB", **storage_options)) pil_images = await asyncio.gather(*tasks) inputs = await asyncio.to_thread( self.processor, text=[self.text], images=pil_images, return_tensors="pt", padding=True ) inputs.to(self.model.device) with torch.no_grad(): outputs = self.model(**inputs) logits_per_text = outputs.logits_per_text return pa.array(logits_per_text.detach().cpu().numpy()[0], type=pa.float32())