【功能新增】AI:聊天记录返回时,增加 segments

This commit is contained in:
YunaiV 2025-03-09 21:00:34 +08:00
parent cddaca5863
commit b709af11a1
8 changed files with 116 additions and 23 deletions

View File

@ -12,9 +12,13 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.chat.AiChatMessageService;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
@ -32,7 +36,7 @@ import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.*;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
@Tag(name = "管理后台 - 聊天消息")
@ -47,6 +51,10 @@ public class AiChatMessageController {
private AiChatConversationService chatConversationService;
@Resource
private AiChatRoleService chatRoleService;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService;
@Resource
private AiKnowledgeDocumentService knowledgeDocumentService;
@Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢")
@PostMapping("/send")
@ -56,7 +64,8 @@ public class AiChatMessageController {
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(
@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
}
@ -69,8 +78,38 @@ public class AiChatMessageController {
if (conversation == null || ObjUtil.notEqual(conversation.getUserId(), getLoginUserId())) {
return success(Collections.emptyList());
}
// 1. 获取消息列表
List<AiChatMessageDO> messageList = chatMessageService.getChatMessageListByConversationId(conversationId);
return success(BeanUtils.toBean(messageList, AiChatMessageRespVO.class));
if (CollUtil.isEmpty(messageList)) {
return success(Collections.emptyList());
}
// 2. 拼接数据主要是知识库段落信息
Map<Long, AiKnowledgeSegmentDO> segmentMap = knowledgeSegmentService.getKnowledgeSegmentMap(convertListByFlatMap(messageList,
message -> CollUtil.isEmpty(message.getSegmentIds()) ? null : message.getSegmentIds().stream()));
Map<Long, AiKnowledgeDocumentDO> documentMap = knowledgeDocumentService.getKnowledgeDocumentMap(
convertList(segmentMap.values(), AiKnowledgeSegmentDO::getDocumentId));
List<AiChatMessageRespVO> messageVOList = BeanUtils.toBean(messageList, AiChatMessageRespVO.class);
for (int i = 0; i < messageList.size(); i++) {
AiChatMessageDO message = messageList.get(i);
if (CollUtil.isEmpty(message.getSegmentIds())) {
continue;
}
// 设置知识库段落信息
messageVOList.get(i).setSegments(convertList(message.getSegmentIds(), segmentId -> {
AiKnowledgeSegmentDO segment = segmentMap.get(segmentId);
if (segment == null) {
return null;
}
AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
if (document == null) {
return null;
}
return new AiChatMessageRespVO.KnowledgeSegment().setId(segment.getId()).setContent(segment.getContent())
.setDocumentId(segment.getDocumentId()).setDocumentName(document.getName());
}));
}
return success(messageVOList);
}
@Operation(summary = "删除消息")
@ -84,7 +123,8 @@ public class AiChatMessageController {
@Operation(summary = "删除指定对话的消息")
@DeleteMapping("/delete-by-conversation-id")
@Parameter(name = "conversationId", required = true, description = "对话编号", example = "1024")
public CommonResult<Boolean> deleteChatMessageByConversationId(@RequestParam("conversationId") Long conversationId) {
public CommonResult<Boolean> deleteChatMessageByConversationId(
@RequestParam("conversationId") Long conversationId) {
chatMessageService.deleteChatMessageByConversationId(conversationId, getLoginUserId());
return success(true);
}
@ -103,7 +143,8 @@ public class AiChatMessageController {
Map<Long, AiChatRoleDO> roleMap = chatRoleService.getChatRoleMap(
convertSet(pageResult.getList(), AiChatMessageDO::getRoleId));
return success(BeanUtils.toBean(pageResult, AiChatMessageRespVO.class,
respVO -> MapUtils.findAndThen(roleMap, respVO.getRoleId(), role -> respVO.setRoleName(role.getName()))));
respVO -> MapUtils.findAndThen(roleMap, respVO.getRoleId(),
role -> respVO.setRoleName(role.getName()))));
}
@Operation(summary = "管理员删除消息")

View File

@ -4,6 +4,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.List;
@Schema(description = "管理后台 - AI 聊天消息 Response VO")
@Data
@ -39,6 +40,12 @@ public class AiChatMessageRespVO {
@Schema(description = "是否携带上下文", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
private Boolean useContext;
@Schema(description = "知识库段落编号数组", example = "[1,2,3]")
private List<Long> segmentIds;
@Schema(description = "知识库段落数组")
private List<KnowledgeSegment> segments;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
private LocalDateTime createTime;
@ -47,4 +54,22 @@ public class AiChatMessageRespVO {
@Schema(description = "角色名字", example = "小黄")
private String roleName;
@Schema(description = "知识库段落", example = "Java 开发手册")
@Data
public static class KnowledgeSegment {
@Schema(description = "段落编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private Long id;
@Schema(description = "切片内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册")
private String content;
@Schema(description = "文档编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790")
private Long documentId;
@Schema(description = "文档名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "产品使用手册")
private String documentName;
}
}

View File

@ -23,7 +23,6 @@ import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
// TODO @芋艿增加权限标识
@Tag(name = "管理后台 - AI 知识库")
@RestController
@RequestMapping("/ai/knowledge")

View File

@ -20,7 +20,7 @@ import java.util.List;
* @since 2024/4/14 17:35
* @since 2024/4/14 17:35
*/
@TableName("ai_chat_message")
@TableName(value = "ai_chat_message", autoResultMap = true)
@KeySequence("ai_chat_conversation_seq") // 用于 OraclePostgreSQLKingbaseDB2H2 数据库的主键自增如果是 MySQL 等数据库可不写
@Data
@EqualsAndHashCode(callSuper = true)
@ -71,14 +71,6 @@ public class AiChatMessageDO extends BaseDO {
*/
private Long roleId;
/**
* 知识库段落编号数组
*
* 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
*/
@TableField(typeHandler = LongListTypeHandler.class)
private List<Long> segmentIds;
/**
* 模型标志
*
@ -102,4 +94,12 @@ public class AiChatMessageDO extends BaseDO {
*/
private Boolean useContext;
/**
* 知识库段落编号数组
*
* 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
*/
@TableField(typeHandler = LongListTypeHandler.class)
private List<Long> segmentIds;
}

View File

@ -212,7 +212,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 1.4 知识库通过 UserMessage 实现
if (CollUtil.isNotEmpty(knowledgeSegments)) {
String reference = knowledgeSegments.stream()
.map(segment -> "<Reference>\n" + segment.getContent() + "</Reference>")
.map(segment -> "<Reference>" + segment.getContent() + "</Reference>")
.collect(Collectors.joining("\n\n"));
chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
}

View File

@ -26,6 +26,7 @@ import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -205,9 +206,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
@Override
public List<AiKnowledgeDocumentDO> getKnowledgeDocumentList(Collection<Long> ids) {
if (CollUtil.isEmpty(ids)) {
return new ArrayList<>();
return Collections.emptyList();
}
return knowledgeDocumentMapper.selectByIds(ids);
return knowledgeDocumentMapper.selectBatchIds(ids);
}
}

View File

@ -10,7 +10,11 @@ import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchR
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
import org.springframework.scheduling.annotation.Async;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
/**
* AI 知识库段落 Service 接口
@ -27,6 +31,24 @@ public interface AiKnowledgeSegmentService {
*/
AiKnowledgeSegmentDO getKnowledgeSegment(Long id);
/**
* 获取知识库段落列表
*
* @param ids 段落编号列表
* @return 段落列表
*/
List<AiKnowledgeSegmentDO> getKnowledgeSegmentList(Collection<Long> ids);
/**
* 获取知识库段落 Map
*
* @param ids 段落编号列表
* @return 段落 Map
*/
default Map<Long, AiKnowledgeSegmentDO> getKnowledgeSegmentMap(Collection<Long> ids) {
return convertMap(getKnowledgeSegmentList(ids), AiKnowledgeSegmentDO::getId);
}
/**
* 获取段落分页
*

View File

@ -30,10 +30,7 @@ import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@ -322,7 +319,15 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Override
public AiKnowledgeSegmentDO getKnowledgeSegment(Long id) {
return validateKnowledgeSegmentExists(id);
return segmentMapper.selectById(id);
}
@Override
public List<AiKnowledgeSegmentDO> getKnowledgeSegmentList(Collection<Long> ids) {
if (CollUtil.isEmpty(ids)) {
return Collections.emptyList();
}
return segmentMapper.selectBatchIds(ids);
}
}