From 6bada543d3e1272a649ba6715e8f460f84204281 Mon Sep 17 00:00:00 2001 From: "MS-QKBGNHPINNKD\\Administrator" Date: Mon, 1 Dec 2025 16:37:53 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E3=80=90ai=E3=80=91=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=99=BA=E8=83=BD=E6=96=87=E6=A1=A3=E5=88=87=E7=89=87?= =?UTF-8?q?=E7=AD=96=E7=95=A5=EF=BC=8C=E6=94=AF=E6=8C=81=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E8=AF=86=E5=88=AB=20Markdown=20QA=20=E5=92=8C=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E5=8C=96=E5=88=87=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 新增 AiDocumentSplitStrategyEnum 枚举,支持 5 种切片策略:自动识别、Token 切分、段落切分、Markdown QA 切分、语义切分 2. 实现 MarkdownQaSplitter:专门处理 Markdown QA 格式文档,识别二级标题作为问题,保持问答对完整性 3. 实现 SemanticTextSplitter:语义化切片器,优先在段落和句子边界处切分,避免截断语义 4. 优化 AiKnowledgeSegmentServiceImpl:增加自动检测文档类型功能,根据文档特征选择最佳切片策略 5. 启用 AI 模块(pom.xml) 6. 修复 AiKnowledgeSegmentPageReqVO:documentId 类型从 Integer 改为 Long 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- pom.xml | 2 +- .../segment/AiKnowledgeSegmentPageReqVO.java | 2 +- .../ai/enums/AiDocumentSplitStrategyEnum.java | 65 ++++ .../AiKnowledgeSegmentServiceImpl.java | 128 ++++++- .../splitter/MarkdownQaSplitter.java | 349 ++++++++++++++++++ .../splitter/SemanticTextSplitter.java | 293 +++++++++++++++ 6 files changed, 826 insertions(+), 13 deletions(-) create mode 100644 yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/AiDocumentSplitStrategyEnum.java create mode 100644 yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/MarkdownQaSplitter.java create mode 100644 yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/SemanticTextSplitter.java diff --git a/pom.xml b/pom.xml index 1de1b27e92..c60954fbe5 100644 --- a/pom.xml +++ b/pom.xml @@ -23,7 +23,7 @@ - + yudao-module-ai diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/segment/AiKnowledgeSegmentPageReqVO.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/segment/AiKnowledgeSegmentPageReqVO.java index f53d5be076..dd3b90300b 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/segment/AiKnowledgeSegmentPageReqVO.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/segment/AiKnowledgeSegmentPageReqVO.java @@ -11,7 +11,7 @@ import lombok.Data; public class AiKnowledgeSegmentPageReqVO extends PageParam { @Schema(description = "文档编号", example = "1") - private Integer documentId; + private Long documentId; @Schema(description = "分段内容关键字", example = "Java 开发") private String content; diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/AiDocumentSplitStrategyEnum.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/AiDocumentSplitStrategyEnum.java new file mode 100644 index 0000000000..2c9f657579 --- /dev/null +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/AiDocumentSplitStrategyEnum.java @@ -0,0 +1,65 @@ +package cn.iocoder.yudao.module.ai.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * AI 知识库文档切片策略枚举 + * + * @author runzhen + */ +@AllArgsConstructor +@Getter +public enum AiDocumentSplitStrategyEnum { + + /** + * 自动识别文档类型并选择最佳切片策略 + */ + AUTO("auto", "自动识别"), + + /** + * 基于 Token 数量机械切分(默认策略) + */ + TOKEN("token", "Token 切分"), + + /** + * 按段落切分(以双换行符为分隔) + */ + PARAGRAPH("paragraph", "段落切分"), + + /** + * Markdown QA 格式专用切片器 + * 识别二级标题作为问题,保持问答对完整性 + * 长答案智能切分但保留问题作为上下文 + */ + MARKDOWN_QA("markdown_qa", "Markdown QA 切分"), + + /** + * 语义化切分,保留句子完整性 + * 在段落和句子边界处切分,避免截断 + */ + SEMANTIC("semantic", "语义切分"); + + /** + * 策略代码 + */ + private final String code; + + /** + * 策略名称 + */ + private final String name; + + /** + * 根据代码获取枚举 + */ + public static AiDocumentSplitStrategyEnum fromCode(String code) { + for (AiDocumentSplitStrategyEnum strategy : values()) { + if (strategy.getCode().equals(code)) { + return strategy; + } + } + return AUTO; // 默认返回自动识别 + } + +} diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java index 43c7e9cefe..51a1ce94d5 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java @@ -4,6 +4,7 @@ import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.ListUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; + import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; @@ -15,8 +16,11 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; +import cn.iocoder.yudao.module.ai.enums.AiDocumentSplitStrategyEnum; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO; +import cn.iocoder.yudao.module.ai.service.knowledge.splitter.MarkdownQaSplitter; +import cn.iocoder.yudao.module.ai.service.knowledge.splitter.SemanticTextSplitter; import cn.iocoder.yudao.module.ai.service.model.AiModelService; import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankOptions; import com.alibaba.cloud.ai.model.RerankModel; @@ -39,8 +43,7 @@ import java.util.*; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; -import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG; -import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS; +import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*; import static org.springframework.ai.vectorstore.SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL; /** @@ -95,16 +98,20 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService AiKnowledgeDO knowledgeDO = knowledgeService.validateKnowledgeExists(documentDO.getKnowledgeId()); VectorStore vectorStore = getVectorStoreById(knowledgeDO); - // 2. 文档切片 - List documentSegments = splitContentByToken(content, documentDO.getSegmentMaxTokens()); + // 2. 文档切片(使用自动检测策略) + List documentSegments = splitContentByStrategy(content, documentDO.getSegmentMaxTokens(), + AiDocumentSplitStrategyEnum.AUTO, documentDO.getUrl()); // 3.1 存储切片 List segmentDOs = convertList(documentSegments, segment -> { if (StrUtil.isEmpty(segment.getText())) { return null; } - return new AiKnowledgeSegmentDO().setKnowledgeId(documentDO.getKnowledgeId()).setDocumentId(documentId) - .setContent(segment.getText()).setContentLength(segment.getText().length()) + return new AiKnowledgeSegmentDO() + .setKnowledgeId(documentDO.getKnowledgeId()) + .setDocumentId(documentId) + .setContent(segment.getText()) + .setContentLength(segment.getText().length()) .setVectorId(AiKnowledgeSegmentDO.VECTOR_ID_EMPTY) .setTokens(tokenCountEstimator.estimate(segment.getText())) .setStatus(CommonStatusEnum.ENABLE.getStatus()); @@ -295,10 +302,13 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService // 1. 读取 URL 内容 String content = knowledgeDocumentService.readUrl(url); - // 2. 文档切片 - List documentSegments = splitContentByToken(content, segmentMaxTokens); + // 2. 自动检测文档类型并选择策略 + AiDocumentSplitStrategyEnum strategy = detectDocumentStrategy(content, url); - // 3. 转换为段落对象 + // 3. 文档切片 + List documentSegments = splitContentByStrategy(content, segmentMaxTokens, strategy, url); + + // 4. 转换为段落对象 return convertList(documentSegments, segment -> { if (StrUtil.isEmpty(segment.getText())) { return null; @@ -333,11 +343,107 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService return getVectorStoreById(knowledge); } - private static List splitContentByToken(String content, Integer segmentMaxTokens) { - TextSplitter textSplitter = buildTokenTextSplitter(segmentMaxTokens); + /** + * 根据策略切分内容 + * + * @param content 文档内容 + * @param segmentMaxTokens 分段的最大 Token 数 + * @param strategy 切片策略 + * @param url 文档 URL(用于自动检测文件类型) + * @return 切片后的文档列表 + */ + private List splitContentByStrategy(String content, Integer segmentMaxTokens, + AiDocumentSplitStrategyEnum strategy, String url) { + // 自动检测策略 + if (strategy == AiDocumentSplitStrategyEnum.AUTO) { + strategy = detectDocumentStrategy(content, url); + log.info("[splitContentByStrategy][自动检测到文档策略: {}]", strategy.getName()); + } + + TextSplitter textSplitter; + switch (strategy) { + case MARKDOWN_QA: + textSplitter = new MarkdownQaSplitter(segmentMaxTokens); + break; + case SEMANTIC: + textSplitter = new SemanticTextSplitter(segmentMaxTokens); + break; + case PARAGRAPH: + textSplitter = new SemanticTextSplitter(segmentMaxTokens, 0); // 段落切分,无重叠 + break; + case TOKEN: + default: + textSplitter = buildTokenTextSplitter(segmentMaxTokens); + break; + } + return textSplitter.apply(Collections.singletonList(new Document(content))); } + /** + * 自动检测文档类型并选择切片策略 + * + * @param content 文档内容 + * @param url 文档 URL + * @return 推荐的切片策略 + */ + private AiDocumentSplitStrategyEnum detectDocumentStrategy(String content, String url) { + if (StrUtil.isEmpty(content)) { + return AiDocumentSplitStrategyEnum.TOKEN; + } + + // 1. 检测 Markdown QA 格式 + if (isMarkdownQaFormat(content, url)) { + return AiDocumentSplitStrategyEnum.MARKDOWN_QA; + } + + // 2. 检测普通 Markdown 文档 + if (isMarkdownDocument(url)) { + return AiDocumentSplitStrategyEnum.SEMANTIC; + } + + // 3. 默认使用语义切分(比 Token 切分更智能) + return AiDocumentSplitStrategyEnum.SEMANTIC; + } + + /** + * 检测是否为 Markdown QA 格式 + * 特征:包含多个二级标题(## )且标题后紧跟答案内容 + */ + private boolean isMarkdownQaFormat(String content, String url) { + // 文件扩展名判断 + if (StrUtil.isNotEmpty(url) && !url.toLowerCase().endsWith(".md")) { + return false; + } + + // 统计二级标题数量 + long h2Count = content.lines() + .filter(line -> line.trim().startsWith("## ")) + .count(); + + // 至少包含 2 个二级标题才认为是 QA 格式 + if (h2Count < 2) { + return false; + } + + // 检查标题占比(QA 文档标题行数相对较多) + long totalLines = content.lines().count(); + double h2Ratio = (double) h2Count / totalLines; + + // 如果二级标题占比超过 10%,认为是 QA 格式 + return h2Ratio > 0.1; + } + + /** + * 检测是否为 Markdown 文档 + */ + private boolean isMarkdownDocument(String url) { + return StrUtil.isNotEmpty(url) && url.toLowerCase().endsWith(".md"); + } + + /** + * 构建基于 Token 的文本切片器(原有逻辑保留) + */ private static TextSplitter buildTokenTextSplitter(Integer segmentMaxTokens) { return TokenTextSplitter.builder() .withChunkSize(segmentMaxTokens) diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/MarkdownQaSplitter.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/MarkdownQaSplitter.java new file mode 100644 index 0000000000..2957f4140e --- /dev/null +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/MarkdownQaSplitter.java @@ -0,0 +1,349 @@ +package cn.iocoder.yudao.module.ai.service.knowledge.splitter; + +import cn.hutool.core.util.StrUtil; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.transformer.splitter.TextSplitter; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Markdown QA 格式专用切片器 + * + *

功能特点: + *

    + *
  • 识别二级标题(## )作为问题标记
  • + *
  • 短 QA 对保持完整(不超过 Token 限制)
  • + *
  • 长答案智能切分,每个片段保留完整问题作为上下文
  • + *
  • 支持自定义 Token 估算器
  • + *
+ * + * @author runzhen + */ +@Slf4j +public class MarkdownQaSplitter extends TextSplitter { + + /** + * 二级标题正则:匹配 "## " 开头的行 + */ + private static final Pattern H2_PATTERN = Pattern.compile("^##\\s+(.+)$", Pattern.MULTILINE); + + /** + * 段落分隔符:双换行 + */ + private static final String PARAGRAPH_SEPARATOR = "\n\n"; + + /** + * 句子分隔符 + */ + private static final Pattern SENTENCE_PATTERN = Pattern.compile("[。!?.!?]\\s*"); + + /** + * 分段的最大 Token 数 + */ + private final int chunkSize; + + /** + * Token 估算器(简单实现:中文按字符数,英文按单词数的 1.3 倍) + */ + private final TokenEstimator tokenEstimator; + + public MarkdownQaSplitter(int chunkSize) { + this.chunkSize = chunkSize; + this.tokenEstimator = new SimpleTokenEstimator(); + } + + @Override + protected List splitText(String text) { + if (StrUtil.isEmpty(text)) { + return Collections.emptyList(); + } + + List result = new ArrayList<>(); + + // 解析 QA 对 + List qaPairs = parseQaPairs(text); + + if (qaPairs.isEmpty()) { + // 如果没有识别到 QA 格式,按段落切分 + return fallbackSplit(text); + } + + // 处理每个 QA 对 + for (QaPair qaPair : qaPairs) { + result.addAll(splitQaPair(qaPair)); + } + + return result; + } + + /** + * 解析 Markdown QA 对 + */ + private List parseQaPairs(String content) { + List qaPairs = new ArrayList<>(); + Matcher matcher = H2_PATTERN.matcher(content); + + List headingPositions = new ArrayList<>(); + List questions = new ArrayList<>(); + + // 找到所有二级标题位置 + while (matcher.find()) { + headingPositions.add(matcher.start()); + questions.add(matcher.group(1).trim()); + } + + if (headingPositions.isEmpty()) { + return qaPairs; + } + + // 提取每个 QA 对 + for (int i = 0; i < headingPositions.size(); i++) { + int start = headingPositions.get(i); + int end = (i + 1 < headingPositions.size()) + ? headingPositions.get(i + 1) + : content.length(); + + String qaText = content.substring(start, end).trim(); + String question = questions.get(i); + + // 提取答案部分(去掉问题标题) + String answer = qaText.substring(qaText.indexOf('\n') + 1).trim(); + + qaPairs.add(new QaPair(question, answer, qaText)); + } + + return qaPairs; + } + + /** + * 切分单个 QA 对 + */ + private List splitQaPair(QaPair qaPair) { + List chunks = new ArrayList<>(); + + String fullQa = qaPair.fullText; + int qaTokens = tokenEstimator.estimate(fullQa); + + // 如果整个 QA 对不超过限制,保持完整 + if (qaTokens <= chunkSize) { + chunks.add(fullQa); + return chunks; + } + + // 长答案需要切分 + log.debug("QA 对超过 Token 限制 ({} > {}),开始智能切分: {}", + qaTokens, chunkSize, qaPair.question); + + List answerChunks = splitLongAnswer(qaPair.answer, qaPair.question); + + for (String answerChunk : answerChunks) { + // 每个片段都包含完整问题 + String chunkText = "## " + qaPair.question + "\n" + answerChunk; + chunks.add(chunkText); + } + + return chunks; + } + + /** + * 切分长答案 + */ + private List splitLongAnswer(String answer, String question) { + List chunks = new ArrayList<>(); + + // 预留问题的 Token 空间 + String questionHeader = "## " + question + "\n"; + int questionTokens = tokenEstimator.estimate(questionHeader); + int availableTokens = chunkSize - questionTokens - 10; // 预留 10 个 Token 的缓冲 + + // 先按段落切分 + String[] paragraphs = answer.split(PARAGRAPH_SEPARATOR); + + StringBuilder currentChunk = new StringBuilder(); + int currentTokens = 0; + + for (String paragraph : paragraphs) { + if (StrUtil.isEmpty(paragraph)) { + continue; + } + + int paragraphTokens = tokenEstimator.estimate(paragraph); + + // 如果单个段落就超过限制,需要按句子切分 + if (paragraphTokens > availableTokens) { + // 先保存当前块 + if (currentChunk.length() > 0) { + chunks.add(currentChunk.toString().trim()); + currentChunk = new StringBuilder(); + currentTokens = 0; + } + + // 按句子切分长段落 + chunks.addAll(splitLongParagraph(paragraph, availableTokens)); + continue; + } + + // 如果加上这个段落会超过限制 + if (currentTokens + paragraphTokens > availableTokens && currentChunk.length() > 0) { + chunks.add(currentChunk.toString().trim()); + currentChunk = new StringBuilder(); + currentTokens = 0; + } + + if (currentChunk.length() > 0) { + currentChunk.append("\n\n"); + } + currentChunk.append(paragraph); + currentTokens += paragraphTokens; + } + + // 添加最后一块 + if (currentChunk.length() > 0) { + chunks.add(currentChunk.toString().trim()); + } + + return chunks.isEmpty() ? Collections.singletonList(answer) : chunks; + } + + /** + * 切分长段落(按句子) + */ + private List splitLongParagraph(String paragraph, int availableTokens) { + List chunks = new ArrayList<>(); + String[] sentences = SENTENCE_PATTERN.split(paragraph); + + StringBuilder currentChunk = new StringBuilder(); + int currentTokens = 0; + + for (String sentence : sentences) { + if (StrUtil.isEmpty(sentence)) { + continue; + } + + int sentenceTokens = tokenEstimator.estimate(sentence); + + // 如果单个句子就超过限制,强制切分 + if (sentenceTokens > availableTokens) { + if (currentChunk.length() > 0) { + chunks.add(currentChunk.toString().trim()); + currentChunk = new StringBuilder(); + currentTokens = 0; + } + chunks.add(sentence.trim()); + continue; + } + + if (currentTokens + sentenceTokens > availableTokens && currentChunk.length() > 0) { + chunks.add(currentChunk.toString().trim()); + currentChunk = new StringBuilder(); + currentTokens = 0; + } + + currentChunk.append(sentence); + currentTokens += sentenceTokens; + } + + if (currentChunk.length() > 0) { + chunks.add(currentChunk.toString().trim()); + } + + return chunks.isEmpty() ? Collections.singletonList(paragraph) : chunks; + } + + /** + * 降级切分策略(当未识别到 QA 格式时) + */ + private List fallbackSplit(String content) { + List chunks = new ArrayList<>(); + String[] paragraphs = content.split(PARAGRAPH_SEPARATOR); + + StringBuilder currentChunk = new StringBuilder(); + int currentTokens = 0; + + for (String paragraph : paragraphs) { + if (StrUtil.isEmpty(paragraph)) { + continue; + } + + int paragraphTokens = tokenEstimator.estimate(paragraph); + + if (currentTokens + paragraphTokens > chunkSize && currentChunk.length() > 0) { + chunks.add(currentChunk.toString().trim()); + currentChunk = new StringBuilder(); + currentTokens = 0; + } + + if (currentChunk.length() > 0) { + currentChunk.append("\n\n"); + } + currentChunk.append(paragraph); + currentTokens += paragraphTokens; + } + + if (currentChunk.length() > 0) { + chunks.add(currentChunk.toString().trim()); + } + + return chunks.isEmpty() ? Collections.singletonList(content) : chunks; + } + + /** + * QA 对数据结构 + */ + private static class QaPair { + String question; + String answer; + String fullText; + + QaPair(String question, String answer, String fullText) { + this.question = question; + this.answer = answer; + this.fullText = fullText; + } + } + + /** + * Token 估算器接口 + */ + public interface TokenEstimator { + int estimate(String text); + } + + /** + * 简单的 Token 估算器实现 + * 中文:1 字符 ≈ 1 Token + * 英文:1 单词 ≈ 1.3 Token + */ + private static class SimpleTokenEstimator implements TokenEstimator { + @Override + public int estimate(String text) { + if (StrUtil.isEmpty(text)) { + return 0; + } + + int chineseChars = 0; + int englishWords = 0; + + // 简单统计中英文 + for (char c : text.toCharArray()) { + if (c >= 0x4E00 && c <= 0x9FA5) { + chineseChars++; + } + } + + // 英文单词估算 + String[] words = text.split("\\s+"); + for (String word : words) { + if (word.matches(".*[a-zA-Z].*")) { + englishWords++; + } + } + + return chineseChars + (int) (englishWords * 1.3); + } + } +} diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/SemanticTextSplitter.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/SemanticTextSplitter.java new file mode 100644 index 0000000000..64160a41a0 --- /dev/null +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/SemanticTextSplitter.java @@ -0,0 +1,293 @@ +package cn.iocoder.yudao.module.ai.service.knowledge.splitter; + +import cn.hutool.core.util.StrUtil; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.transformer.splitter.TextSplitter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.regex.Pattern; + +/** + * 语义化文本切片器 + * + *

功能特点: + *

    + *
  • 优先在段落边界(双换行)处切分
  • + *
  • 其次在句子边界(句号、问号、感叹号)处切分
  • + *
  • 避免在句子中间截断,保持语义完整性
  • + *
  • 支持中英文标点符号识别
  • + *
+ * + * @author runzhen + */ +@Slf4j +public class SemanticTextSplitter extends TextSplitter { + + /** + * 分段的最大 Token 数 + */ + private final int chunkSize; + + /** + * 段落重叠大小(用于保持上下文连贯性) + */ + private final int chunkOverlap; + + /** + * 段落分隔符(按优先级排序) + */ + private static final List PARAGRAPH_SEPARATORS = Arrays.asList( + "\n\n\n", // 三个换行 + "\n\n", // 双换行 + "\n" // 单换行 + ); + + /** + * 句子结束标记(中英文标点) + */ + private static final Pattern SENTENCE_END_PATTERN = Pattern.compile( + "[。!?.!?]+[\\s\"'))】\\]]*" + ); + + /** + * Token 估算器 + */ + private final MarkdownQaSplitter.TokenEstimator tokenEstimator; + + public SemanticTextSplitter(int chunkSize, int chunkOverlap) { + this.chunkSize = chunkSize; + this.chunkOverlap = Math.min(chunkOverlap, chunkSize / 2); // 重叠不超过一半 + this.tokenEstimator = new SimpleTokenEstimator(); + } + + public SemanticTextSplitter(int chunkSize) { + this(chunkSize, 50); // 默认重叠 50 个 Token + } + + @Override + protected List splitText(String text) { + if (StrUtil.isEmpty(text)) { + return Collections.emptyList(); + } + + return splitTextRecursive(text); + } + + /** + * 切分文本(递归策略) + */ + private List splitTextRecursive(String text) { + List chunks = new ArrayList<>(); + + // 如果文本不超过限制,直接返回 + int textTokens = tokenEstimator.estimate(text); + if (textTokens <= chunkSize) { + chunks.add(text.trim()); + return chunks; + } + + // 尝试按不同分隔符切分 + List splits = null; + String usedSeparator = null; + + for (String separator : PARAGRAPH_SEPARATORS) { + if (text.contains(separator)) { + splits = Arrays.asList(text.split(Pattern.quote(separator))); + usedSeparator = separator; + break; + } + } + + // 如果没有找到段落分隔符,按句子切分 + if (splits == null || splits.size() == 1) { + splits = splitBySentences(text); + usedSeparator = ""; // 句子切分不需要分隔符 + } + + // 合并小片段 + chunks = mergeSplits(splits, usedSeparator); + + return chunks; + } + + /** + * 按句子切分 + */ + private List splitBySentences(String text) { + List sentences = new ArrayList<>(); + int lastEnd = 0; + + java.util.regex.Matcher matcher = SENTENCE_END_PATTERN.matcher(text); + while (matcher.find()) { + String sentence = text.substring(lastEnd, matcher.end()).trim(); + if (StrUtil.isNotEmpty(sentence)) { + sentences.add(sentence); + } + lastEnd = matcher.end(); + } + + // 添加剩余部分 + if (lastEnd < text.length()) { + String remaining = text.substring(lastEnd).trim(); + if (StrUtil.isNotEmpty(remaining)) { + sentences.add(remaining); + } + } + + return sentences.isEmpty() ? Collections.singletonList(text) : sentences; + } + + /** + * 合并切分后的小片段 + */ + private List mergeSplits(List splits, String separator) { + List chunks = new ArrayList<>(); + List currentChunks = new ArrayList<>(); + int currentLength = 0; + + for (String split : splits) { + if (StrUtil.isEmpty(split)) { + continue; + } + + int splitTokens = tokenEstimator.estimate(split); + + // 如果单个片段就超过限制,进一步递归切分 + if (splitTokens > chunkSize) { + // 先保存当前累积的块 + if (!currentChunks.isEmpty()) { + String chunkText = String.join(separator, currentChunks); + chunks.add(chunkText.trim()); + currentChunks.clear(); + currentLength = 0; + } + + // 递归切分大片段 + if (!separator.isEmpty()) { + // 如果是段落分隔符,尝试按句子切分 + chunks.addAll(splitTextRecursive(split)); + } else { + // 如果已经是句子级别,强制按字符切分 + chunks.addAll(forceSplitLongText(split)); + } + continue; + } + + // 计算加上分隔符的 Token 数 + int separatorTokens = StrUtil.isEmpty(separator) ? 0 : tokenEstimator.estimate(separator); + + // 如果加上这个片段会超过限制 + if (!currentChunks.isEmpty() && currentLength + splitTokens + separatorTokens > chunkSize) { + // 保存当前块 + String chunkText = String.join(separator, currentChunks); + chunks.add(chunkText.trim()); + + // 处理重叠:保留最后几个片段 + currentChunks = getOverlappingChunks(currentChunks, separator); + currentLength = estimateTokens(currentChunks, separator); + } + + currentChunks.add(split); + currentLength += splitTokens + separatorTokens; + } + + // 添加最后一块 + if (!currentChunks.isEmpty()) { + String chunkText = String.join(separator, currentChunks); + chunks.add(chunkText.trim()); + } + + return chunks; + } + + /** + * 获取重叠的片段(用于保持上下文) + */ + private List getOverlappingChunks(List chunks, String separator) { + if (chunkOverlap == 0 || chunks.isEmpty()) { + return new ArrayList<>(); + } + + List overlapping = new ArrayList<>(); + int tokens = 0; + + // 从后往前取片段,直到达到重叠大小 + for (int i = chunks.size() - 1; i >= 0; i--) { + String chunk = chunks.get(i); + int chunkTokens = tokenEstimator.estimate(chunk); + + if (tokens + chunkTokens > chunkOverlap) { + break; + } + + overlapping.add(0, chunk); + tokens += chunkTokens + (StrUtil.isEmpty(separator) ? 0 : tokenEstimator.estimate(separator)); + } + + return overlapping; + } + + /** + * 估算片段列表的总 Token 数 + */ + private int estimateTokens(List chunks, String separator) { + int total = 0; + for (int i = 0; i < chunks.size(); i++) { + total += tokenEstimator.estimate(chunks.get(i)); + if (i < chunks.size() - 1 && StrUtil.isNotEmpty(separator)) { + total += tokenEstimator.estimate(separator); + } + } + return total; + } + + /** + * 强制切分长文本(当语义切分失败时) + */ + private List forceSplitLongText(String text) { + List chunks = new ArrayList<>(); + int charsPerChunk = (int) (chunkSize * 0.8); // 保守估计 + + for (int i = 0; i < text.length(); i += charsPerChunk) { + int end = Math.min(i + charsPerChunk, text.length()); + String chunk = text.substring(i, end); + chunks.add(chunk.trim()); + } + + log.warn("文本过长,已强制按字符切分,可能影响语义完整性"); + return chunks; + } + + /** + * 简单的 Token 估算器实现 + */ + private static class SimpleTokenEstimator implements MarkdownQaSplitter.TokenEstimator { + @Override + public int estimate(String text) { + if (StrUtil.isEmpty(text)) { + return 0; + } + + int chineseChars = 0; + int englishWords = 0; + + for (char c : text.toCharArray()) { + if (c >= 0x4E00 && c <= 0x9FA5) { + chineseChars++; + } + } + + String[] words = text.split("\\s+"); + for (String word : words) { + if (word.matches(".*[a-zA-Z].*")) { + englishWords++; + } + } + + return chineseChars + (int) (englishWords * 1.3); + } + } +}