from __future__ import annotations
import json
import logging
import re
from re import Match
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Tuple, overload
from xpark.dataset.datatype import DataType
from xpark.dataset.expressions import BatchColumnClassProtocol, udf
from xpark.dataset.import_utils import lazy_import
from xpark.dataset.utils import sanitize_path_segment
if TYPE_CHECKING:
import chonkie
import pyarrow as pa
import tiktoken
else:
chonkie = lazy_import("chonkie")
pa = lazy_import("pyarrow", rename="pa")
tiktoken = lazy_import("tiktoken")
logger = logging.getLogger("ray")
DEFAULT_MAX_CHUNK_TOKENS = 1024
DEFAULT_CHUNK_SIZE_BYTES = 4096
MAX_MERGED_FILENAME_LENGTH = 32
MAX_FILENAME_LENGTH = 32
DEFAULT_TOKENIZER = "cl100k_base"
class MarkdownTreeSplitter:
"""Split markdown into a virtual file-tree of ``(path, text, is_dir)`` tuples.
Each chunk respects *max_tokens*.
Algorithm derived from:
https://github.com/volcengine/OpenViking/blob/3064725bd63f70ab5054a661c3de5fb8b00e6c96/openviking/parse/parsers/markdown.py (SPDX-License-Identifier: Apache-2.0)
"""
_heading_re = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
_frontmatter_re = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
_exclude_re = re.compile(
r"^`{3,}[^\n]*\n.*?^`{3,}\s*$"
r"|<!--.*?-->" # HTML comment
r"|^(?: |\t)[^\n]+$",
re.DOTALL | re.MULTILINE,
)
def __init__(
self,
max_tokens: int = DEFAULT_MAX_CHUNK_TOKENS,
):
self.max_tokens = max_tokens
self._chunker: chonkie.RecursiveChunker = chonkie.RecursiveChunker(
chunk_size=max_tokens,
tokenizer=DEFAULT_TOKENIZER,
)
def _count_tokens(self, text: str) -> int:
return self._chunker._estimate_token_count(text)
def _para_split(self, content: str) -> List[str]:
chunks = self._chunker.chunk(content)
return [chunk.text.strip() for chunk in chunks] or [content]
def split(self, content: str, doc_name: str) -> List[Tuple[str, str, bool]]:
"""Return ``(relative_path, text, is_dir)`` tuples rooted at *doc_name*."""
results: List[Tuple[str, str, bool]] = []
content = self._strip_frontmatter(content)
headings: List[Tuple[int, int, str, int]] = self._extract_headings(content)
if self._count_tokens(content) <= self.max_tokens:
results.append((f"{doc_name}/{doc_name}.md", content, False))
return results
if not headings:
for i, part in enumerate(self._para_split(content), 1):
results.append((f"{doc_name}/{doc_name}_{i}.md", part, False))
return results
self._build_and_emit(content, headings, doc_name, results)
return results
def _strip_frontmatter(self, content: str) -> str:
match: Match[str] | None = self._frontmatter_re.match(content)
if not match:
return content
return content[match.end() :]
def _extract_headings(self, content: str) -> List[Tuple[int, int, str, int]]:
excluded: list[tuple[int, int]] = [(m.start(), m.end()) for m in self._exclude_re.finditer(content)]
# Single-pattern output is already position-sorted — no sort needed.
exc_idx = 0
exc_len = len(excluded)
headings: List[Tuple[int, int, str, int]] = []
for m in self._heading_re.finditer(content):
pos: int = m.start()
while exc_idx < exc_len and excluded[exc_idx][1] <= pos:
exc_idx += 1
if exc_idx < exc_len and excluded[exc_idx][0] <= pos:
continue
if pos > 0 and content[pos - 1] == "\\":
continue
headings.append((m.start(), m.end(), m.group(2).strip(), len(m.group(1))))
return headings
def _build_and_emit(
self,
content: str,
headings: List[Tuple[int, int, str, int]],
root_name: str,
results: List[Tuple[str, str, bool]],
) -> None:
n = len(content)
stack: List[Dict] = []
def _make_node(name: str, level: int, body_start: int, path: str) -> Dict[str, Any]:
# - name: heading name
# - level: heading level
# - path: relative path
# - body_start: start of body (after heading)
# - body_end: end of body (before the end of this section)
# - direct_end: end of body (before next heading)
# - children: list of child nodes
return {
"name": name,
"level": level,
"path": path,
"body_start": body_start,
"body_end": n,
"direct_end": n,
"children": [],
}
# Create a supernode to represent the whole document
supernode = _make_node(root_name, 0, 0, root_name)
supernode["body_end"] = supernode["direct_end"] = n
stack.append(supernode)
# Top-down tree build. Track sibling count at each level to ensure unique paths
for idx, (h_start, h_end, title, level) in enumerate(headings):
name = sanitize_path_segment(title, max_len=MAX_FILENAME_LENGTH, default_name="section")
# Close nodes that are at same or deeper level
while len(stack) > 1 and stack[-1]["level"] >= level:
closed = stack.pop()
closed["body_end"] = h_start
if not closed["direct_end"]:
closed["direct_end"] = h_start
parent = stack[-1]
parent_path = parent["path"]
if len(parent["children"]) == 0:
parent["children"] = []
parent["direct_end"] = h_start
# Generate unique path with sequential numbering
sibling_index = len(parent["children"]) + 1
unique_name = f"{sibling_index}-{name}"
node = _make_node(name, level, h_end + 1, f"{parent_path}/{unique_name}")
parent["children"].append(node)
stack.append(node)
while len(stack) > 1:
closed = stack.pop()
closed["body_end"] = n
if not closed["children"]:
closed["direct_end"] = n
preamble_end = headings[0][0]
if preamble_end > 0 and content[:preamble_end].strip():
supernode["direct_end"] = preamble_end
else:
supernode["direct_end"] = 0
def _text(node: Dict) -> str:
lvl = node["level"]
body = content[node["body_start"] : node["body_end"]].strip()
if lvl == 0:
return body
prefix = "#" * lvl
return f"{prefix} {node['name']}\n\n{body}" if body else f"{prefix} {node['name']}"
def _emit(node: Dict) -> int:
lvl = node["level"]
if not node["children"]:
# Leaf node.
tok = self._count_tokens(_text(node))
node["_tok"] = tok
node["_emitted"] = False # Leaf nodes are never emitted by themselves
return tok
# Track which children were already emitted (expanded)
unemitted_children = []
emitted_children = []
for child in node["children"]:
_emit(child)
if not child["_emitted"]:
unemitted_children.append(child)
else:
emitted_children.append(child)
prefix = "#" * lvl if lvl > 0 else ""
own_body = content[node["body_start"] : node["direct_end"]].strip()
own_text = (
(f"{prefix} {node['name']}\n\n{own_body}" if own_body else f"{prefix} {node['name']}")
if lvl > 0
else own_body
)
own_tok = self._count_tokens(own_text)
total_tok = own_tok + sum(c["_tok"] for c in node["children"])
node["_tok"] = total_tok
if total_tok <= self.max_tokens:
# Forward the current subtree to father node.
node["_emitted"] = False # Not emitted, will be handled by parent
return total_tok
# Need to split: create directory with child info
subdir = node["path"]
# Push [own text node] + ONLY unemitted children for merging
children_to_merge: List[Dict] = []
if own_text:
children_to_merge.append(
{
"name": node["name"],
"level": lvl,
"path": node["path"],
"body_start": node["body_start"],
"body_end": node["direct_end"],
"direct_end": node["direct_end"],
"children": [],
"_tok": own_tok,
"_emitted": False,
}
)
# Only merge children that haven't been emitted yet
children_to_merge.extend(unemitted_children)
# Track how many results we have before merging
results_before = len(results)
_merge_and_output(children_to_merge, subdir)
# Build JSON child references for the directory node
# Format: {"children": [{"type": "file"|"dir", "path": "..."}]}
child_refs = []
# First, add references to already-emitted child directories
for child in emitted_children:
child_refs.append({"type": "dir", "path": child["path"]})
# Second, add references to newly created file chunks
# These are the files just created by _merge_and_output
for i in range(results_before, len(results)):
child_refs.append({"type": "file", "path": results[i][0]})
# Serialize as JSON for easy deserialization in Python and other languages
# Using ensure_ascii=False preserves Unicode (Chinese, emoji, etc.)
dir_content = json.dumps({"children": child_refs}, ensure_ascii=False) if child_refs else ""
if len(child_refs) > 0:
results.append((subdir, dir_content, True))
node["_emitted"] = True # Mark this node as emitted
return total_tok
def _merge_and_output(siblings: List[Dict], parent_dir: str) -> None:
pending: List[Tuple[str, str, int]] = []
pending_tok = 0
def try_push(name: str, text: str, tok: int) -> bool:
nonlocal pending, pending_tok
if tok and pending_tok + tok <= self.max_tokens:
pending.append((name, text, tok))
pending_tok += tok
return True
if pending:
# Lazy flush.
self._save_merged(parent_dir, pending, results)
pending = []
pending_tok = 0
return False
for node in siblings:
# There should not exist nodes that have already been emitted
if node.get("_emitted", True):
raise ValueError(
f"Node {node['path']} has already been emitted "
f"(doc='{root_name}', parent_dir='{parent_dir}', "
f"level={node.get('level', '?')}, tok={node.get('_tok', '?')})"
)
node["_emitted"] = True
# Use the unique path segment (e.g., "1-introduction") instead of just name
# Extract the last segment from the full path which includes sequence number
path_segments = node["path"].split("/")
unique_name = path_segments[-1] # e.g., "1-introduction"
tok: int = node["_tok"]
text: str = _text(node)
if try_push(unique_name, text, tok):
continue
if tok > self.max_tokens:
for i, part in enumerate(self._para_split(text), 1):
results.append((f"{parent_dir}/{unique_name}_{i}.md", part, False))
else:
if not try_push(unique_name, text, tok):
raise ValueError(
f"Node {node['path']} push failed "
f"(doc='{root_name}', tok={tok}, max_tokens={self.max_tokens}, "
f"pending_tok={pending_tok}, pending_count={len(pending)})"
)
try_push("", "", 0) # Empty text used for flush.
# Emit from the supernode
_emit(supernode)
def _generate_merged_filename(self, sections: List[Tuple[str, str, int]]) -> str:
if not sections:
return "merged"
names: list[str] = [sec for sec, _, _ in sections]
if len(names) == 1:
name = names[0]
else:
suffix = f"_{len(names)}more"
name = f"{names[0]}{suffix}"
return sanitize_path_segment(name, max_len=MAX_MERGED_FILENAME_LENGTH, default_name="merged")
def _save_merged(
self,
parent_dir: str,
sections: List[Tuple[str, str, int]],
results: List[Tuple[str, str, bool]],
) -> None:
fname = self._generate_merged_filename(sections)
merged = "\n\n".join(c for _, c, _ in sections)
tok = self._count_tokens(merged)
if tok > self.max_tokens:
section_names = [name for name, _, _ in sections]
raise ValueError(
f"Merged text exceeds token limit: {tok} > {self.max_tokens} "
f"(parent_dir='{parent_dir}', sections={section_names})"
)
else:
results.append((f"{parent_dir}/{fname}.md", merged, False))
[docs]
@udf(return_dtype=DataType.from_arrow(pa.list_(pa.string())))
class TextChunking(BatchColumnClassProtocol):
"""Unified text chunking operator with multiple strategies.
Supports two chunking strategies via the ``strategy`` parameter:
- ``"fast"`` (default): SIMD-accelerated byte-based chunking (100+ GB/s throughput).
No tokenization overhead. Best for high-throughput pipelines.
- ``"recursive"``: Recursive token-based chunking with hierarchical splitting.
Uses multi-level rules (RecursiveRules) for better structure preservation.
Supports loading pre-configured recipes via ``recipe`` parameter.
Please refer to https://huggingface.co/datasets/chonkie-ai/recipes for recipes.
Args:
strategy (Literal["fast", "recursive"]): Chunking strategy to use.
Defaults to ``"fast"``.
chunk_size (int | None): Target chunk size. For ``"fast"`` strategy this is
measured in **bytes** (default 4096). For ``"recursive"`` strategy this is
measured in **tokens** (default 1024). If ``None``, the strategy-specific
default is used.
**Fast strategy parameters** (only used when ``strategy="fast"``):
delimiters (str): Single-byte delimiter characters to split on. Each character
in the string is treated as an individual delimiter. Defaults to ``"\\n.?"``.
pattern (str | None): Multi-byte pattern to split on. If set, overrides
``delimiters``. Defaults to ``None``.
prefix (bool): If ``True``, keep the delimiter/pattern at the start of the
next chunk instead of the end of the current chunk. Defaults to ``False``.
consecutive (bool): If ``True``, split at the **start** of consecutive
delimiter runs instead of the middle. Defaults to ``False``.
forward_fallback (bool): If ``True``, search forward for a delimiter when
none is found in the backward search window. Defaults to ``False``.
**Recursive strategy parameters** (only used when ``strategy="recursive"``):
tokenizer (str | None): Tokenizer to use for token counting. Can be a
``tiktoken`` encoding name (e.g. ``"cl100k_base"``, ``"o200k_base"``).
Defaults to ``"cl100k_base"``.
rules (RecursiveRules | None): Hierarchical splitting rules that define
multi-level delimiters for recursive chunking. If ``None``, the default
``RecursiveRules()`` is used, which splits by paragraphs → sentences →
punctuation → whitespace → characters. Defaults to ``None``.
min_characters_per_chunk (int): Minimum number of characters per chunk.
Chunks shorter than this will be merged with adjacent chunks.
Defaults to ``12``.
recipe (str | None): Name of a pre-configured recipe to load via
``RecursiveChunker.from_recipe()``. When set, ``rules`` is ignored and
the recipe's built-in rules are used instead. Common recipes include
``"default"``. See https://huggingface.co/datasets/chonkie-ai/recipes
for available recipes. Defaults to ``None``.
lang (str): Language code for recipe loading (e.g. ``"en"``, ``"zh"``).
Only used when ``recipe`` is set. Defaults to ``"en"``.
Returns a list of text chunks (``pa.list_(pa.string())``) per input row.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import TextChunking, from_items
text = "The quick brown fox. Jumps over the lazy dog. Hello world."
# Fast chunking (default) - byte-based, no tokenization
ds = from_items([text])
ds = ds.with_column(
"chunks_fast",
TextChunking(strategy="fast", chunk_size=20, delimiters=". \\n")
.options(num_workers={"CPU": 4}, batch_size=32)
.with_column(col("item")),
)
# Recursive token chunking with custom parameters
ds = ds.with_column(
"chunks_recursive",
TextChunking(
strategy="recursive",
chunk_size=8,
min_characters_per_chunk=1,
tokenizer="cl100k_base",
)
.options(num_workers={"CPU": 4}, batch_size=32)
.with_column(col("item")),
)
# Recursive chunking with pre-configured recipe
ds = ds.with_column(
"chunks_with_recipe",
TextChunking(
strategy="recursive",
recipe="default",
lang="en",
chunk_size=8,
min_characters_per_chunk=1,
)
.options(num_workers={"CPU": 4}, batch_size=32)
.with_column(col("item")),
)
print(ds.take_all())
# Output: [{"item": "The quick brown fox. Jumps over the lazy dog. Hello world.",
# "chunks_fast": ["The quick brown fox.", " Jumps over the ", "lazy dog. Hello ", "world."],
# "chunks_recursive": ["The quick brown fox. ", "Jumps over the lazy dog. ", "Hello world."],
# "chunks_with_recipe": ["The quick brown fox. ", "Jumps over the lazy dog. ", "Hello world."]}]
"""
@overload
def __init__(
self,
/,
*,
strategy: Literal["fast"] = "fast",
chunk_size: int = DEFAULT_CHUNK_SIZE_BYTES,
delimiters: str = "\n.?",
pattern: str | None = None,
prefix: bool = False,
consecutive: bool = False,
forward_fallback: bool = False,
) -> None: ...
@overload
def __init__(
self,
/,
*,
strategy: Literal["recursive"],
chunk_size: int = DEFAULT_MAX_CHUNK_TOKENS,
tokenizer: str | None = None,
rules: chonkie.RecursiveRules | None = None,
min_characters_per_chunk: int = 12,
# from_recipe support
recipe: str | None = None,
lang: str = "en",
) -> None: ...
def __init__(
self,
/,
*,
strategy: Literal["fast", "recursive"] = "fast",
chunk_size: int | None = None,
# Fast strategy params
delimiters: str = "\n.?",
pattern: str | None = None,
prefix: bool = False,
consecutive: bool = False,
forward_fallback: bool = False,
# Recursive strategy params
tokenizer: str | None = None,
rules: chonkie.RecursiveRules | None = None,
min_characters_per_chunk: int = 12,
recipe: str | None = None,
lang: str = "en",
) -> None:
self._chunker: chonkie.FastChunker | chonkie.RecursiveChunker
self._strategy = strategy
if strategy == "fast":
# Fast chunker: byte-based, SIMD-accelerated
actual_chunk_size = chunk_size if chunk_size is not None else DEFAULT_CHUNK_SIZE_BYTES
self._chunker = chonkie.FastChunker(
chunk_size=actual_chunk_size,
delimiters=delimiters,
pattern=pattern,
prefix=prefix,
consecutive=consecutive,
forward_fallback=forward_fallback,
)
elif strategy == "recursive":
# Recursive chunker: token-based with hierarchical splitting
actual_chunk_size = chunk_size if chunk_size is not None else DEFAULT_MAX_CHUNK_TOKENS
actual_tokenizer = tokenizer if tokenizer is not None else DEFAULT_TOKENIZER
if recipe is not None:
self._chunker = chonkie.RecursiveChunker.from_recipe(
name=recipe,
lang=lang,
tokenizer=actual_tokenizer,
chunk_size=actual_chunk_size,
min_characters_per_chunk=min_characters_per_chunk,
)
else:
self._chunker = chonkie.RecursiveChunker(
tokenizer=actual_tokenizer,
chunk_size=actual_chunk_size,
min_characters_per_chunk=min_characters_per_chunk,
rules=chonkie.RecursiveRules() if rules is None else rules,
)
else:
raise ValueError(f"strategy must be 'fast' or 'recursive', got {strategy!r}")
async def __call__(self, texts: pa.ChunkedArray) -> pa.Array:
text_list = texts.to_pylist()
batch_chunks = self._chunker.chunk_batch(text_list)
chunks_list = [[chunk.text for chunk in doc_chunks] for doc_chunks in batch_chunks]
return pa.array(chunks_list, type=pa.list_(pa.string()))
@udf(return_dtype=DataType.from_arrow(pa.list_(pa.struct({"id": pa.string(), "content": pa.string()}))))
class MarkdownChunking(BatchColumnClassProtocol):
"""Markdown-aware structural text chunking operator.
Preserves document hierarchy by splitting markdown into a virtual file tree.
Heading-aware splitting that respects section boundaries and max token limits.
Returns a list of structs per input row containing ``{"id": ..., "content": ...}``.
The ``doc_name`` column provides a per-row label used to build virtual relative
paths. Each input text is paired with its own doc name.
Args:
max_tokens: Maximum tokens per chunk. Defaults to 1024.
Examples:
.. code-block:: python
from xpark.dataset.expressions import col
from xpark.dataset import MarkdownChunking, from_items
# Structural markdown chunking with unique path IDs
ds = from_items([
{"text": "# Introduction1\n\nContent here...", "doc_name": "document1"},
{"text": "# Introduction2\n\nContent here...", "doc_name": "document2"},
])
ds = ds.with_column(
"chunks",
MarkdownChunking(max_tokens=512)
.options(num_workers={"CPU": 2}, batch_size=32)
.with_column(col("text"), col("doc_name")),
)
print(ds.take_all())
# Output: [{'text': '# Introduction1\n\nContent here...', 'doc_name':
# 'document1', 'chunks': [{'id': 'markdown:document1/document1.md',
# 'content': '# Introduction1\n\nContent here...'}]},
# {'text': '# Introduction2\n\nContent here...', 'doc_name': 'document2',
# 'chunks': [{'id': 'markdown:document2/document2.md', 'content': '# Introduction2\n\nContent here...'}]}]
"""
def __init__(self, /, *, max_tokens: int = DEFAULT_MAX_CHUNK_TOKENS) -> None:
self._splitter = MarkdownTreeSplitter(max_tokens=max_tokens)
def __call__(self, texts: pa.ChunkedArray, doc_names: pa.ChunkedArray) -> pa.Array:
text_list = texts.to_pylist()
doc_name_list = doc_names.to_pylist()
results: List[list[dict[str, str]]] = []
for text, doc_name in zip(text_list, doc_name_list):
text = text or ""
if not text.strip():
results.append([])
continue
doc_name = sanitize_path_segment(doc_name or "", max_len=MAX_FILENAME_LENGTH, default_name="doc")
rows: list[dict[str, str]] = [
{
"id": f"{'dir' if is_dir else 'markdown'}:{path}",
"content": chunk_text,
}
for path, chunk_text, is_dir in self._splitter.split(text, doc_name)
]
results.append(rows)
return pa.array(results, type=pa.list_(pa.struct({"id": pa.string(), "content": pa.string()})))