Source code for xpark.dataset.processors.text_chunking

from __future__ import annotations

import json
import logging
import re
from re import Match
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Tuple, overload

from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.utils import sanitize_path_segment

if TYPE_CHECKING:
    import chonkie
    import pyarrow as pa
    import tiktoken
else:
    chonkie = lazy_import("chonkie")
    pa = lazy_import("pyarrow", rename="pa")
    tiktoken = lazy_import("tiktoken")

logger = logging.getLogger("ray")


DEFAULT_MAX_CHUNK_TOKENS = 1024
DEFAULT_CHUNK_SIZE_BYTES = 4096
MAX_MERGED_FILENAME_LENGTH = 32
MAX_FILENAME_LENGTH = 32

DEFAULT_TOKENIZER = "cl100k_base"


class MarkdownTreeSplitter:
    """Split markdown into a virtual file-tree of ``(path, text, is_dir)`` tuples.

    Each chunk respects *max_tokens*.

    Algorithm derived from:
     https://github.com/volcengine/OpenViking/blob/3064725bd63f70ab5054a661c3de5fb8b00e6c96/openviking/parse/parsers/markdown.py (SPDX-License-Identifier: Apache-2.0)
    """

    _heading_re = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
    _frontmatter_re = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
    _exclude_re = re.compile(
        r"^`{3,}[^\n]*\n.*?^`{3,}\s*$"
        r"|<!--.*?-->"  # HTML comment
        r"|^(?:    |\t)[^\n]+$",
        re.DOTALL | re.MULTILINE,
    )

    def __init__(
        self,
        max_tokens: int = DEFAULT_MAX_CHUNK_TOKENS,
    ):
        self.max_tokens = max_tokens
        self._chunker: chonkie.RecursiveChunker = chonkie.RecursiveChunker(
            chunk_size=max_tokens,
            tokenizer=DEFAULT_TOKENIZER,
        )

    def _count_tokens(self, text: str) -> int:
        return self._chunker._estimate_token_count(text)

    def _para_split(self, content: str) -> List[str]:
        chunks = self._chunker.chunk(content)
        return [chunk.text.strip() for chunk in chunks] or [content]

    def split(self, content: str, doc_name: str) -> List[Tuple[str, str, bool]]:
        """Return ``(relative_path, text, is_dir)`` tuples rooted at *doc_name*."""
        results: List[Tuple[str, str, bool]] = []
        content = self._strip_frontmatter(content)

        headings: List[Tuple[int, int, str, int]] = self._extract_headings(content)

        if self._count_tokens(content) <= self.max_tokens:
            results.append((f"{doc_name}/{doc_name}.md", content, False))
            return results

        if not headings:
            for i, part in enumerate(self._para_split(content), 1):
                results.append((f"{doc_name}/{doc_name}_{i}.md", part, False))
            return results

        self._build_and_emit(content, headings, doc_name, results)
        return results

    def _strip_frontmatter(self, content: str) -> str:
        match: Match[str] | None = self._frontmatter_re.match(content)
        if not match:
            return content
        return content[match.end() :]

    def _extract_headings(self, content: str) -> List[Tuple[int, int, str, int]]:
        excluded: list[tuple[int, int]] = [(m.start(), m.end()) for m in self._exclude_re.finditer(content)]
        # Single-pattern output is already position-sorted — no sort needed.
        exc_idx = 0
        exc_len = len(excluded)

        headings: List[Tuple[int, int, str, int]] = []
        for m in self._heading_re.finditer(content):
            pos: int = m.start()
            while exc_idx < exc_len and excluded[exc_idx][1] <= pos:
                exc_idx += 1
            if exc_idx < exc_len and excluded[exc_idx][0] <= pos:
                continue
            if pos > 0 and content[pos - 1] == "\\":
                continue
            headings.append((m.start(), m.end(), m.group(2).strip(), len(m.group(1))))
        return headings

    def _build_and_emit(
        self,
        content: str,
        headings: List[Tuple[int, int, str, int]],
        root_name: str,
        results: List[Tuple[str, str, bool]],
    ) -> None:
        n = len(content)
        stack: List[Dict] = []

        def _make_node(name: str, level: int, body_start: int, path: str) -> Dict[str, Any]:
            # - name: heading name
            # - level: heading level
            # - path: relative path
            # - body_start: start of body (after heading)
            # - body_end: end of body (before the end of this section)
            # - direct_end: end of body (before next heading)
            # - children: list of child nodes
            return {
                "name": name,
                "level": level,
                "path": path,
                "body_start": body_start,
                "body_end": n,
                "direct_end": n,
                "children": [],
            }

        # Create a supernode to represent the whole document
        supernode = _make_node(root_name, 0, 0, root_name)
        supernode["body_end"] = supernode["direct_end"] = n
        stack.append(supernode)

        # Top-down tree build. Track sibling count at each level to ensure unique paths
        for idx, (h_start, h_end, title, level) in enumerate(headings):
            name = sanitize_path_segment(title, max_len=MAX_FILENAME_LENGTH, default_name="section")

            # Close nodes that are at same or deeper level
            while len(stack) > 1 and stack[-1]["level"] >= level:
                closed = stack.pop()
                closed["body_end"] = h_start
                if not closed["direct_end"]:
                    closed["direct_end"] = h_start

            parent = stack[-1]
            parent_path = parent["path"]

            if len(parent["children"]) == 0:
                parent["children"] = []
                parent["direct_end"] = h_start

            # Generate unique path with sequential numbering
            sibling_index = len(parent["children"]) + 1
            unique_name = f"{sibling_index}-{name}"
            node = _make_node(name, level, h_end + 1, f"{parent_path}/{unique_name}")

            parent["children"].append(node)

            stack.append(node)

        while len(stack) > 1:
            closed = stack.pop()
            closed["body_end"] = n
            if not closed["children"]:
                closed["direct_end"] = n

        preamble_end = headings[0][0]
        if preamble_end > 0 and content[:preamble_end].strip():
            supernode["direct_end"] = preamble_end
        else:
            supernode["direct_end"] = 0

        def _text(node: Dict) -> str:
            lvl = node["level"]
            body = content[node["body_start"] : node["body_end"]].strip()
            if lvl == 0:
                return body
            prefix = "#" * lvl
            return f"{prefix} {node['name']}\n\n{body}" if body else f"{prefix} {node['name']}"

        def _emit(node: Dict) -> int:
            lvl = node["level"]

            if not node["children"]:
                # Leaf node.
                tok = self._count_tokens(_text(node))
                node["_tok"] = tok
                node["_emitted"] = False  # Leaf nodes are never emitted by themselves
                return tok

            # Track which children were already emitted (expanded)
            unemitted_children = []
            emitted_children = []

            for child in node["children"]:
                _emit(child)
                if not child["_emitted"]:
                    unemitted_children.append(child)
                else:
                    emitted_children.append(child)

            prefix = "#" * lvl if lvl > 0 else ""
            own_body = content[node["body_start"] : node["direct_end"]].strip()
            own_text = (
                (f"{prefix} {node['name']}\n\n{own_body}" if own_body else f"{prefix} {node['name']}")
                if lvl > 0
                else own_body
            )
            own_tok = self._count_tokens(own_text)
            total_tok = own_tok + sum(c["_tok"] for c in node["children"])
            node["_tok"] = total_tok

            if total_tok <= self.max_tokens:
                # Forward the current subtree to father node.
                node["_emitted"] = False  # Not emitted, will be handled by parent
                return total_tok

            # Need to split: create directory with child info
            subdir = node["path"]

            # Push [own text node] + ONLY unemitted children for merging
            children_to_merge: List[Dict] = []
            if own_text:
                children_to_merge.append(
                    {
                        "name": node["name"],
                        "level": lvl,
                        "path": node["path"],
                        "body_start": node["body_start"],
                        "body_end": node["direct_end"],
                        "direct_end": node["direct_end"],
                        "children": [],
                        "_tok": own_tok,
                        "_emitted": False,
                    }
                )

            # Only merge children that haven't been emitted yet
            children_to_merge.extend(unemitted_children)

            # Track how many results we have before merging
            results_before = len(results)
            _merge_and_output(children_to_merge, subdir)

            # Build JSON child references for the directory node
            # Format: {"children": [{"type": "file"|"dir", "path": "..."}]}
            child_refs = []

            # First, add references to already-emitted child directories
            for child in emitted_children:
                child_refs.append({"type": "dir", "path": child["path"]})

            # Second, add references to newly created file chunks
            # These are the files just created by _merge_and_output
            for i in range(results_before, len(results)):
                child_refs.append({"type": "file", "path": results[i][0]})

            # Serialize as JSON for easy deserialization in Python and other languages
            # Using ensure_ascii=False preserves Unicode (Chinese, emoji, etc.)
            dir_content = json.dumps({"children": child_refs}, ensure_ascii=False) if child_refs else ""
            if len(child_refs) > 0:
                results.append((subdir, dir_content, True))

            node["_emitted"] = True  # Mark this node as emitted

            return total_tok

        def _merge_and_output(siblings: List[Dict], parent_dir: str) -> None:
            pending: List[Tuple[str, str, int]] = []
            pending_tok = 0

            def try_push(name: str, text: str, tok: int) -> bool:
                nonlocal pending, pending_tok
                if tok and pending_tok + tok <= self.max_tokens:
                    pending.append((name, text, tok))
                    pending_tok += tok
                    return True
                if pending:
                    # Lazy flush.
                    self._save_merged(parent_dir, pending, results)
                    pending = []
                    pending_tok = 0
                return False

            for node in siblings:
                # There should not exist nodes that have already been emitted
                if node.get("_emitted", True):
                    raise ValueError(
                        f"Node {node['path']} has already been emitted "
                        f"(doc='{root_name}', parent_dir='{parent_dir}', "
                        f"level={node.get('level', '?')}, tok={node.get('_tok', '?')})"
                    )
                node["_emitted"] = True

                # Use the unique path segment (e.g., "1-introduction") instead of just name
                # Extract the last segment from the full path which includes sequence number
                path_segments = node["path"].split("/")
                unique_name = path_segments[-1]  # e.g., "1-introduction"
                tok: int = node["_tok"]
                text: str = _text(node)

                if try_push(unique_name, text, tok):
                    continue

                if tok > self.max_tokens:
                    for i, part in enumerate(self._para_split(text), 1):
                        results.append((f"{parent_dir}/{unique_name}_{i}.md", part, False))
                else:
                    if not try_push(unique_name, text, tok):
                        raise ValueError(
                            f"Node {node['path']} push failed "
                            f"(doc='{root_name}', tok={tok}, max_tokens={self.max_tokens}, "
                            f"pending_tok={pending_tok}, pending_count={len(pending)})"
                        )

            try_push("", "", 0)  # Empty text used for flush.

        # Emit from the supernode
        _emit(supernode)

    def _generate_merged_filename(self, sections: List[Tuple[str, str, int]]) -> str:
        if not sections:
            return "merged"
        names: list[str] = [sec for sec, _, _ in sections]
        if len(names) == 1:
            name = names[0]
        else:
            suffix = f"_{len(names)}more"
            name = f"{names[0]}{suffix}"
        return sanitize_path_segment(name, max_len=MAX_MERGED_FILENAME_LENGTH, default_name="merged")

    def _save_merged(
        self,
        parent_dir: str,
        sections: List[Tuple[str, str, int]],
        results: List[Tuple[str, str, bool]],
    ) -> None:
        fname = self._generate_merged_filename(sections)
        merged = "\n\n".join(c for _, c, _ in sections)
        tok = self._count_tokens(merged)
        if tok > self.max_tokens:
            section_names = [name for name, _, _ in sections]
            raise ValueError(
                f"Merged text exceeds token limit: {tok} > {self.max_tokens} "
                f"(parent_dir='{parent_dir}', sections={section_names})"
            )
        else:
            results.append((f"{parent_dir}/{fname}.md", merged, False))


