Source code for xpark.dataset.processors.image_compute

from __future__ import annotations

import asyncio
import inspect
import textwrap
from types import FunctionType
from typing import Callable

import numpy as np
import pyarrow as pa
from PIL import Image
from ray.air.util.tensor_extensions.arrow import ArrowTensorArray
from ray.data.datatype import DataType
from ray.data.expressions import Expr, udf

from xpark.dataset.context import DatasetContext
from xpark.dataset.utils import read_image, safe_run

PIL_FUNC_LIST = [
    Image.Image.width,
    Image.Image.height,
    Image.Image.size,
    Image.Image.mode,
    Image.Image.convert,
    Image.Image.quantize,
    Image.Image.crop,
    Image.Image.filter,
    Image.Image.entropy,
    Image.Image.point,
    Image.Image.resize,
    Image.Image.rotate,
    Image.Image.reduce,
    Image.Image.getchannel,
]

RETURN_ANNOTATION_MAP = {
    # TODO(anthonycai) Current DataType limitation: cannot handle mixed ndarray types (e.g., uint8 in 'RGB' mode vs float32 in 'F' mode)
    "Image": DataType.binary(),
    "int": DataType.int32(),
    "float": DataType.float64(),
    "str": DataType.string(),
    "tuple[int, int]": DataType.from_arrow(pa.list_(pa.int32(), 2)),
}

PROPERTY_DOC = {
    Image.Image.width: "Image width, in pixels.",
    Image.Image.height: "Image height, in pixels.",
    Image.Image.mode: "Image mode. This is a string specifying the pixel format used by the image.\n"
    + "Typical values are “1”, “L”, “RGB”, or “CMYK.”",
    Image.Image.size: "Image size, in pixels. The size is given as a 2-tuple (width, height).",
}

IMAGES_DESCRIBE = ":param images: The images to be processed."
SOURCE_MODE_DESCRIBE = ":param source_mode: The mode of the images. Default is auto detect."


[docs] class ImageCompute: """.. note:: Do not construct this class, use the staticmethod instead.""" def __new__(cls, *args, **kwargs): raise TypeError(f"The {cls.__name__} class cannot be instantiated.")
def wrap_pil_doc(pil_method: FunctionType | property, doc: str) -> str: if isinstance(pil_method, property): new_doc = "\n" + PROPERTY_DOC[pil_method] + "\n\n" new_doc += IMAGES_DESCRIBE + "\n" + SOURCE_MODE_DESCRIBE + "\n" else: lines = inspect.cleandoc(doc).split("\n") new_lines = [] param_section_started = False for line in lines: if line.strip().startswith(":param") and not param_section_started: new_lines.append(IMAGES_DESCRIBE) param_section_started = True elif line.strip().startswith(":returns"): new_lines.append(SOURCE_MODE_DESCRIBE) new_lines.append(line) new_doc = "\n".join(new_lines) new_doc = textwrap.indent(new_doc, " " * 8) return new_doc def pil_method_wrapper( pil_method: FunctionType | property, ) -> tuple[Callable[..., Expr], DataType]: fn = pil_method if isinstance(pil_method, FunctionType) else pil_method.fget orig_sig = inspect.signature(fn) return_annotation = orig_sig.return_annotation params = [p for name, p in orig_sig.parameters.items() if name != "self"] method_name = fn.__name__ # add images and source_mode params new_params = [ inspect.Parameter( "images", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=pa.ChunkedArray, ), *params, inspect.Parameter( "source_mode", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None, annotation=str | None, ), ] new_sig = inspect.Signature(new_params) def wrapper(*args, **kwargs): # async func async def async_wrapper(*args, **kwargs): bound = new_sig.bind_partial(*args, **kwargs) bound.apply_defaults() images = bound.arguments.pop("images") source_mode = bound.arguments.pop("source_mode") call_kwargs = dict(bound.arguments) results = [] tasks = [] storage_options = DatasetContext.get_current().storage_options if len(images) == 0: return pa.array([]) for image in images: tasks.append(read_image(image, source_mode=source_mode, **storage_options)) # TODO(anthonycai) support unordered preprocessor, or use auto batch pil_images = await asyncio.gather(*tasks) for image in pil_images: if isinstance(pil_method, FunctionType): out = getattr(image, method_name)(**call_kwargs) else: out = getattr(image, method_name) if isinstance(out, Image.Image): results.append(np.array(out)) else: results.append(out) if isinstance(results[0], np.ndarray): return ArrowTensorArray.from_numpy(results) elif isinstance(results[0], int): return pa.array(results, type=pa.int32()) elif isinstance(results[0], float): return pa.array(results, type=pa.float64()) elif isinstance(results[0], str): return pa.array(results, type=pa.string()) elif isinstance(results[0], tuple): return pa.array(results, type=pa.list_(pa.int32(), 2)) else: raise ValueError(f"Unsupported return type: {type(results[0])}") return safe_run(async_wrapper(*args, **kwargs)) wrapper.__name__ = method_name wrapper.__doc__ = wrap_pil_doc(pil_method, pil_method.__doc__) wrapper.__signature__ = new_sig ann = dict(getattr(pil_method, "__annotations__", {})) ann.update({"images": pa.ChunkedArray, "source_mode": str | None}) ann["return"] = pa.Array wrapper.__annotations__ = ann return wrapper, return_annotation for _pil_method in PIL_FUNC_LIST: if isinstance(_pil_method, FunctionType): method_name = _pil_method.__name__ else: method_name = _pil_method.fget.__name__ method, return_annotation = pil_method_wrapper(_pil_method) data_type = RETURN_ANNOTATION_MAP[return_annotation] setattr(ImageCompute, method_name, staticmethod(udf(return_dtype=data_type)(method))) def _gen_stub_code() -> str: import re from io import StringIO with StringIO() as stub: stub.write('"""This is an auto-generated stub. Please do not modify this file."""\n\n') stub.write("from ray.data.expressions import ColumnExpr, Expr\n") stub.write("from PIL import Dither, Palette, ImageFilter\n") stub.write("from PIL.Image import Image, ImagePointTransform, ImagePointHandler, Resampling\n") stub.write("from PIL._typing import NumpyArray\n") stub.write("from typing import Sequence, Callable\n\n") stub.write("def _gen_stub_code() -> str: ...\n\n") stub.write("class ImageCompute:\n") for name, fn in ImageCompute.__dict__.items(): if name != "__new__" and isinstance(fn, staticmethod): signature = inspect.signature(fn) # replace <IntEnum: id> to IntEnum for pattern = r"<([a-zA-Z]+\.[a-zA-Z]+):\s*\w+>" params = re.sub(pattern, r"\1", str(signature)) params = re.sub(r"images: pyarrow.lib.ChunkedArray", "images: ColumnExpr", params) stub.write(" @staticmethod\n") stub.write(f" def {name}{params} -> Expr: ...\n") return stub.getvalue() if __name__ == "__main__": import subprocess with open(__file__ + "i", "w") as f: f.write(_gen_stub_code()) # ruff format subprocess.run(["ruff", "format", __file__ + "i"]) subprocess.run(["ruff", "check", "--fix", __file__ + "i"])