diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java
index 6cb98c5629..1479274959 100644
--- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java
+++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java
@@ -35,11 +35,7 @@ public enum AiChatRoleEnum {
### 微信
除此之外不要任何解释性语句。
"""),
-
- AI_KNOWLEDGE_ROLE("知识库助手", """
- 给你提供一些数据参考:{info},请回答我的问题。
- 请你跟进数据参考与工具返回结果回复用户的请求。
- """);
+ ;
/**
* 角色名
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java
index a9c956deae..358e994104 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java
@@ -1,9 +1,8 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
@@ -65,14 +64,6 @@ public class AiChatConversationDO extends BaseDO {
*/
private Long roleId;
- // TODO @芋艿:可优化,绑定多个知识库。前提,spring ai 支持 RerankModel 的封装
- /**
- * 知识库编号
- *
- * 关联 {@link AiKnowledgeDO#getId()}
- */
- private Long knowledgeId;
-
/**
* 模型编号
*
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java
index a026ec8193..d121d85785 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java
@@ -1,14 +1,14 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
-import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.*;
import org.springframework.ai.chat.messages.MessageType;
@@ -71,13 +71,12 @@ public class AiChatMessageDO extends BaseDO {
*/
private Long roleId;
-
/**
- * 段落编号数组
+ * 知识库段落编号数组
*
* 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
*/
- @TableField(typeHandler = JacksonTypeHandler.class)
+ @TableField(typeHandler = LongListTypeHandler.class)
private List segmentIds;
/**
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java
index 6483166a7f..6c35571c8f 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java
@@ -68,7 +68,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
// 2. 创建 AiChatConversationDO 聊天对话
AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false)
- .setModelId(model.getId()).setModel(model.getModel()).setKnowledgeId(createReqVO.getKnowledgeId())
+ .setModelId(model.getId()).setModel(model.getModel())
.setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts());
if (role != null) {
conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage());
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
index 51bcc45d44..b95f7cd214 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
@@ -14,12 +14,14 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
-import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
+import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
@@ -32,13 +34,13 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
import java.time.LocalDateTime;
import java.util.*;
+import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
@@ -56,12 +58,21 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
@Slf4j
public class AiChatMessageServiceImpl implements AiChatMessageService {
+ /**
+ * 知识库转 {@link UserMessage} 的内容模版
+ */
+ private static final String KNOWLEDGE_USER_MESSAGE_TEMPLATE = "使用 标记中的内容作为本次对话的参考:\n\n" +
+ "%s\n\n" + // 多个 的拼接
+ "回答要求:\n- 避免提及你是从 获取的知识。";
+
@Resource
private AiChatMessageMapper chatMessageMapper;
@Resource
private AiChatConversationService chatConversationService;
@Resource
+ private AiChatRoleService chatRoleService;
+ @Resource
private AiModelService modalService;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService;
@@ -69,118 +80,143 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
// 1.1 校验对话存在
- AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
+ AiChatConversationDO conversation = chatConversationService
+ .validateChatConversationExists(sendReqVO.getConversationId());
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
}
List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型
AiModelDO model = modalService.validateModel(conversation.getModelId());
- ChatModel chatModel = modalService.getChatModel(model.getKeyId());
+ ChatModel chatModel = modalService.getChatModel(model.getId());
- // 2. 插入 user 发送消息
+ // 2. 知识库找回
+ List knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
+ conversation);
+
+ // 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
- userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
+ userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
+ null);
// 3.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
- userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
+ userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
+ knowledgeSegments);
- // 3.2 召回段落
- List segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
-
- // 3.3 创建 chat 需要的 Prompt
- Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
+ // 3.2 创建 chat 需要的 Prompt
+ Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
ChatResponse chatResponse = chatModel.call(prompt);
- // 3.4 段式返回
+ // 3.3 段式返回
String newContent = chatResponse.getResult().getOutput().getText();
- chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
- return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
- .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
+ chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
+ return new AiChatMessageSendRespVO()
+ .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
+ .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
+ .setContent(newContent));
}
@Override
- public Flux> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
+ public Flux> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO,
+ Long userId) {
// 1.1 校验对话存在
- AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
+ AiChatConversationDO conversation = chatConversationService
+ .validateChatConversationExists(sendReqVO.getConversationId());
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
}
List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型
AiModelDO model = modalService.validateModel(conversation.getModelId());
- StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId());
+ StreamingChatModel chatModel = modalService.getChatModel(model.getId());
- // 2. 插入 user 发送消息
+ // 2. 知识库找回
+ List knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
+ conversation);
+
+ // 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
- userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
+ userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
+ null);
- // 3.1 插入 assistant 接收消息
+ // 4.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
- userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
+ userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
+ knowledgeSegments);
- // 3.2 召回段落
- List segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
-
- // 3.3 构建 Prompt,并进行调用
- Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
+ // 4.2 构建 Prompt,并进行调用
+ Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
Flux streamResponse = chatModel.stream(prompt);
- // 3.4 流式返回
+ // 4.3 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null;
newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况
contentBuffer.append(newContent);
// 响应结果
- return success(new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
- .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)));
+ return success(new AiChatMessageSendRespVO()
+ .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
+ .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
+ .setContent(newContent)));
}).doOnComplete(() -> {
// 忽略租户,因为 Flux 异步无法透传租户
- TenantUtils.executeIgnore(() ->
- chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId))
- .setContent(contentBuffer.toString())));
+ TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
+ new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
}).doOnError(throwable -> {
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
// 忽略租户,因为 Flux 异步无法透传租户
- TenantUtils.executeIgnore(() ->
- chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())));
+ TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
+ new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())));
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
}
- private List recallSegment(String content, Long knowledgeId) {
- if (Objects.isNull(knowledgeId)) {
+ private List recallKnowledgeSegment(String content,
+ AiChatConversationDO conversation) {
+ // 1. 查询聊天角色
+ if (conversation == null || conversation.getRoleId() == null) {
return Collections.emptyList();
}
-// return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
- return null;
- }
-
- private Prompt buildPrompt(AiChatConversationDO conversation, List messages, List segmentList,
- AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
- // 1. 构建 Prompt Message 列表
- List chatMessages = new ArrayList<>();
-
- // 1.1 召回内容消息构建
- if (CollUtil.isNotEmpty(segmentList)) {
- PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
- StringBuilder infoBuilder = StrUtil.builder();
- segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
- Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
- chatMessages.add(message);
+ AiChatRoleDO role = chatRoleService.getChatRole(conversation.getRoleId());
+ if (role == null || CollUtil.isEmpty(role.getKnowledgeIds())) {
+ return Collections.emptyList();
}
- // 1.2 system context 角色设定
+ // 2. 遍历找回
+ List knowledgeSegments = new ArrayList<>();
+ for (Long knowledgeId : role.getKnowledgeIds()) {
+ knowledgeSegments.addAll(knowledgeSegmentService.searchKnowledgeSegment(new AiKnowledgeSegmentSearchReqBO()
+ .setKnowledgeId(knowledgeId).setContent(content)));
+ }
+ return knowledgeSegments;
+ }
+
+ private Prompt buildPrompt(AiChatConversationDO conversation, List messages,
+ List knowledgeSegments,
+ AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
+ List chatMessages = new ArrayList<>();
+ // 1.1 System Context 角色设定
if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
}
- // 1.3 history message 历史消息
+
+ // 1.2 历史 history message 历史消息
List contextMessages = filterContextMessages(messages, conversation, sendReqVO);
- contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
- // 1.4 user message 新发送消息
+ contextMessages
+ .forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
+
+ // 1.3 当前 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent()));
+ // 1.4 知识库,通过 UserMessage 实现
+ if (CollUtil.isNotEmpty(knowledgeSegments)) {
+ String reference = knowledgeSegments.stream()
+ .map(segment -> "\n" + segment.getContent() + "")
+ .collect(Collectors.joining("\n\n"));
+ chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
+ }
+
// 2. 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
@@ -199,8 +235,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
* @return 消息上下文
*/
private List filterContextMessages(List messages,
- AiChatConversationDO conversation,
- AiChatMessageSendReqVO sendReqVO) {
+ AiChatConversationDO conversation,
+ AiChatMessageSendReqVO sendReqVO) {
if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
return Collections.emptyList();
}
@@ -211,7 +247,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
continue;
}
AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
- if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
+ if (userMessage == null
+ || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
|| StrUtil.isEmpty(assistantMessage.getContent())) {
continue;
}
@@ -228,11 +265,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
- AiModelDO model, Long userId, Long roleId,
- MessageType messageType, String content, Boolean useContext) {
+ AiModelDO model, Long userId, Long roleId,
+ MessageType messageType, String content, Boolean useContext,
+ List knowledgeSegments) {
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
- .setType(messageType.getValue()).setContent(content).setUseContext(useContext);
+ .setType(messageType.getValue()).setContent(content).setUseContext(useContext)
+ .setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId));
message.setCreateTime(LocalDateTime.now());
chatMessageMapper.insert(message);
return message;