【增加】增加模型联网搜索,重写用户 prompt

This commit is contained in:
cherishsince 2025-03-02 12:34:09 +08:00
parent f638f90afc
commit cb2e9e044f
5 changed files with 68 additions and 8 deletions

View File

@ -22,4 +22,6 @@ public class AiChatMessageSendReqVO {
@Schema(description = "是否携带上下文", example = "true") @Schema(description = "是否携带上下文", example = "true")
private Boolean useContext; private Boolean useContext;
@Schema(description = "搜索enable", example = "true")
private Boolean searchEnable;
} }

View File

@ -23,6 +23,8 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.websearch.WebSearchService;
import cn.iocoder.yudao.module.ai.service.websearch.vo.WebSearchRespVO;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
@ -60,7 +62,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Resource @Resource
private AiChatMessageMapper chatMessageMapper; private AiChatMessageMapper chatMessageMapper;
@Resource @Resource
private AiChatConversationService chatConversationService; private AiChatConversationService chatConversationService;
@Resource @Resource
@ -69,6 +70,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Resource @Resource
private AiKnowledgeSegmentService knowledgeSegmentService; private AiKnowledgeSegmentService knowledgeSegmentService;
@Resource
private WebSearchService webSearchService;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@ -93,8 +96,11 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 3.2 召回段落 // 3.2 召回段落
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 联网搜索内容
List<WebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
// 3.3 创建 chat 需要的 Prompt // 3.3 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO); Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
ChatResponse chatResponse = chatModel.call(prompt); ChatResponse chatResponse = chatModel.call(prompt);
// 3.4 段式返回 // 3.4 段式返回
@ -124,12 +130,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 召回段落 // 3.2 召回段落
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 联网搜索
// todo count 看是否需要放到配置文件
List<WebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
// 3.3 构建 Prompt并进行调用 // 3.3 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO); Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.4 流式返回 // 3.4 流式返回
@ -155,6 +164,29 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
} }
/**
* 获取 web search
*
* @param prompt 提示词
* @param searchEnable 查询到会否开启
* @param count 查询数量
* @return 返回查询结果
*/
private List<WebSearchRespVO> getWebSearch(String prompt, Boolean searchEnable, int count) {
if (searchEnable != null && searchEnable) {
List<WebSearchRespVO> webSearchRespList = webSearchService.bingSearch(prompt, count);
Map<String, String> webCrawlerRespMap
= webSearchService.webCrawler(webSearchRespList.stream().map(WebSearchRespVO::getUrl).toList());
for (WebSearchRespVO webSearchRespVO : webSearchRespList) {
if (!webCrawlerRespMap.containsKey(webSearchRespVO.getUrl())) {
continue;
}
webSearchRespVO.setContent(webCrawlerRespMap.get(webSearchRespVO.getUrl()));
}
}
return Collections.emptyList();
}
private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) { private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
if (Objects.isNull(knowledgeId)) { if (Objects.isNull(knowledgeId)) {
return Collections.emptyList(); return Collections.emptyList();
@ -162,8 +194,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content)); return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList, private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { List<AiKnowledgeSegmentDO> segmentList, AiChatModelDO model,
AiChatMessageSendReqVO sendReqVO, List<WebSearchRespVO> webSearchRespList) {
// 1. 构建 Prompt Message 列表 // 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
@ -184,7 +217,25 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO); List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
// 1.4 user message 新发送消息 // 1.4 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); if (sendReqVO.getSearchEnable() != null
&& sendReqVO.getSearchEnable() && CollUtil.isNotEmpty(webSearchRespList)) {
StringBuilder promptBuilder = StrUtil.builder();
promptBuilder.append("## 以下是联网搜索内容: \n");
int i = 1;
for (WebSearchRespVO webSearchRespVO : webSearchRespList) {
promptBuilder.append("[内容%s begin]".formatted(i)).append("\n");
promptBuilder.append("标题:").append(webSearchRespVO.getTitle()).append("\n");
promptBuilder.append("地址:").append(webSearchRespVO.getUrl()).append("\n");
promptBuilder.append("内容:").append(webSearchRespVO.getContent()).append("\n");
promptBuilder.append("[内容%s end]".formatted(i)).append("\n");
i++;
}
promptBuilder.append("## 用户问题如下: \n");
promptBuilder.append(sendReqVO.getContent()).append("\n");
} else {
chatMessages.add(new UserMessage(sendReqVO.getContent()));
}
// 2. 构建 ChatOptions 对象 // 2. 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());

View File

@ -22,4 +22,8 @@ public class WebSearchRespVO {
* 摘要 * 摘要
*/ */
private String snippet; private String snippet;
/**
* 网站内容
*/
private String content;
} }

View File

@ -0,0 +1 @@
package cn.iocoder.yudao.module.ai;

View File

@ -228,7 +228,9 @@ yudao:
wxa-subscribe-message: wxa-subscribe-message:
miniprogram-state: developer # 跳转小程序类型:开发版为 “developer”体验版为 “trial”为正式版为 “formal” miniprogram-state: developer # 跳转小程序类型:开发版为 “developer”体验版为 “trial”为正式版为 “formal”
tencent-lbs-key: TVDBZ-TDILD-4ON4B-PFDZA-RNLKH-VVF6E # QQ 地图的密钥 https://lbs.qq.com/service/staticV2/staticGuide/staticDoc tencent-lbs-key: TVDBZ-TDILD-4ON4B-PFDZA-RNLKH-VVF6E # QQ 地图的密钥 https://lbs.qq.com/service/staticV2/staticGuide/staticDoc
web-search:
bing-api-key: xx
google-api-key: xx
justauth: justauth:
enabled: true enabled: true
type: type: