From b709af11a138b9ebadc0877980a38071304ebe6a Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sun, 9 Mar 2025 21:00:34 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=8A=9F=E8=83=BD=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E8=81=8A=E5=A4=A9=E8=AE=B0=E5=BD=95?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E6=97=B6=EF=BC=8C=E5=A2=9E=E5=8A=A0=20segmen?= =?UTF-8?q?ts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../admin/chat/AiChatMessageController.java | 51 +++++++++++++++++-- .../chat/vo/message/AiChatMessageRespVO.java | 25 +++++++++ .../knowledge/AiKnowledgeController.java | 1 - .../dal/dataobject/chat/AiChatMessageDO.java | 18 +++---- .../chat/AiChatMessageServiceImpl.java | 2 +- .../AiKnowledgeDocumentServiceImpl.java | 5 +- .../knowledge/AiKnowledgeSegmentService.java | 22 ++++++++ .../AiKnowledgeSegmentServiceImpl.java | 15 ++++-- 8 files changed, 116 insertions(+), 23 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java index 43ae9a40db..b4fa8ab88c 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java @@ -12,9 +12,13 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.chat.AiChatMessageService; +import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService; +import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -32,7 +36,7 @@ import java.util.List; import java.util.Map; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; -import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet; +import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.*; import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; @Tag(name = "管理后台 - 聊天消息") @@ -47,6 +51,10 @@ public class AiChatMessageController { private AiChatConversationService chatConversationService; @Resource private AiChatRoleService chatRoleService; + @Resource + private AiKnowledgeSegmentService knowledgeSegmentService; + @Resource + private AiKnowledgeDocumentService knowledgeDocumentService; @Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢") @PostMapping("/send") @@ -56,7 +64,8 @@ public class AiChatMessageController { @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) { + public Flux> sendChatMessageStream( + @Valid @RequestBody AiChatMessageSendReqVO sendReqVO) { return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId()); } @@ -69,8 +78,38 @@ public class AiChatMessageController { if (conversation == null || ObjUtil.notEqual(conversation.getUserId(), getLoginUserId())) { return success(Collections.emptyList()); } + // 1. 获取消息列表 List messageList = chatMessageService.getChatMessageListByConversationId(conversationId); - return success(BeanUtils.toBean(messageList, AiChatMessageRespVO.class)); + if (CollUtil.isEmpty(messageList)) { + return success(Collections.emptyList()); + } + + // 2. 拼接数据,主要是知识库段落信息 + Map segmentMap = knowledgeSegmentService.getKnowledgeSegmentMap(convertListByFlatMap(messageList, + message -> CollUtil.isEmpty(message.getSegmentIds()) ? null : message.getSegmentIds().stream())); + Map documentMap = knowledgeDocumentService.getKnowledgeDocumentMap( + convertList(segmentMap.values(), AiKnowledgeSegmentDO::getDocumentId)); + List messageVOList = BeanUtils.toBean(messageList, AiChatMessageRespVO.class); + for (int i = 0; i < messageList.size(); i++) { + AiChatMessageDO message = messageList.get(i); + if (CollUtil.isEmpty(message.getSegmentIds())) { + continue; + } + // 设置知识库段落信息 + messageVOList.get(i).setSegments(convertList(message.getSegmentIds(), segmentId -> { + AiKnowledgeSegmentDO segment = segmentMap.get(segmentId); + if (segment == null) { + return null; + } + AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId()); + if (document == null) { + return null; + } + return new AiChatMessageRespVO.KnowledgeSegment().setId(segment.getId()).setContent(segment.getContent()) + .setDocumentId(segment.getDocumentId()).setDocumentName(document.getName()); + })); + } + return success(messageVOList); } @Operation(summary = "删除消息") @@ -84,7 +123,8 @@ public class AiChatMessageController { @Operation(summary = "删除指定对话的消息") @DeleteMapping("/delete-by-conversation-id") @Parameter(name = "conversationId", required = true, description = "对话编号", example = "1024") - public CommonResult deleteChatMessageByConversationId(@RequestParam("conversationId") Long conversationId) { + public CommonResult deleteChatMessageByConversationId( + @RequestParam("conversationId") Long conversationId) { chatMessageService.deleteChatMessageByConversationId(conversationId, getLoginUserId()); return success(true); } @@ -103,7 +143,8 @@ public class AiChatMessageController { Map roleMap = chatRoleService.getChatRoleMap( convertSet(pageResult.getList(), AiChatMessageDO::getRoleId)); return success(BeanUtils.toBean(pageResult, AiChatMessageRespVO.class, - respVO -> MapUtils.findAndThen(roleMap, respVO.getRoleId(), role -> respVO.setRoleName(role.getName())))); + respVO -> MapUtils.findAndThen(roleMap, respVO.getRoleId(), + role -> respVO.setRoleName(role.getName())))); } @Operation(summary = "管理员删除消息") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java index 9b358df6f2..5d44e4f967 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java @@ -4,6 +4,7 @@ import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; import java.time.LocalDateTime; +import java.util.List; @Schema(description = "管理后台 - AI 聊天消息 Response VO") @Data @@ -39,6 +40,12 @@ public class AiChatMessageRespVO { @Schema(description = "是否携带上下文", requiredMode = Schema.RequiredMode.REQUIRED, example = "true") private Boolean useContext; + @Schema(description = "知识库段落编号数组", example = "[1,2,3]") + private List segmentIds; + + @Schema(description = "知识库段落数组") + private List segments; + @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") private LocalDateTime createTime; @@ -47,4 +54,22 @@ public class AiChatMessageRespVO { @Schema(description = "角色名字", example = "小黄") private String roleName; + @Schema(description = "知识库段落", example = "Java 开发手册") + @Data + public static class KnowledgeSegment { + + @Schema(description = "段落编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") + private Long id; + + @Schema(description = "切片内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册") + private String content; + + @Schema(description = "文档编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790") + private Long documentId; + + @Schema(description = "文档名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "产品使用手册") + private String documentName; + + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/AiKnowledgeController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/AiKnowledgeController.java index 39fb8d9430..7dd2e1647e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/AiKnowledgeController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/AiKnowledgeController.java @@ -23,7 +23,6 @@ import java.util.List; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; -// TODO @芋艿:增加权限标识 @Tag(name = "管理后台 - AI 知识库") @RestController @RequestMapping("/ai/knowledge") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java index d121d85785..94f764c85e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java @@ -20,7 +20,7 @@ import java.util.List; * @since 2024/4/14 17:35 * @since 2024/4/14 17:35 */ -@TableName("ai_chat_message") +@TableName(value = "ai_chat_message", autoResultMap = true) @KeySequence("ai_chat_conversation_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。 @Data @EqualsAndHashCode(callSuper = true) @@ -71,14 +71,6 @@ public class AiChatMessageDO extends BaseDO { */ private Long roleId; - /** - * 知识库段落编号数组 - * - * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段 - */ - @TableField(typeHandler = LongListTypeHandler.class) - private List segmentIds; - /** * 模型标志 * @@ -102,4 +94,12 @@ public class AiChatMessageDO extends BaseDO { */ private Boolean useContext; + /** + * 知识库段落编号数组 + * + * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段 + */ + @TableField(typeHandler = LongListTypeHandler.class) + private List segmentIds; + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index b95f7cd214..cd6c39d405 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -212,7 +212,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // 1.4 知识库,通过 UserMessage 实现 if (CollUtil.isNotEmpty(knowledgeSegments)) { String reference = knowledgeSegments.stream() - .map(segment -> "\n" + segment.getContent() + "") + .map(segment -> "" + segment.getContent() + "") .collect(Collectors.joining("\n\n")); chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference))); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java index 48dd78dbb9..2d78f94f34 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java @@ -26,6 +26,7 @@ import org.springframework.transaction.annotation.Transactional; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; @@ -205,9 +206,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic @Override public List getKnowledgeDocumentList(Collection ids) { if (CollUtil.isEmpty(ids)) { - return new ArrayList<>(); + return Collections.emptyList(); } - return knowledgeDocumentMapper.selectByIds(ids); + return knowledgeDocumentMapper.selectBatchIds(ids); } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java index 15ab941fe8..32272abdb8 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java @@ -10,7 +10,11 @@ import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchR import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO; import org.springframework.scheduling.annotation.Async; +import java.util.Collection; import java.util.List; +import java.util.Map; + +import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap; /** * AI 知识库段落 Service 接口 @@ -27,6 +31,24 @@ public interface AiKnowledgeSegmentService { */ AiKnowledgeSegmentDO getKnowledgeSegment(Long id); + /** + * 获取知识库段落列表 + * + * @param ids 段落编号列表 + * @return 段落列表 + */ + List getKnowledgeSegmentList(Collection ids); + + /** + * 获取知识库段落 Map + * + * @param ids 段落编号列表 + * @return 段落 Map + */ + default Map getKnowledgeSegmentMap(Collection ids) { + return convertMap(getKnowledgeSegmentList(ids), AiKnowledgeSegmentDO::getId); + } + /** * 获取段落分页 * diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java index efdd86203f..94b735b3c5 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java @@ -30,10 +30,7 @@ import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; +import java.util.*; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; @@ -322,7 +319,15 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService @Override public AiKnowledgeSegmentDO getKnowledgeSegment(Long id) { - return validateKnowledgeSegmentExists(id); + return segmentMapper.selectById(id); + } + + @Override + public List getKnowledgeSegmentList(Collection ids) { + if (CollUtil.isEmpty(ids)) { + return Collections.emptyList(); + } + return segmentMapper.selectBatchIds(ids); } }