Source code for xpark.dataset.processors.text_embedding

from __future__ import annotations

import asyncio
import logging
import os
from types import CoroutineType
from typing import TYPE_CHECKING, Any, Literal

from xpark.dataset.constants import IO_WORKER_ENV, NOT_SET
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
from xpark.dataset.utils import assert_type, iter_batch, qps_limiter

if TYPE_CHECKING:
    import openai
    import pyarrow as pa
else:
    openai = lazy_import("openai")
    pa = lazy_import("pyarrow", rename="pa")

logger = logging.getLogger("ray")

EmbeddingModel = {
    "Qwen/Qwen3-Embedding-0.6B": {
        "label": {"test", "all"},
        "pooling": "last",
        "dimensions": 1024,
        "model_specs": {
            "huggingface": {
                "gguf": {
                    "model_id": "Qwen/Qwen3-Embedding-0.6B-GGUF",
                    "model_revision": "48f36f50b4a081a6f56dd4a227f9b66668e1399f",
                    "quantizations": ["Q8_0"],
                    "model_file_name_template": "Qwen3-Embedding-0.6B-{quantization}.gguf",
                },
                "pytorch": {
                    "model_id": "Qwen/Qwen3-Embedding-0.6B",
                    "model_revision": "744169034862c8eec56628663995004342e4e449",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "gguf": {
                    "model_id": "Qwen/Qwen3-Embedding-0.6B-GGUF",
                    "quantizations": ["Q8_0"],
                    "model_file_name_template": "Qwen3-Embedding-0.6B-{quantization}.gguf",
                },
                "pytorch": {
                    "model_id": "Qwen/Qwen3-Embedding-0.6B",
                    "quantizations": [None],
                },
            },
        },
    },
    "Qwen/Qwen3-Embedding-4B": {
        "pooling": "last",
        "dimensions": 2560,
        "model_specs": {
            "huggingface": {
                "gguf": {
                    "model_id": "Qwen/Qwen3-Embedding-4B-GGUF",
                    "model_revision": "f4602530db1d980e16da9d7d3a70294cf5c190be",
                    "quantizations": ["Q5_K_M"],
                    "model_file_name_template": "Qwen3-Embedding-4B-{quantization}.gguf",
                },
                "pytorch": {
                    "model_id": "Qwen/Qwen3-Embedding-4B",
                    "model_revision": "5cf2132abc99cad020ac570b19d031efec650f2b",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "gguf": {
                    "model_id": "Qwen/Qwen3-Embedding-4B-GGUF",
                    "quantizations": ["Q5_K_M"],
                    "model_file_name_template": "Qwen3-Embedding-4B-{quantization}.gguf",
                },
                "pytorch": {
                    "model_id": "Qwen/Qwen3-Embedding-4B",
                    "quantizations": [None],
                },
            },
        },
    },
    "Qwen/Qwen3-Embedding-8B": {
        "pooling": "last",
        "dimensions": 4096,
        "model_specs": {
            "huggingface": {
                "gguf": {
                    "model_id": "Qwen/Qwen3-Embedding-8B-GGUF",
                    "model_revision": "69d0e58a13e463cd99a9b83e3f5fee7c10265fab",
                    "quantizations": ["Q4_K_M"],
                    "model_file_name_template": "Qwen3-Embedding-8B-{quantization}.gguf",
                },
                "pytorch": {
                    "model_id": "Qwen/Qwen3-Embedding-8B",
                    "model_revision": "1d8ad4ca9b3dd8059ad90a75d4983776a23d44af",
                    "quantizations": [None],
                },
            },
            "modelscope": {
                "gguf": {
                    "model_id": "Qwen/Qwen3-Embedding-8B-GGUF",
                    "quantizations": ["Q4_K_M"],
                    "model_file_name_template": "Qwen3-Embedding-8B-{quantization}.gguf",
                },
                "pytorch": {
                    "model_id": "Qwen/Qwen3-Embedding-8B",
                    "quantizations": [None],
                },
            },
        },
    },
}


class EmbeddingModelSpec(ModelSpec):
    pooling: Literal["mean", "cls", "last"]
    dimensions: int


ALL_EMBEDDING_MODELS = {k: EmbeddingModelSpec.model_validate(v) for k, v in EmbeddingModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_EMBEDDING_MODELS.items() if v.label & {"all"}]


class TextEmbeddingGPU(BatchColumnClassProtocol):
    def __init__(self, model: str):
        logger.info("Using GPU TextEmbedding")
        # We force vLLM to use the v0 engine to allocate VLLM workers in current job,
        # and v1 engine does not support embedding well.
        os.environ["VLLM_USE_V1"] = "0"
        from vllm import LLM
        from vllm.config import PoolerConfig

        model_spec = ALL_EMBEDDING_MODELS[model]
        model_path = cache_model("huggingface", "pytorch", model_spec, None)
        self.model = LLM(
            model_path,
            task="embed",
            enforce_eager=True,
            # class PoolingType(IntEnum):
            #     """Enumeration for different types of pooling methods."""
            #     LAST = 0
            #     ALL = 1
            #     CLS = 2
            #     STEP = 3
            #     MEAN = 4
            override_pooler_config=PoolerConfig(pooling_type=model_spec.pooling.upper(), normalize=True),
        )

    async def __call__(self, prompts: pa.ChunkedArray) -> pa.Array:
        # Generate embedding. The output is a list of EmbeddingRequestOutputs.
        outputs = self.model.embed(prompts.to_pylist())
        embeddings = [o.outputs.embedding for o in outputs]
        dim = len(embeddings[0]) if len(embeddings) > 0 else 0
        return pa.array(embeddings, type=pa.list_(pa.float32(), dim))


class TextEmbeddingCPU(BatchColumnClassProtocol):
    def __init__(self, model: str):
        logger.info("Using CPU TextEmbedding")
        # Disable LLAMA_SET_ROWS for better embedding performance.
        os.environ["LLAMA_SET_ROWS"] = "0"
        import xllamacpp as xlc

        model_spec = ALL_EMBEDDING_MODELS[model]
        model_path = cache_model("huggingface", "gguf", model_spec, None)
        assert os.path.exists(model_path)
        params = xlc.CommonParams()
        params.verbosity = -1
        params.model.path = model_path
        params.embedding = True
        params.warmup = False
        params.no_perf = True
        params.n_parallel = 1
        params.cpuparams.n_threads = 1
        params.cpuparams_batch.n_threads = 1
        # enum llama_pooling_type {
        #     LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
        #     LLAMA_POOLING_TYPE_NONE = 0,
        #     LLAMA_POOLING_TYPE_MEAN = 1,
        #     LLAMA_POOLING_TYPE_CLS  = 2,
        #     LLAMA_POOLING_TYPE_LAST = 3,
        #     LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models
        # };
        params.pooling_type = {
            "last": xlc.llama_pooling_type.LLAMA_POOLING_TYPE_LAST,
            "mean": xlc.llama_pooling_type.LLAMA_POOLING_TYPE_MEAN,
            "cls": xlc.llama_pooling_type.LLAMA_POOLING_TYPE_CLS,
        }[model_spec.pooling]
        self.model = xlc.Server(params)

    async def __call__(self, prompts: pa.ChunkedArray) -> pa.Array:
        result = self.model.handle_embeddings({"input": prompts.to_pylist()})
        if result.get("code"):
            raise RuntimeError(f"Embedding error: {result}")

        embeddings = [d["embedding"] for d in result["data"]]
        dim = len(embeddings[0]) if len(embeddings) > 0 else 0
        return pa.array(embeddings, type=pa.list_(pa.float32(), dim))


class TextEmbeddingHttp(BatchColumnClassProtocol):
    def __init__(
        self,
        base_url: str,
        model: str,
        api_key: str,
        batch_rows: int = 10,
        max_qps: int | None = None,
        max_retries: int = 0,
        **kwargs: dict[str, Any],
    ):
        logger.info("Using Remote Http TextEmbedding")
        self.batch_rows = batch_rows
        self.model = model
        self.kwargs = kwargs
        client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries)
        self.create_embeddings = qps_limiter(max_qps)(client.embeddings.create)

    async def __call__(self, prompts: pa.ChunkedArray) -> pa.Array:
        # Batch requests for embeddings follow the format of the OpenAI API:
        # https://platform.openai.com/docs/api-reference/embeddings/create
        requests = []
        for batch in iter_batch(prompts, self.batch_rows):
            requests.append(self.create_embeddings(input=batch.to_pylist(), model=self.model, **self.kwargs))

        responses = await asyncio.gather(*requests)
        results = []
        for response in responses:
            for embedding in response.data:
                results.append(embedding.embedding)
        dim = len(results[0]) if len(results) > 0 else 0
        return pa.array(results, type=pa.list_(pa.float32(), dim))


[docs] @udf(return_dtype=DataType.float32()) class TextEmbedding(BatchColumnClassProtocol): __doc__ = f"""Text Embedding processor for CPU, GPU and remote Http requests. Args: _local_model: The embedding model name for CPU or GPU, available models: {AVAILABLE_MODELS} base_url: The base URL of the LLM server. model: The request model name. api_key: The request API key. batch_rows: The number of rows to request once. max_qps: The maximum number of requests per second. max_retries: The maximum number of retries per request in the event of failures. We retry with exponential backoff upto this specific maximum retries. **kwargs: Keyword arguments to pass to the `openai.AsyncClient.embeddings.create <https://github.com/openai/openai-python/blob/main/src/openai/resources/embeddings.py>`_ API. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import TextEmbedding, from_items ds = from_items([ "what is the advantage of using the GPU rendering options in Android?", "Blank video when converting uncompressed AVI files with ffmpeg", ]) ds = ds.with_column( "embedding", TextEmbedding( # Local embedding model. "Qwen/Qwen3-Embedding-0.6B", # For remote embedding requests. base_url="http://127.0.0.1:9997/v1", model="qwen3", ) # One IO worker for HTTP request, 10 CPU workers for local embedding. .options(num_workers={{"CPU": 10, "IO": 1}}) .with_column(col("item")), ) print(ds.take(2)) """ model: BatchColumnClassProtocol def __init__( self, _local_model: str | None = None, /, *, base_url: str | None = None, model: str | None = None, api_key: str = NOT_SET, batch_rows: int = 10, max_qps: int | None = None, max_retries: int = 0, **kwargs: dict[str, Any], ): if _local_model is None and base_url is None: raise ValueError("Either _local_model or base_url must be specified.") if os.environ.get("CUDA_VISIBLE_DEVICES"): if _local_model is None: raise ValueError("_local_model must be specified for GPU worker.") self.model = TextEmbeddingGPU(model=_local_model) elif os.environ.get(IO_WORKER_ENV): if base_url is None: raise ValueError("base_url must be specified for IO worker.") if model is None: raise ValueError("model must be specified if base_url is specified.") self.model = TextEmbeddingHttp( base_url=base_url, model=model, api_key=api_key, batch_rows=batch_rows, max_qps=max_qps, max_retries=max_retries, **kwargs, ) else: if _local_model is None: raise ValueError("_local_model must be specified for CPU worker.") self.model = TextEmbeddingCPU(model=_local_model) async def __call__(self, prompt: pa.ChunkedArray) -> pa.Array: return await assert_type(self.model(prompt), CoroutineType)