【功能新增】AI:聊天记录返回时,增加 segments
This commit is contained in:
parent
cddaca5863
commit
b709af11a1
|
@ -12,9 +12,13 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
|
|||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
|
||||
import cn.iocoder.yudao.module.ai.service.chat.AiChatMessageService;
|
||||
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
|
||||
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
|
@ -32,7 +36,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.*;
|
||||
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
|
||||
|
||||
@Tag(name = "管理后台 - 聊天消息")
|
||||
|
@ -47,6 +51,10 @@ public class AiChatMessageController {
|
|||
private AiChatConversationService chatConversationService;
|
||||
@Resource
|
||||
private AiChatRoleService chatRoleService;
|
||||
@Resource
|
||||
private AiKnowledgeSegmentService knowledgeSegmentService;
|
||||
@Resource
|
||||
private AiKnowledgeDocumentService knowledgeDocumentService;
|
||||
|
||||
@Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢")
|
||||
@PostMapping("/send")
|
||||
|
@ -56,7 +64,8 @@ public class AiChatMessageController {
|
|||
|
||||
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
|
||||
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
||||
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(
|
||||
@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
||||
return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
|
||||
}
|
||||
|
||||
|
@ -69,8 +78,38 @@ public class AiChatMessageController {
|
|||
if (conversation == null || ObjUtil.notEqual(conversation.getUserId(), getLoginUserId())) {
|
||||
return success(Collections.emptyList());
|
||||
}
|
||||
// 1. 获取消息列表
|
||||
List<AiChatMessageDO> messageList = chatMessageService.getChatMessageListByConversationId(conversationId);
|
||||
return success(BeanUtils.toBean(messageList, AiChatMessageRespVO.class));
|
||||
if (CollUtil.isEmpty(messageList)) {
|
||||
return success(Collections.emptyList());
|
||||
}
|
||||
|
||||
// 2. 拼接数据,主要是知识库段落信息
|
||||
Map<Long, AiKnowledgeSegmentDO> segmentMap = knowledgeSegmentService.getKnowledgeSegmentMap(convertListByFlatMap(messageList,
|
||||
message -> CollUtil.isEmpty(message.getSegmentIds()) ? null : message.getSegmentIds().stream()));
|
||||
Map<Long, AiKnowledgeDocumentDO> documentMap = knowledgeDocumentService.getKnowledgeDocumentMap(
|
||||
convertList(segmentMap.values(), AiKnowledgeSegmentDO::getDocumentId));
|
||||
List<AiChatMessageRespVO> messageVOList = BeanUtils.toBean(messageList, AiChatMessageRespVO.class);
|
||||
for (int i = 0; i < messageList.size(); i++) {
|
||||
AiChatMessageDO message = messageList.get(i);
|
||||
if (CollUtil.isEmpty(message.getSegmentIds())) {
|
||||
continue;
|
||||
}
|
||||
// 设置知识库段落信息
|
||||
messageVOList.get(i).setSegments(convertList(message.getSegmentIds(), segmentId -> {
|
||||
AiKnowledgeSegmentDO segment = segmentMap.get(segmentId);
|
||||
if (segment == null) {
|
||||
return null;
|
||||
}
|
||||
AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
|
||||
if (document == null) {
|
||||
return null;
|
||||
}
|
||||
return new AiChatMessageRespVO.KnowledgeSegment().setId(segment.getId()).setContent(segment.getContent())
|
||||
.setDocumentId(segment.getDocumentId()).setDocumentName(document.getName());
|
||||
}));
|
||||
}
|
||||
return success(messageVOList);
|
||||
}
|
||||
|
||||
@Operation(summary = "删除消息")
|
||||
|
@ -84,7 +123,8 @@ public class AiChatMessageController {
|
|||
@Operation(summary = "删除指定对话的消息")
|
||||
@DeleteMapping("/delete-by-conversation-id")
|
||||
@Parameter(name = "conversationId", required = true, description = "对话编号", example = "1024")
|
||||
public CommonResult<Boolean> deleteChatMessageByConversationId(@RequestParam("conversationId") Long conversationId) {
|
||||
public CommonResult<Boolean> deleteChatMessageByConversationId(
|
||||
@RequestParam("conversationId") Long conversationId) {
|
||||
chatMessageService.deleteChatMessageByConversationId(conversationId, getLoginUserId());
|
||||
return success(true);
|
||||
}
|
||||
|
@ -103,7 +143,8 @@ public class AiChatMessageController {
|
|||
Map<Long, AiChatRoleDO> roleMap = chatRoleService.getChatRoleMap(
|
||||
convertSet(pageResult.getList(), AiChatMessageDO::getRoleId));
|
||||
return success(BeanUtils.toBean(pageResult, AiChatMessageRespVO.class,
|
||||
respVO -> MapUtils.findAndThen(roleMap, respVO.getRoleId(), role -> respVO.setRoleName(role.getName()))));
|
||||
respVO -> MapUtils.findAndThen(roleMap, respVO.getRoleId(),
|
||||
role -> respVO.setRoleName(role.getName()))));
|
||||
}
|
||||
|
||||
@Operation(summary = "管理员删除消息")
|
||||
|
|
|
@ -4,6 +4,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
|
|||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
|
||||
@Schema(description = "管理后台 - AI 聊天消息 Response VO")
|
||||
@Data
|
||||
|
@ -39,6 +40,12 @@ public class AiChatMessageRespVO {
|
|||
@Schema(description = "是否携带上下文", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
|
||||
private Boolean useContext;
|
||||
|
||||
@Schema(description = "知识库段落编号数组", example = "[1,2,3]")
|
||||
private List<Long> segmentIds;
|
||||
|
||||
@Schema(description = "知识库段落数组")
|
||||
private List<KnowledgeSegment> segments;
|
||||
|
||||
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
|
||||
private LocalDateTime createTime;
|
||||
|
||||
|
@ -47,4 +54,22 @@ public class AiChatMessageRespVO {
|
|||
@Schema(description = "角色名字", example = "小黄")
|
||||
private String roleName;
|
||||
|
||||
@Schema(description = "知识库段落", example = "Java 开发手册")
|
||||
@Data
|
||||
public static class KnowledgeSegment {
|
||||
|
||||
@Schema(description = "段落编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
|
||||
private Long id;
|
||||
|
||||
@Schema(description = "切片内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册")
|
||||
private String content;
|
||||
|
||||
@Schema(description = "文档编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790")
|
||||
private Long documentId;
|
||||
|
||||
@Schema(description = "文档名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "产品使用手册")
|
||||
private String documentName;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@ import java.util.List;
|
|||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
||||
|
||||
// TODO @芋艿:增加权限标识
|
||||
@Tag(name = "管理后台 - AI 知识库")
|
||||
@RestController
|
||||
@RequestMapping("/ai/knowledge")
|
||||
|
|
|
@ -20,7 +20,7 @@ import java.util.List;
|
|||
* @since 2024/4/14 17:35
|
||||
* @since 2024/4/14 17:35
|
||||
*/
|
||||
@TableName("ai_chat_message")
|
||||
@TableName(value = "ai_chat_message", autoResultMap = true)
|
||||
@KeySequence("ai_chat_conversation_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
|
@ -71,14 +71,6 @@ public class AiChatMessageDO extends BaseDO {
|
|||
*/
|
||||
private Long roleId;
|
||||
|
||||
/**
|
||||
* 知识库段落编号数组
|
||||
*
|
||||
* 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
|
||||
*/
|
||||
@TableField(typeHandler = LongListTypeHandler.class)
|
||||
private List<Long> segmentIds;
|
||||
|
||||
/**
|
||||
* 模型标志
|
||||
*
|
||||
|
@ -102,4 +94,12 @@ public class AiChatMessageDO extends BaseDO {
|
|||
*/
|
||||
private Boolean useContext;
|
||||
|
||||
/**
|
||||
* 知识库段落编号数组
|
||||
*
|
||||
* 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
|
||||
*/
|
||||
@TableField(typeHandler = LongListTypeHandler.class)
|
||||
private List<Long> segmentIds;
|
||||
|
||||
}
|
||||
|
|
|
@ -212,7 +212,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||
// 1.4 知识库,通过 UserMessage 实现
|
||||
if (CollUtil.isNotEmpty(knowledgeSegments)) {
|
||||
String reference = knowledgeSegments.stream()
|
||||
.map(segment -> "<Reference>\n" + segment.getContent() + "</Reference>")
|
||||
.map(segment -> "<Reference>" + segment.getContent() + "</Reference>")
|
||||
.collect(Collectors.joining("\n\n"));
|
||||
chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.springframework.transaction.annotation.Transactional;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
|
@ -205,9 +206,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
|
|||
@Override
|
||||
public List<AiKnowledgeDocumentDO> getKnowledgeDocumentList(Collection<Long> ids) {
|
||||
if (CollUtil.isEmpty(ids)) {
|
||||
return new ArrayList<>();
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return knowledgeDocumentMapper.selectByIds(ids);
|
||||
return knowledgeDocumentMapper.selectBatchIds(ids);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -10,7 +10,11 @@ import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchR
|
|||
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
|
||||
|
||||
/**
|
||||
* AI 知识库段落 Service 接口
|
||||
|
@ -27,6 +31,24 @@ public interface AiKnowledgeSegmentService {
|
|||
*/
|
||||
AiKnowledgeSegmentDO getKnowledgeSegment(Long id);
|
||||
|
||||
/**
|
||||
* 获取知识库段落列表
|
||||
*
|
||||
* @param ids 段落编号列表
|
||||
* @return 段落列表
|
||||
*/
|
||||
List<AiKnowledgeSegmentDO> getKnowledgeSegmentList(Collection<Long> ids);
|
||||
|
||||
/**
|
||||
* 获取知识库段落 Map
|
||||
*
|
||||
* @param ids 段落编号列表
|
||||
* @return 段落 Map
|
||||
*/
|
||||
default Map<Long, AiKnowledgeSegmentDO> getKnowledgeSegmentMap(Collection<Long> ids) {
|
||||
return convertMap(getKnowledgeSegmentList(ids), AiKnowledgeSegmentDO::getId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取段落分页
|
||||
*
|
||||
|
|
|
@ -30,10 +30,7 @@ import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
|
|||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.*;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
||||
|
@ -322,7 +319,15 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
|
|||
|
||||
@Override
|
||||
public AiKnowledgeSegmentDO getKnowledgeSegment(Long id) {
|
||||
return validateKnowledgeSegmentExists(id);
|
||||
return segmentMapper.selectById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AiKnowledgeSegmentDO> getKnowledgeSegmentList(Collection<Long> ids) {
|
||||
if (CollUtil.isEmpty(ids)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return segmentMapper.selectBatchIds(ids);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue