from __future__ import annotations
import asyncio
import logging
import os
from types import CoroutineType
from typing import TYPE_CHECKING, Any, Literal
from xpark.dataset.constants import IO_WORKER_ENV, NOT_SET
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.model import ModelSpec, cache_model
from xpark.dataset.utils import assert_type, iter_batch, qps_limiter
if TYPE_CHECKING:
import openai
import pyarrow as pa
else:
openai = lazy_import("openai")
pa = lazy_import("pyarrow", rename="pa")
logger = logging.getLogger("ray")
EmbeddingModel = {
"Qwen/Qwen3-Embedding-0.6B": {
"label": {"test", "all"},
"pooling": "last",
"dimensions": 1024,
"model_specs": {
"huggingface": {
"gguf": {
"model_id": "Qwen/Qwen3-Embedding-0.6B-GGUF",
"model_revision": "48f36f50b4a081a6f56dd4a227f9b66668e1399f",
"quantizations": ["Q8_0"],
"model_file_name_template": "Qwen3-Embedding-0.6B-{quantization}.gguf",
},
"pytorch": {
"model_id": "Qwen/Qwen3-Embedding-0.6B",
"model_revision": "744169034862c8eec56628663995004342e4e449",
"quantizations": [None],
},
},
"modelscope": {
"gguf": {
"model_id": "Qwen/Qwen3-Embedding-0.6B-GGUF",
"quantizations": ["Q8_0"],
"model_file_name_template": "Qwen3-Embedding-0.6B-{quantization}.gguf",
},
"pytorch": {
"model_id": "Qwen/Qwen3-Embedding-0.6B",
"quantizations": [None],
},
},
},
},
"Qwen/Qwen3-Embedding-4B": {
"pooling": "last",
"dimensions": 2560,
"model_specs": {
"huggingface": {
"gguf": {
"model_id": "Qwen/Qwen3-Embedding-4B-GGUF",
"model_revision": "f4602530db1d980e16da9d7d3a70294cf5c190be",
"quantizations": ["Q5_K_M"],
"model_file_name_template": "Qwen3-Embedding-4B-{quantization}.gguf",
},
"pytorch": {
"model_id": "Qwen/Qwen3-Embedding-4B",
"model_revision": "5cf2132abc99cad020ac570b19d031efec650f2b",
"quantizations": [None],
},
},
"modelscope": {
"gguf": {
"model_id": "Qwen/Qwen3-Embedding-4B-GGUF",
"quantizations": ["Q5_K_M"],
"model_file_name_template": "Qwen3-Embedding-4B-{quantization}.gguf",
},
"pytorch": {
"model_id": "Qwen/Qwen3-Embedding-4B",
"quantizations": [None],
},
},
},
},
"Qwen/Qwen3-Embedding-8B": {
"pooling": "last",
"dimensions": 4096,
"model_specs": {
"huggingface": {
"gguf": {
"model_id": "Qwen/Qwen3-Embedding-8B-GGUF",
"model_revision": "69d0e58a13e463cd99a9b83e3f5fee7c10265fab",
"quantizations": ["Q4_K_M"],
"model_file_name_template": "Qwen3-Embedding-8B-{quantization}.gguf",
},
"pytorch": {
"model_id": "Qwen/Qwen3-Embedding-8B",
"model_revision": "1d8ad4ca9b3dd8059ad90a75d4983776a23d44af",
"quantizations": [None],
},
},
"modelscope": {
"gguf": {
"model_id": "Qwen/Qwen3-Embedding-8B-GGUF",
"quantizations": ["Q4_K_M"],
"model_file_name_template": "Qwen3-Embedding-8B-{quantization}.gguf",
},
"pytorch": {
"model_id": "Qwen/Qwen3-Embedding-8B",
"quantizations": [None],
},
},
},
},
}
class EmbeddingModelSpec(ModelSpec):
pooling: Literal["mean", "cls", "last"]
dimensions: int
ALL_EMBEDDING_MODELS = {k: EmbeddingModelSpec.model_validate(v) for k, v in EmbeddingModel.items()}
AVAILABLE_MODELS = [k for k, v in ALL_EMBEDDING_MODELS.items() if v.label & {"all"}]
class TextEmbeddingGPU(BatchColumnClassProtocol):
def __init__(self, model: str):
logger.info("Using GPU TextEmbedding")
# We force vLLM to use the v0 engine to allocate VLLM workers in current job,
# and v1 engine does not support embedding well.
os.environ["VLLM_USE_V1"] = "0"
from vllm import LLM
from vllm.config import PoolerConfig
model_spec = ALL_EMBEDDING_MODELS[model]
model_path = cache_model("huggingface", "pytorch", model_spec, None)
self.model = LLM(
model_path,
task="embed",
enforce_eager=True,
# class PoolingType(IntEnum):
# """Enumeration for different types of pooling methods."""
# LAST = 0
# ALL = 1
# CLS = 2
# STEP = 3
# MEAN = 4
override_pooler_config=PoolerConfig(pooling_type=model_spec.pooling.upper(), normalize=True),
)
async def __call__(self, prompts: pa.ChunkedArray) -> pa.Array:
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = self.model.embed(prompts.to_pylist())
embeddings = [o.outputs.embedding for o in outputs]
dim = len(embeddings[0]) if len(embeddings) > 0 else 0
return pa.array(embeddings, type=pa.list_(pa.float32(), dim))
class TextEmbeddingCPU(BatchColumnClassProtocol):
def __init__(self, model: str):
logger.info("Using CPU TextEmbedding")
# Disable LLAMA_SET_ROWS for better embedding performance.
os.environ["LLAMA_SET_ROWS"] = "0"
import xllamacpp as xlc
model_spec = ALL_EMBEDDING_MODELS[model]
model_path = cache_model("huggingface", "gguf", model_spec, None)
assert os.path.exists(model_path)
params = xlc.CommonParams()
params.verbosity = -1
params.model.path = model_path
params.embedding = True
params.warmup = False
params.no_perf = True
params.n_parallel = 1
params.cpuparams.n_threads = 1
params.cpuparams_batch.n_threads = 1
# enum llama_pooling_type {
# LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
# LLAMA_POOLING_TYPE_NONE = 0,
# LLAMA_POOLING_TYPE_MEAN = 1,
# LLAMA_POOLING_TYPE_CLS = 2,
# LLAMA_POOLING_TYPE_LAST = 3,
# LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models
# };
params.pooling_type = {
"last": xlc.llama_pooling_type.LLAMA_POOLING_TYPE_LAST,
"mean": xlc.llama_pooling_type.LLAMA_POOLING_TYPE_MEAN,
"cls": xlc.llama_pooling_type.LLAMA_POOLING_TYPE_CLS,
}[model_spec.pooling]
self.model = xlc.Server(params)
async def __call__(self, prompts: pa.ChunkedArray) -> pa.Array:
result = self.model.handle_embeddings({"input": prompts.to_pylist()})
if result.get("code"):
raise RuntimeError(f"Embedding error: {result}")
embeddings = [d["embedding"] for d in result["data"]]
dim = len(embeddings[0]) if len(embeddings) > 0 else 0
return pa.array(embeddings, type=pa.list_(pa.float32(), dim))
class TextEmbeddingHttp(BatchColumnClassProtocol):
def __init__(
self,
base_url: str,
model: str,
api_key: str,
batch_rows: int = 10,
max_qps: int | None = None,
max_retries: int = 0,
**kwargs: dict[str, Any],
):
logger.info("Using Remote Http TextEmbedding")
self.batch_rows = batch_rows
self.model = model
self.kwargs = kwargs
client = openai.AsyncClient(api_key=api_key, base_url=base_url, max_retries=max_retries)
self.create_embeddings = qps_limiter(max_qps)(client.embeddings.create)
async def __call__(self, prompts: pa.ChunkedArray) -> pa.Array:
# Batch requests for embeddings follow the format of the OpenAI API:
# https://platform.openai.com/docs/api-reference/embeddings/create
requests = []
for batch in iter_batch(prompts, self.batch_rows):
requests.append(self.create_embeddings(input=batch.to_pylist(), model=self.model, **self.kwargs))
responses = await asyncio.gather(*requests)
results = []
for response in responses:
for embedding in response.data:
results.append(embedding.embedding)
dim = len(results[0]) if len(results) > 0 else 0
return pa.array(results, type=pa.list_(pa.float32(), dim))
[docs]
@udf(return_dtype=DataType.float32())
class TextEmbedding(BatchColumnClassProtocol):
__doc__ = f"""Text Embedding processor for CPU, GPU and remote Http requests.
Args:
_local_model: The embedding model name for CPU or GPU, available models: {AVAILABLE_MODELS}
base_url: The base URL of the LLM server.
model: The request model name.
api_key: The request API key.
batch_rows: The number of rows to request once.
max_qps: The maximum number of requests per second.
max_retries: The maximum number of retries per request in the event of failures.
We retry with exponential backoff upto this specific maximum retries.
**kwargs: Keyword arguments to pass to the `openai.AsyncClient.embeddings.create
<https://github.com/openai/openai-python/blob/main/src/openai/resources/embeddings.py>`_ API.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import TextEmbedding, from_items
ds = from_items([
"what is the advantage of using the GPU rendering options in Android?",
"Blank video when converting uncompressed AVI files with ffmpeg",
])
ds = ds.with_column(
"embedding",
TextEmbedding(
# Local embedding model.
"Qwen/Qwen3-Embedding-0.6B",
# For remote embedding requests.
base_url="http://127.0.0.1:9997/v1",
model="qwen3",
)
# One IO worker for HTTP request, 10 CPU workers for local embedding.
.options(num_workers={{"CPU": 10, "IO": 1}})
.with_column(col("item")),
)
print(ds.take(2))
"""
model: BatchColumnClassProtocol
def __init__(
self,
_local_model: str | None = None,
/,
*,
base_url: str | None = None,
model: str | None = None,
api_key: str = NOT_SET,
batch_rows: int = 10,
max_qps: int | None = None,
max_retries: int = 0,
**kwargs: dict[str, Any],
):
if _local_model is None and base_url is None:
raise ValueError("Either _local_model or base_url must be specified.")
if os.environ.get("CUDA_VISIBLE_DEVICES"):
if _local_model is None:
raise ValueError("_local_model must be specified for GPU worker.")
self.model = TextEmbeddingGPU(model=_local_model)
elif os.environ.get(IO_WORKER_ENV):
if base_url is None:
raise ValueError("base_url must be specified for IO worker.")
if model is None:
raise ValueError("model must be specified if base_url is specified.")
self.model = TextEmbeddingHttp(
base_url=base_url,
model=model,
api_key=api_key,
batch_rows=batch_rows,
max_qps=max_qps,
max_retries=max_retries,
**kwargs,
)
else:
if _local_model is None:
raise ValueError("_local_model must be specified for CPU worker.")
self.model = TextEmbeddingCPU(model=_local_model)
async def __call__(self, prompt: pa.ChunkedArray) -> pa.Array:
return await assert_type(self.model(prompt), CoroutineType)