From cddaca586360a3cb924190c9ade7c5a55a1a420c Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sun, 9 Mar 2025 19:02:17 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=8A=9F=E8=83=BD=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E8=81=8A=E5=A4=A9=E6=97=B6=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=9F=A5=E8=AF=86=E5=BA=93=E7=9A=84=E6=8B=BC?= =?UTF-8?q?=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../yudao/module/ai/enums/AiChatRoleEnum.java | 6 +- .../dataobject/chat/AiChatConversationDO.java | 11 +- .../dal/dataobject/chat/AiChatMessageDO.java | 9 +- .../chat/AiChatConversationServiceImpl.java | 2 +- .../chat/AiChatMessageServiceImpl.java | 169 +++++++++++------- 5 files changed, 111 insertions(+), 86 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java index 6cb98c5629..1479274959 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java @@ -35,11 +35,7 @@ public enum AiChatRoleEnum { ### 微信 除此之外不要任何解释性语句。 """), - - AI_KNOWLEDGE_ROLE("知识库助手", """ - 给你提供一些数据参考:{info},请回答我的问题。 - 请你跟进数据参考与工具返回结果回复用户的请求。 - """); + ; /** * 角色名 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java index a9c956deae..358e994104 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java @@ -1,9 +1,8 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; @@ -65,14 +64,6 @@ public class AiChatConversationDO extends BaseDO { */ private Long roleId; - // TODO @芋艿:可优化,绑定多个知识库。前提,spring ai 支持 RerankModel 的封装 - /** - * 知识库编号 - *

- * 关联 {@link AiKnowledgeDO#getId()} - */ - private Long knowledgeId; - /** * 模型编号 * diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java index a026ec8193..d121d85785 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java @@ -1,14 +1,14 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; -import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler; import lombok.*; import org.springframework.ai.chat.messages.MessageType; @@ -71,13 +71,12 @@ public class AiChatMessageDO extends BaseDO { */ private Long roleId; - /** - * 段落编号数组 + * 知识库段落编号数组 * * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段 */ - @TableField(typeHandler = JacksonTypeHandler.class) + @TableField(typeHandler = LongListTypeHandler.class) private List segmentIds; /** diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java index 6483166a7f..6c35571c8f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java @@ -68,7 +68,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService // 2. 创建 AiChatConversationDO 聊天对话 AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false) - .setModelId(model.getId()).setModel(model.getModel()).setKnowledgeId(createReqVO.getKnowledgeId()) + .setModelId(model.getId()).setModel(model.getModel()) .setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts()); if (role != null) { conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage()); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 51bcc45d44..b95f7cd214 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -14,12 +14,14 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; -import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; +import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO; +import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO; +import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiModelService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; @@ -32,13 +34,13 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; import java.time.LocalDateTime; import java.util.*; +import java.util.stream.Collectors; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; @@ -56,12 +58,21 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N @Slf4j public class AiChatMessageServiceImpl implements AiChatMessageService { + /** + * 知识库转 {@link UserMessage} 的内容模版 + */ + private static final String KNOWLEDGE_USER_MESSAGE_TEMPLATE = "使用 标记中的内容作为本次对话的参考:\n\n" + + "%s\n\n" + // 多个 的拼接 + "回答要求:\n- 避免提及你是从 获取的知识。"; + @Resource private AiChatMessageMapper chatMessageMapper; @Resource private AiChatConversationService chatConversationService; @Resource + private AiChatRoleService chatRoleService; + @Resource private AiModelService modalService; @Resource private AiKnowledgeSegmentService knowledgeSegmentService; @@ -69,118 +80,143 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { @Transactional(rollbackFor = Exception.class) public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { // 1.1 校验对话存在 - AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId()); + AiChatConversationDO conversation = chatConversationService + .validateChatConversationExists(sendReqVO.getConversationId()); if (ObjUtil.notEqual(conversation.getUserId(), userId)) { throw exception(CHAT_CONVERSATION_NOT_EXISTS); } List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); // 1.2 校验模型 AiModelDO model = modalService.validateModel(conversation.getModelId()); - ChatModel chatModel = modalService.getChatModel(model.getKeyId()); + ChatModel chatModel = modalService.getChatModel(model.getId()); - // 2. 插入 user 发送消息 + // 2. 知识库找回 + List knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), + conversation); + + // 3. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, - userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext()); + userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(), + null); // 3.1 插入 assistant 接收消息 AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, - userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); + userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(), + knowledgeSegments); - // 3.2 召回段落 - List segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); - - // 3.3 创建 chat 需要的 Prompt - Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO); + // 3.2 创建 chat 需要的 Prompt + Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); ChatResponse chatResponse = chatModel.call(prompt); - // 3.4 段式返回 + // 3.3 段式返回 String newContent = chatResponse.getResult().getOutput().getText(); - chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent)); - return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) - .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)); + chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent)); + return new AiChatMessageSendRespVO() + .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) + .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class) + .setContent(newContent)); } @Override - public Flux> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) { + public Flux> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, + Long userId) { // 1.1 校验对话存在 - AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId()); + AiChatConversationDO conversation = chatConversationService + .validateChatConversationExists(sendReqVO.getConversationId()); if (ObjUtil.notEqual(conversation.getUserId(), userId)) { throw exception(CHAT_CONVERSATION_NOT_EXISTS); } List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); // 1.2 校验模型 AiModelDO model = modalService.validateModel(conversation.getModelId()); - StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId()); + StreamingChatModel chatModel = modalService.getChatModel(model.getId()); - // 2. 插入 user 发送消息 + // 2. 知识库找回 + List knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), + conversation); + + // 3. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, - userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext()); + userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(), + null); - // 3.1 插入 assistant 接收消息 + // 4.1 插入 assistant 接收消息 AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, - userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); + userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(), + knowledgeSegments); - // 3.2 召回段落 - List segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); - - // 3.3 构建 Prompt,并进行调用 - Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO); + // 4.2 构建 Prompt,并进行调用 + Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); Flux streamResponse = chatModel.stream(prompt); - // 3.4 流式返回 + // 4.3 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null; newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况 contentBuffer.append(newContent); // 响应结果 - return success(new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) - .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent))); + return success(new AiChatMessageSendRespVO() + .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) + .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class) + .setContent(newContent))); }).doOnComplete(() -> { // 忽略租户,因为 Flux 异步无法透传租户 - TenantUtils.executeIgnore(() -> - chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)) - .setContent(contentBuffer.toString()))); + TenantUtils.executeIgnore(() -> chatMessageMapper.updateById( + new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString()))); }).doOnError(throwable -> { log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable); // 忽略租户,因为 Flux 异步无法透传租户 - TenantUtils.executeIgnore(() -> - chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()))); + TenantUtils.executeIgnore(() -> chatMessageMapper.updateById( + new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR))); } - private List recallSegment(String content, Long knowledgeId) { - if (Objects.isNull(knowledgeId)) { + private List recallKnowledgeSegment(String content, + AiChatConversationDO conversation) { + // 1. 查询聊天角色 + if (conversation == null || conversation.getRoleId() == null) { return Collections.emptyList(); } -// return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content)); - return null; - } - - private Prompt buildPrompt(AiChatConversationDO conversation, List messages, List segmentList, - AiModelDO model, AiChatMessageSendReqVO sendReqVO) { - // 1. 构建 Prompt Message 列表 - List chatMessages = new ArrayList<>(); - - // 1.1 召回内容消息构建 - if (CollUtil.isNotEmpty(segmentList)) { - PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage()); - StringBuilder infoBuilder = StrUtil.builder(); - segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent())); - Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString())); - chatMessages.add(message); + AiChatRoleDO role = chatRoleService.getChatRole(conversation.getRoleId()); + if (role == null || CollUtil.isEmpty(role.getKnowledgeIds())) { + return Collections.emptyList(); } - // 1.2 system context 角色设定 + // 2. 遍历找回 + List knowledgeSegments = new ArrayList<>(); + for (Long knowledgeId : role.getKnowledgeIds()) { + knowledgeSegments.addAll(knowledgeSegmentService.searchKnowledgeSegment(new AiKnowledgeSegmentSearchReqBO() + .setKnowledgeId(knowledgeId).setContent(content))); + } + return knowledgeSegments; + } + + private Prompt buildPrompt(AiChatConversationDO conversation, List messages, + List knowledgeSegments, + AiModelDO model, AiChatMessageSendReqVO sendReqVO) { + List chatMessages = new ArrayList<>(); + // 1.1 System Context 角色设定 if (StrUtil.isNotBlank(conversation.getSystemMessage())) { chatMessages.add(new SystemMessage(conversation.getSystemMessage())); } - // 1.3 history message 历史消息 + + // 1.2 历史 history message 历史消息 List contextMessages = filterContextMessages(messages, conversation, sendReqVO); - contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); - // 1.4 user message 新发送消息 + contextMessages + .forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); + + // 1.3 当前 user message 新发送消息 chatMessages.add(new UserMessage(sendReqVO.getContent())); + // 1.4 知识库,通过 UserMessage 实现 + if (CollUtil.isNotEmpty(knowledgeSegments)) { + String reference = knowledgeSegments.stream() + .map(segment -> "\n" + segment.getContent() + "") + .collect(Collectors.joining("\n\n")); + chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference))); + } + // 2. 构建 ChatOptions 对象 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), @@ -199,8 +235,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { * @return 消息上下文 */ private List filterContextMessages(List messages, - AiChatConversationDO conversation, - AiChatMessageSendReqVO sendReqVO) { + AiChatConversationDO conversation, + AiChatMessageSendReqVO sendReqVO) { if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) { return Collections.emptyList(); } @@ -211,7 +247,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { continue; } AiChatMessageDO userMessage = CollUtil.get(messages, i - 1); - if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId()) + if (userMessage == null + || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId()) || StrUtil.isEmpty(assistantMessage.getContent())) { continue; } @@ -228,11 +265,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } private AiChatMessageDO createChatMessage(Long conversationId, Long replyId, - AiModelDO model, Long userId, Long roleId, - MessageType messageType, String content, Boolean useContext) { + AiModelDO model, Long userId, Long roleId, + MessageType messageType, String content, Boolean useContext, + List knowledgeSegments) { AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId) .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId) - .setType(messageType.getValue()).setContent(content).setUseContext(useContext); + .setType(messageType.getValue()).setContent(content).setUseContext(useContext) + .setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId)); message.setCreateTime(LocalDateTime.now()); chatMessageMapper.insert(message); return message;