from __future__ import annotations
import asyncio
import inspect
import textwrap
from types import FunctionType
from typing import Callable
import numpy as np
import pyarrow as pa
from PIL import Image
from ray.air.util.tensor_extensions.arrow import ArrowTensorArray
from ray.data.datatype import DataType
from ray.data.expressions import Expr, udf
from xpark.dataset.context import DatasetContext
from xpark.dataset.utils import read_image, safe_run
PIL_FUNC_LIST = [
Image.Image.width,
Image.Image.height,
Image.Image.size,
Image.Image.mode,
Image.Image.convert,
Image.Image.quantize,
Image.Image.crop,
Image.Image.filter,
Image.Image.entropy,
Image.Image.point,
Image.Image.resize,
Image.Image.rotate,
Image.Image.reduce,
Image.Image.getchannel,
]
RETURN_ANNOTATION_MAP = {
# TODO(anthonycai) Current DataType limitation: cannot handle mixed ndarray types (e.g., uint8 in 'RGB' mode vs float32 in 'F' mode)
"Image": DataType.binary(),
"int": DataType.int32(),
"float": DataType.float64(),
"str": DataType.string(),
"tuple[int, int]": DataType.from_arrow(pa.list_(pa.int32(), 2)),
}
PROPERTY_DOC = {
Image.Image.width: "Image width, in pixels.",
Image.Image.height: "Image height, in pixels.",
Image.Image.mode: "Image mode. This is a string specifying the pixel format used by the image.\n"
+ "Typical values are “1”, “L”, “RGB”, or “CMYK.”",
Image.Image.size: "Image size, in pixels. The size is given as a 2-tuple (width, height).",
}
IMAGES_DESCRIBE = ":param images: The images to be processed."
SOURCE_MODE_DESCRIBE = ":param source_mode: The mode of the images. Default is auto detect."
[docs]
class ImageCompute:
""".. note:: Do not construct this class, use the staticmethod instead."""
def __new__(cls, *args, **kwargs):
raise TypeError(f"The {cls.__name__} class cannot be instantiated.")
def wrap_pil_doc(pil_method: FunctionType | property, doc: str) -> str:
if isinstance(pil_method, property):
new_doc = "\n" + PROPERTY_DOC[pil_method] + "\n\n"
new_doc += IMAGES_DESCRIBE + "\n" + SOURCE_MODE_DESCRIBE + "\n"
else:
lines = inspect.cleandoc(doc).split("\n")
new_lines = []
param_section_started = False
for line in lines:
if line.strip().startswith(":param") and not param_section_started:
new_lines.append(IMAGES_DESCRIBE)
param_section_started = True
elif line.strip().startswith(":returns"):
new_lines.append(SOURCE_MODE_DESCRIBE)
new_lines.append(line)
new_doc = "\n".join(new_lines)
new_doc = textwrap.indent(new_doc, " " * 8)
return new_doc
def pil_method_wrapper(
pil_method: FunctionType | property,
) -> tuple[Callable[..., Expr], DataType]:
fn = pil_method if isinstance(pil_method, FunctionType) else pil_method.fget
orig_sig = inspect.signature(fn)
return_annotation = orig_sig.return_annotation
params = [p for name, p in orig_sig.parameters.items() if name != "self"]
method_name = fn.__name__
# add images and source_mode params
new_params = [
inspect.Parameter(
"images",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=pa.ChunkedArray,
),
*params,
inspect.Parameter(
"source_mode",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=None,
annotation=str | None,
),
]
new_sig = inspect.Signature(new_params)
def wrapper(*args, **kwargs):
# async func
async def async_wrapper(*args, **kwargs):
bound = new_sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
images = bound.arguments.pop("images")
source_mode = bound.arguments.pop("source_mode")
call_kwargs = dict(bound.arguments)
results = []
tasks = []
storage_options = DatasetContext.get_current().storage_options
if len(images) == 0:
return pa.array([])
for image in images:
tasks.append(read_image(image, source_mode=source_mode, **storage_options))
# TODO(anthonycai) support unordered preprocessor, or use auto batch
pil_images = await asyncio.gather(*tasks)
for image in pil_images:
if isinstance(pil_method, FunctionType):
out = getattr(image, method_name)(**call_kwargs)
else:
out = getattr(image, method_name)
if isinstance(out, Image.Image):
results.append(np.array(out))
else:
results.append(out)
if isinstance(results[0], np.ndarray):
return ArrowTensorArray.from_numpy(results)
elif isinstance(results[0], int):
return pa.array(results, type=pa.int32())
elif isinstance(results[0], float):
return pa.array(results, type=pa.float64())
elif isinstance(results[0], str):
return pa.array(results, type=pa.string())
elif isinstance(results[0], tuple):
return pa.array(results, type=pa.list_(pa.int32(), 2))
else:
raise ValueError(f"Unsupported return type: {type(results[0])}")
return safe_run(async_wrapper(*args, **kwargs))
wrapper.__name__ = method_name
wrapper.__doc__ = wrap_pil_doc(pil_method, pil_method.__doc__)
wrapper.__signature__ = new_sig
ann = dict(getattr(pil_method, "__annotations__", {}))
ann.update({"images": pa.ChunkedArray, "source_mode": str | None})
ann["return"] = pa.Array
wrapper.__annotations__ = ann
return wrapper, return_annotation
for _pil_method in PIL_FUNC_LIST:
if isinstance(_pil_method, FunctionType):
method_name = _pil_method.__name__
else:
method_name = _pil_method.fget.__name__
method, return_annotation = pil_method_wrapper(_pil_method)
data_type = RETURN_ANNOTATION_MAP[return_annotation]
setattr(ImageCompute, method_name, staticmethod(udf(return_dtype=data_type)(method)))
def _gen_stub_code() -> str:
import re
from io import StringIO
with StringIO() as stub:
stub.write('"""This is an auto-generated stub. Please do not modify this file."""\n\n')
stub.write("from ray.data.expressions import ColumnExpr, Expr\n")
stub.write("from PIL import Dither, Palette, ImageFilter\n")
stub.write("from PIL.Image import Image, ImagePointTransform, ImagePointHandler, Resampling\n")
stub.write("from PIL._typing import NumpyArray\n")
stub.write("from typing import Sequence, Callable\n\n")
stub.write("def _gen_stub_code() -> str: ...\n\n")
stub.write("class ImageCompute:\n")
for name, fn in ImageCompute.__dict__.items():
if name != "__new__" and isinstance(fn, staticmethod):
signature = inspect.signature(fn)
# replace <IntEnum: id> to IntEnum for
pattern = r"<([a-zA-Z]+\.[a-zA-Z]+):\s*\w+>"
params = re.sub(pattern, r"\1", str(signature))
params = re.sub(r"images: pyarrow.lib.ChunkedArray", "images: ColumnExpr", params)
stub.write(" @staticmethod\n")
stub.write(f" def {name}{params} -> Expr: ...\n")
return stub.getvalue()
if __name__ == "__main__":
import subprocess
with open(__file__ + "i", "w") as f:
f.write(_gen_stub_code())
# ruff format
subprocess.run(["ruff", "format", __file__ + "i"])
subprocess.run(["ruff", "check", "--fix", __file__ + "i"])