Source code for xpark.dataset.processors.image_text_similarity_score
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from xpark.dataset.context import DatasetContext
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model, prepare_huggingface_model
from xpark.dataset.utils import read_image
if TYPE_CHECKING:
import pyarrow as pa
import torch
else:
pa = lazy_import("pyarrow", rename="pa")
torch = lazy_import("torch")
ImageTextModel = {
"openai/clip-vit-base-patch32": {
"label": {"test", "all"},
"model_specs": {
"huggingface": {
"pytorch": {
"model_id": "openai/clip-vit-base-patch32",
"model_revision": "3d74acf9a28c67741b2f4f2ea7635f0aaf6f0268",
"quantizations": [None],
},
},
},
}
}
class ImageTextModelSpec(ModelSpec):
pass
ALL_IMAGE_TEXT_MODEL = {k: ImageTextModelSpec.model_validate(v) for k, v in ImageTextModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_IMAGE_TEXT_MODEL.items() if v.label & {"all"}]
[docs]
@udf(return_dtype=DataType.float32())
class ImageTextSimilarityScore(BatchColumnClassProtocol):
__doc__ = f"""Image text similarity score calculation processor for CPU, GPU
Args:
text: The text to be used for similarity score calculation.
_local_model: The CLIP model name for CPU or GPU. default is "openai/clip-vit-base-patch32"
available models: {AVAILABLE_MODELS}
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import ImageTextSimilarityScore, from_items
import numpy as np
ds = from_items([
{{"image": np.random.randint(0, 255, (256, 256, 3)).astype(np.uint8), "path": "test.jpg"}}
])
ds = ds.with_column(
"image_score",
ImageTextSimilarityScore(text="a photo of a cat")
.options(num_workers={{"CPU": 4}}, batch_size=1)
.with_column(col("image")),
)
print(ds.take(1))
"""
def __init__(self, text: str, _local_model: str = "openai/clip-vit-base-patch32"):
if _local_model not in ALL_IMAGE_TEXT_MODEL:
raise ValueError(f"model is not supported. {_local_model}")
model_spec = ALL_IMAGE_TEXT_MODEL[_local_model]
model_path = cache_model("huggingface", "pytorch", model_spec, None)
self.text = text
self.model, self.processor = prepare_huggingface_model(model_path, device_map="auto")
async def __call__(self, images: pa.ChunkedArray) -> pa.Array:
storage_options = DatasetContext.get_current().storage_options
tasks = []
for image in images:
tasks.append(read_image(image, source_mode="RGB", **storage_options))
pil_images = await asyncio.gather(*tasks)
inputs = await asyncio.to_thread(
self.processor, text=[self.text], images=pil_images, return_tensors="pt", padding=True
)
inputs.to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits_per_text = outputs.logits_per_text
return pa.array(logits_per_text.detach().cpu().numpy()[0], type=pa.float32())