diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java index 58ba056595..245a19f7cb 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java @@ -4,6 +4,7 @@ import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; import java.time.LocalDateTime; +import java.util.List; @Schema(description = "管理后台 - AI 聊天消息发送 Response VO") @Data @@ -28,6 +29,12 @@ public class AiChatMessageSendRespVO { @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊") private String content; + @Schema(description = "知识库段落编号数组", example = "[1,2,3]") + private List segmentIds; + + @Schema(description = "知识库段落数组") + private List segments; + @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED) private LocalDateTime createTime; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index cd6c39d405..88f7144336 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -10,14 +10,17 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; +import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO; @@ -76,6 +79,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { private AiModelService modalService; @Resource private AiKnowledgeSegmentService knowledgeSegmentService; + @Resource + private AiKnowledgeDocumentService knowledgeDocumentService; @Transactional(rollbackFor = Exception.class) public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { @@ -108,18 +113,23 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); ChatResponse chatResponse = chatModel.call(prompt); - // 3.3 段式返回 + // 3.3 更新响应内容 String newContent = chatResponse.getResult().getOutput().getText(); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent)); + // 3.4 响应结果 + List segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class, + segment -> { + AiKnowledgeDocumentDO document = knowledgeDocumentService.getKnowledgeDocument(segment.getDocumentId()); + segment.setDocumentName(document != null ? document.getName() : null); + }); return new AiChatMessageSendRespVO() .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class) - .setContent(newContent)); + .setContent(newContent).setSegments(segments)); } @Override - public Flux> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, - Long userId) { + public Flux> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) { // 1.1 校验对话存在 AiChatConversationDO conversation = chatConversationService .validateChatConversationExists(sendReqVO.getConversationId()); @@ -132,8 +142,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { StreamingChatModel chatModel = modalService.getChatModel(model.getId()); // 2. 知识库找回 - List knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), - conversation); + List knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), conversation); // 3. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, @@ -152,14 +161,23 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // 4.3 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { + // 处理知识库的返回,只有首次才有 + List segments = null; + if (StrUtil.isEmpty(contentBuffer)) { + segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class, + segment -> TenantUtils.executeIgnore(() -> { + AiKnowledgeDocumentDO document = knowledgeDocumentService.getKnowledgeDocument(segment.getDocumentId()); + segment.setDocumentName(document != null ? document.getName() : null); + })); + } + // 响应结果 String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null; newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况 contentBuffer.append(newContent); - // 响应结果 return success(new AiChatMessageSendRespVO() .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class) - .setContent(newContent))); + .setContent(newContent).setSegments(segments))); }).doOnComplete(() -> { // 忽略租户,因为 Flux 异步无法透传租户 TenantUtils.executeIgnore(() -> chatMessageMapper.updateById( @@ -173,7 +191,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } private List recallKnowledgeSegment(String content, - AiChatConversationDO conversation) { + AiChatConversationDO conversation) { // 1. 查询聊天角色 if (conversation == null || conversation.getRoleId() == null) { return Collections.emptyList(); @@ -193,8 +211,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } private Prompt buildPrompt(AiChatConversationDO conversation, List messages, - List knowledgeSegments, - AiModelDO model, AiChatMessageSendReqVO sendReqVO) { + List knowledgeSegments, + AiModelDO model, AiChatMessageSendReqVO sendReqVO) { List chatMessages = new ArrayList<>(); // 1.1 System Context 角色设定 if (StrUtil.isNotBlank(conversation.getSystemMessage())) { @@ -235,8 +253,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { * @return 消息上下文 */ private List filterContextMessages(List messages, - AiChatConversationDO conversation, - AiChatMessageSendReqVO sendReqVO) { + AiChatConversationDO conversation, + AiChatMessageSendReqVO sendReqVO) { if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) { return Collections.emptyList(); } @@ -265,9 +283,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } private AiChatMessageDO createChatMessage(Long conversationId, Long replyId, - AiModelDO model, Long userId, Long roleId, - MessageType messageType, String content, Boolean useContext, - List knowledgeSegments) { + AiModelDO model, Long userId, Long roleId, + MessageType messageType, String content, Boolean useContext, + List knowledgeSegments) { AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId) .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId) .setType(messageType.getValue()).setContent(content).setUseContext(useContext)