[docs] @udf(return_dtype=DataType.from_arrow(pa.list_(pa.string()))) class TextChunking(BatchColumnClassProtocol): """Unified text chunking operator with multiple strategies. Supports two chunking strategies via the ``strategy`` parameter: - ``"fast"`` (default): SIMD-accelerated byte-based chunking (100+ GB/s throughput). No tokenization overhead. Best for high-throughput pipelines. - ``"recursive"``: Recursive token-based chunking with hierarchical splitting. Uses multi-level rules (RecursiveRules) for better structure preservation. Supports loading pre-configured recipes via ``recipe`` parameter. Please refer to https://huggingface.co/datasets/chonkie-ai/recipes for recipes. Args: strategy (Literal["fast", "recursive"]): Chunking strategy to use. Defaults to ``"fast"``. chunk_size (int | None): Target chunk size. For ``"fast"`` strategy this is measured in **bytes** (default 4096). For ``"recursive"`` strategy this is measured in **tokens** (default 1024). If ``None``, the strategy-specific default is used. **Fast strategy parameters** (only used when ``strategy="fast"``): delimiters (str): Single-byte delimiter characters to split on. Each character in the string is treated as an individual delimiter. Defaults to ``"\\n.?"``. pattern (str | None): Multi-byte pattern to split on. If set, overrides ``delimiters``. Defaults to ``None``. prefix (bool): If ``True``, keep the delimiter/pattern at the start of the next chunk instead of the end of the current chunk. Defaults to ``False``. consecutive (bool): If ``True``, split at the **start** of consecutive delimiter runs instead of the middle. Defaults to ``False``. forward_fallback (bool): If ``True``, search forward for a delimiter when none is found in the backward search window. Defaults to ``False``. **Recursive strategy parameters** (only used when ``strategy="recursive"``): tokenizer (str | None): Tokenizer to use for token counting. Can be a ``tiktoken`` encoding name (e.g. ``"cl100k_base"``, ``"o200k_base"``). Defaults to ``"cl100k_base"``. rules (RecursiveRules | None): Hierarchical splitting rules that define multi-level delimiters for recursive chunking. If ``None``, the default ``RecursiveRules()`` is used, which splits by paragraphs → sentences → punctuation → whitespace → characters. Defaults to ``None``. min_characters_per_chunk (int): Minimum number of characters per chunk. Chunks shorter than this will be merged with adjacent chunks. Defaults to ``12``. recipe (str | None): Name of a pre-configured recipe to load via ``RecursiveChunker.from_recipe()``. When set, ``rules`` is ignored and the recipe's built-in rules are used instead. Common recipes include ``"default"``. See https://huggingface.co/datasets/chonkie-ai/recipes for available recipes. Defaults to ``None``. lang (str): Language code for recipe loading (e.g. ``"en"``, ``"zh"``). Only used when ``recipe`` is set. Defaults to ``"en"``. Returns a list of text chunks (``pa.list_(pa.string())``) per input row. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import TextChunking, from_items text = "The quick brown fox. Jumps over the lazy dog. Hello world." # Fast chunking (default) - byte-based, no tokenization ds = from_items([text]) ds = ds.with_column( "chunks_fast", TextChunking(strategy="fast", chunk_size=20, delimiters=". \\n") .options(num_workers={"CPU": 4}, batch_size=32) .with_column(col("item")), ) # Recursive token chunking with custom parameters ds = ds.with_column( "chunks_recursive", TextChunking( strategy="recursive", chunk_size=8, min_characters_per_chunk=1, tokenizer="cl100k_base", ) .options(num_workers={"CPU": 4}, batch_size=32) .with_column(col("item")), ) # Recursive chunking with pre-configured recipe ds = ds.with_column( "chunks_with_recipe", TextChunking( strategy="recursive", recipe="default", lang="en", chunk_size=8, min_characters_per_chunk=1, ) .options(num_workers={"CPU": 4}, batch_size=32) .with_column(col("item")), ) print(ds.take_all()) # Output: [{"item": "The quick brown fox. Jumps over the lazy dog. Hello world.", # "chunks_fast": ["The quick brown fox.", " Jumps over the ", "lazy dog. Hello ", "world."], # "chunks_recursive": ["The quick brown fox. ", "Jumps over the lazy dog. ", "Hello world."], # "chunks_with_recipe": ["The quick brown fox. ", "Jumps over the lazy dog. ", "Hello world."]}] """ @overload def __init__( self, /, *, strategy: Literal["fast"] = "fast", chunk_size: int = DEFAULT_CHUNK_SIZE_BYTES, delimiters: str = "\n.?", pattern: str | None = None, prefix: bool = False, consecutive: bool = False, forward_fallback: bool = False, ) -> None: ... @overload def __init__( self, /, *, strategy: Literal["recursive"], chunk_size: int = DEFAULT_MAX_CHUNK_TOKENS, tokenizer: str | None = None, rules: chonkie.RecursiveRules | None = None, min_characters_per_chunk: int = 12, # from_recipe support recipe: str | None = None, lang: str = "en", ) -> None: ... def __init__( self, /, *, strategy: Literal["fast", "recursive"] = "fast", chunk_size: int | None = None, # Fast strategy params delimiters: str = "\n.?", pattern: str | None = None, prefix: bool = False, consecutive: bool = False, forward_fallback: bool = False, # Recursive strategy params tokenizer: str | None = None, rules: chonkie.RecursiveRules | None = None, min_characters_per_chunk: int = 12, recipe: str | None = None, lang: str = "en", ) -> None: self._chunker: chonkie.FastChunker | chonkie.RecursiveChunker self._strategy = strategy if strategy == "fast": # Fast chunker: byte-based, SIMD-accelerated actual_chunk_size = chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE_BYTES self._chunker = chonkie.FastChunker( chunk_size=actual_chunk_size, delimiters=delimiters, pattern=pattern, prefix=prefix, consecutive=consecutive, forward_fallback=forward_fallback, ) elif strategy == "recursive": # Recursive chunker: token-based with hierarchical splitting actual_chunk_size = chunk_size if chunk_size is not None else DEFAULT_MAX_CHUNK_TOKENS actual_tokenizer = tokenizer if tokenizer is not None else DEFAULT_TOKENIZER if recipe is not None: self._chunker = chonkie.RecursiveChunker.from_recipe( name=recipe, lang=lang, tokenizer=actual_tokenizer, chunk_size=actual_chunk_size, min_characters_per_chunk=min_characters_per_chunk, ) else: self._chunker = chonkie.RecursiveChunker( tokenizer=actual_tokenizer, chunk_size=actual_chunk_size, min_characters_per_chunk=min_characters_per_chunk, rules=chonkie.RecursiveRules() if rules is None else rules, ) else: raise ValueError(f"strategy must be 'fast' or 'recursive', got {strategy!r}") async def __call__(self, texts: pa.ChunkedArray) -> pa.Array: text_list = texts.to_pylist() batch_chunks = self._chunker.chunk_batch(text_list) chunks_list = [[chunk.text for chunk in doc_chunks] for doc_chunks in batch_chunks] return pa.array(chunks_list, type=pa.list_(pa.string()))
@udf(return_dtype=DataType.from_arrow(pa.list_(pa.struct({"id": pa.string(), "content": pa.string()})))) class MarkdownChunking(BatchColumnClassProtocol): """Markdown-aware structural text chunking operator. Preserves document hierarchy by splitting markdown into a virtual file tree. Heading-aware splitting that respects section boundaries and max token limits. Returns a list of structs per input row containing ``{"id": ..., "content": ...}``. The ``doc_name`` column provides a per-row label used to build virtual relative paths. Each input text is paired with its own doc name. Args: max_tokens: Maximum tokens per chunk. Defaults to 1024. Examples: .. code-block:: python from xpark.dataset.expressions import col from xpark.dataset import MarkdownChunking, from_items # Structural markdown chunking with unique path IDs ds = from_items([ {"text": "# Introduction1\n\nContent here...", "doc_name": "document1"}, {"text": "# Introduction2\n\nContent here...", "doc_name": "document2"}, ]) ds = ds.with_column( "chunks", MarkdownChunking(max_tokens=512) .options(num_workers={"CPU": 2}, batch_size=32) .with_column(col("text"), col("doc_name")), ) print(ds.take_all()) # Output: [{'text': '# Introduction1\n\nContent here...', 'doc_name': # 'document1', 'chunks': [{'id': 'markdown:document1/document1.md', # 'content': '# Introduction1\n\nContent here...'}]}, # {'text': '# Introduction2\n\nContent here...', 'doc_name': 'document2', # 'chunks': [{'id': 'markdown:document2/document2.md', 'content': '# Introduction2\n\nContent here...'}]}] """ def __init__(self, /, *, max_tokens: int = DEFAULT_MAX_CHUNK_TOKENS) -> None: self._splitter = MarkdownTreeSplitter(max_tokens=max_tokens) def __call__(self, texts: pa.ChunkedArray, doc_names: pa.ChunkedArray) -> pa.Array: text_list = texts.to_pylist() doc_name_list = doc_names.to_pylist() results: List[list[dict[str, str]]] = [] for text, doc_name in zip(text_list, doc_name_list): text = text or "" if not text.strip(): results.append([]) continue doc_name = sanitize_path_segment(doc_name or "", max_len=MAX_FILENAME_LENGTH, default_name="doc") rows: list[dict[str, str]] = [ { "id": f"{'dir' if is_dir else 'markdown'}:{path}", "content": chunk_text, } for path, chunk_text, is_dir in self._splitter.split(text, doc_name) ] results.append(rows) return pa.array(results, type=pa.list_(pa.struct({"id": pa.string(), "content": pa.string()})))