【代码评审】AI:联网检索

This commit is contained in:
YunaiV 2025-03-15 07:38:22 +08:00
parent 4ea26c7e81
commit d7567e669c
8 changed files with 57 additions and 55 deletions

View File

@ -3,9 +3,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors;
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO") @Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
@Data @Data
@ -22,6 +20,8 @@ public class AiChatMessageSendReqVO {
@Schema(description = "是否携带上下文", example = "true") @Schema(description = "是否携带上下文", example = "true")
private Boolean useContext; private Boolean useContext;
@Schema(description = "搜索enable", example = "true") // TODO @芋艿改成 useSearch保持和 useContext 一个风格
@Schema(description = "是否搜索", example = "true")
private Boolean searchEnable; private Boolean searchEnable;
} }

View File

@ -24,7 +24,7 @@ import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.websearch.WebSearchService; import cn.iocoder.yudao.module.ai.service.websearch.WebSearchService;
import cn.iocoder.yudao.module.ai.service.websearch.vo.WebSearchRespVO; import cn.iocoder.yudao.module.ai.service.websearch.vo.AiWebSearchRespVO;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
@ -97,9 +97,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 联网搜索内容 // 3.3 联网搜索内容
List<WebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10); // TODO @芋艿可能要改成前端检索
List<AiWebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
// 3.4 创建 chat 需要的 Prompt // 3.4 创建 Chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch); Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
ChatResponse chatResponse = chatModel.call(prompt); ChatResponse chatResponse = chatModel.call(prompt);
@ -135,7 +136,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 3.3 联网搜索 // 3.3 联网搜索
// todo count 看是否需要放到配置文件 // todo count 看是否需要放到配置文件
List<WebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10); List<AiWebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
// 3.4 构建 Prompt并进行调用 // 3.4 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch); Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
@ -172,12 +173,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
* @param count 查询数量 * @param count 查询数量
* @return 返回查询结果 * @return 返回查询结果
*/ */
private List<WebSearchRespVO> getWebSearch(String prompt, Boolean searchEnable, int count) { private List<AiWebSearchRespVO> getWebSearch(String prompt, Boolean searchEnable, int count) {
if (searchEnable != null && searchEnable) { if (searchEnable != null && searchEnable) {
List<WebSearchRespVO> webSearchRespList = webSearchService.bingSearch(prompt, count); List<AiWebSearchRespVO> webSearchRespList = webSearchService.bingSearch(prompt, count);
Map<String, String> webCrawlerRespMap Map<String, String> webCrawlerRespMap
= webSearchService.webCrawler(webSearchRespList.stream().map(WebSearchRespVO::getUrl).toList()); = webSearchService.webCrawler(webSearchRespList.stream().map(AiWebSearchRespVO::getUrl).toList());
for (WebSearchRespVO webSearchRespVO : webSearchRespList) { for (AiWebSearchRespVO webSearchRespVO : webSearchRespList) {
if (!webCrawlerRespMap.containsKey(webSearchRespVO.getUrl())) { if (!webCrawlerRespMap.containsKey(webSearchRespVO.getUrl())) {
continue; continue;
} }
@ -196,7 +197,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
List<AiKnowledgeSegmentDO> segmentList, AiChatModelDO model, List<AiKnowledgeSegmentDO> segmentList, AiChatModelDO model,
AiChatMessageSendReqVO sendReqVO, List<WebSearchRespVO> webSearchRespList) { AiChatMessageSendReqVO sendReqVO, List<AiWebSearchRespVO> webSearchRespList) {
// 1. 构建 Prompt Message 列表 // 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
@ -217,13 +218,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO); List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
// 1.4 user message 新发送消息 // 1.4 user message 新发送消息
// TODO @芋艿处理下 prompt 模版
if (sendReqVO.getSearchEnable() != null if (sendReqVO.getSearchEnable() != null
&& sendReqVO.getSearchEnable() && CollUtil.isNotEmpty(webSearchRespList)) { && sendReqVO.getSearchEnable() && CollUtil.isNotEmpty(webSearchRespList)) {
StringBuilder promptBuilder = StrUtil.builder(); StringBuilder promptBuilder = StrUtil.builder();
promptBuilder.append("## 以下是联网搜索内容: \n"); promptBuilder.append("## 以下是联网搜索内容: \n");
int i = 1; int i = 1;
for (WebSearchRespVO webSearchRespVO : webSearchRespList) { for (AiWebSearchRespVO webSearchRespVO : webSearchRespList) {
promptBuilder.append("[内容%s begin]".formatted(i)).append("\n"); promptBuilder.append("[内容%s begin]".formatted(i)).append("\n");
promptBuilder.append("标题:").append(webSearchRespVO.getTitle()).append("\n"); promptBuilder.append("标题:").append(webSearchRespVO.getTitle()).append("\n");
promptBuilder.append("地址:").append(webSearchRespVO.getUrl()).append("\n"); promptBuilder.append("地址:").append(webSearchRespVO.getUrl()).append("\n");

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.module.ai.service.websearch; package cn.iocoder.yudao.module.ai.service.websearch;
import cn.iocoder.yudao.module.ai.service.websearch.vo.WebSearchRespVO; import cn.iocoder.yudao.module.ai.service.websearch.vo.AiWebSearchRespVO;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -17,8 +17,8 @@ public interface WebSearchService {
* @param count 返回结果数量 * @param count 返回结果数量
* @return 搜索结果列表 * @return 搜索结果列表
*/ */
List<WebSearchRespVO> bingSearch(String query, Integer count); List<AiWebSearchRespVO> bingSearch(String query, Integer count);
/** /**
* Google 搜索 * Google 搜索
* *
@ -26,7 +26,7 @@ public interface WebSearchService {
* @param count 返回结果数量 * @param count 返回结果数量
* @return 搜索结果列表 * @return 搜索结果列表
*/ */
List<WebSearchRespVO> googleSearch(String query, Integer count); List<AiWebSearchRespVO> googleSearch(String query, Integer count);
/** /**
* web 爬虫 * web 爬虫

View File

@ -6,7 +6,7 @@ import cn.hutool.http.HttpResponse;
import cn.hutool.json.JSONArray; import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject; import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import cn.iocoder.yudao.module.ai.service.websearch.vo.WebSearchRespVO; import cn.iocoder.yudao.module.ai.service.websearch.vo.AiWebSearchRespVO;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -46,11 +46,11 @@ public class WebSearchServiceImpl implements WebSearchService {
* @return 搜索结果列表 * @return 搜索结果列表
*/ */
@Override @Override
public List<WebSearchRespVO> bingSearch(String query, Integer count) { public List<AiWebSearchRespVO> bingSearch(String query, Integer count) {
if (query == null || query.isEmpty()) { if (query == null || query.isEmpty()) {
return CollUtil.newArrayList(); return CollUtil.newArrayList();
} }
try { try {
// 发送请求 // 发送请求
HttpResponse response = HttpRequest.get(BING_URL) HttpResponse response = HttpRequest.get(BING_URL)
@ -60,41 +60,41 @@ public class WebSearchServiceImpl implements WebSearchService {
.form("responseFilter", "Webpages") .form("responseFilter", "Webpages")
.form("textFormat", "Raw") .form("textFormat", "Raw")
.execute(); .execute();
// 解析响应 // 解析响应
String body = response.body(); String body = response.body();
JSONObject json = JSONUtil.parseObj(body); JSONObject json = JSONUtil.parseObj(body);
// 处理结果 // 处理结果
List<WebSearchRespVO> results = new ArrayList<>(); List<AiWebSearchRespVO> results = new ArrayList<>();
if (json.containsKey("webPages") && json.getJSONObject("webPages").containsKey("value")) { if (json.containsKey("webPages") && json.getJSONObject("webPages").containsKey("value")) {
JSONArray items = json.getJSONObject("webPages").getJSONArray("value"); JSONArray items = json.getJSONObject("webPages").getJSONArray("value");
for (int i = 0; i < items.size(); i++) { for (int i = 0; i < items.size(); i++) {
JSONObject item = items.getJSONObject(i); JSONObject item = items.getJSONObject(i);
WebSearchRespVO result = new WebSearchRespVO() AiWebSearchRespVO result = new AiWebSearchRespVO()
.setTitle(item.getStr("name")) .setTitle(item.getStr("name"))
.setUrl(item.getStr("url")) .setUrl(item.getStr("url"))
.setSnippet(item.getStr("snippet")); .setSnippet(item.getStr("snippet"));
results.add(result); results.add(result);
} }
} }
return results; return results;
} catch (Exception e) { } catch (Exception e) {
log.error("[bingSearch][查询({}) 发生异常]", query, e); log.error("[bingSearch][查询({}) 发生异常]", query, e);
return CollUtil.newArrayList(); return CollUtil.newArrayList();
} }
} }
/** /**
* Google 搜索使用Serper API * Google 搜索使用 Serper API
* *
* @param query 搜索关键词 * @param query 搜索关键词
* @param count 返回结果数量 * @param count 返回结果数量
* @return 搜索结果列表 * @return 搜索结果列表
*/ */
@Override @Override
public List<WebSearchRespVO> googleSearch(String query, Integer count) { public List<AiWebSearchRespVO> googleSearch(String query, Integer count) {
if (query == null || query.isEmpty()) { if (query == null || query.isEmpty()) {
return CollUtil.newArrayList(); return CollUtil.newArrayList();
} }
@ -105,24 +105,24 @@ public class WebSearchServiceImpl implements WebSearchService {
payload.set("q", query); payload.set("q", query);
payload.set("gl", "cn"); payload.set("gl", "cn");
payload.set("num", count); payload.set("num", count);
// 发送请求 // 发送请求
HttpResponse response = HttpRequest.post(GOOGLE_URL) HttpResponse response = HttpRequest.post(GOOGLE_URL)
.header("X-API-KEY", googleApiKey) .header("X-API-KEY", googleApiKey)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.body(payload.toString()) .body(payload.toString())
.execute(); .execute();
// 解析响应 // 解析响应
String body = response.body(); String body = response.body();
JSONObject json = JSONUtil.parseObj(body); JSONObject json = JSONUtil.parseObj(body);
JSONArray organicResults = json.getJSONArray("organic"); JSONArray organicResults = json.getJSONArray("organic");
// 处理结果 // 处理结果
List<WebSearchRespVO> results = new ArrayList<>(); List<AiWebSearchRespVO> results = new ArrayList<>();
for (int i = 0; i < organicResults.size(); i++) { for (int i = 0; i < organicResults.size(); i++) {
JSONObject item = organicResults.getJSONObject(i); JSONObject item = organicResults.getJSONObject(i);
WebSearchRespVO result = new WebSearchRespVO() AiWebSearchRespVO result = new AiWebSearchRespVO()
.setTitle(item.getStr("title")) .setTitle(item.getStr("title"))
.setUrl(item.getStr("link")) .setUrl(item.getStr("link"))
.setSnippet(item.containsKey("snippet") ? item.getStr("snippet") : ""); .setSnippet(item.containsKey("snippet") ? item.getStr("snippet") : "");
@ -146,13 +146,13 @@ public class WebSearchServiceImpl implements WebSearchService {
if (CollUtil.isEmpty(urls)) { if (CollUtil.isEmpty(urls)) {
return Map.of(); return Map.of();
} }
Map<String, String> result = new HashMap<>(); Map<String, String> result = new HashMap<>();
for (String url : urls) { for (String url : urls) {
try { try {
// 解析URL以获取域名作为Origin // 解析URL以获取域名作为Origin
String origin = extractOrigin(url); String origin = extractOrigin(url);
// 发送HTTP请求获取网页内容 // 发送HTTP请求获取网页内容
HttpResponse response = HttpRequest.get(url) HttpResponse response = HttpRequest.get(url)
.header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") .header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
@ -163,22 +163,22 @@ public class WebSearchServiceImpl implements WebSearchService {
.header("Cache-Control", "max-age=0") .header("Cache-Control", "max-age=0")
.timeout(10000) // 设置10秒超时 .timeout(10000) // 设置10秒超时
.execute(); .execute();
if (response.isOk()) { if (response.isOk()) {
String html = response.body(); String html = response.body();
// 使用Jsoup解析HTML并提取文本内容 // 使用Jsoup解析HTML并提取文本内容
org.jsoup.nodes.Document doc = org.jsoup.Jsoup.parse(html); org.jsoup.nodes.Document doc = org.jsoup.Jsoup.parse(html);
// 移除script和style元素它们包含的内容不是我们需要的文本 // 移除script和style元素它们包含的内容不是我们需要的文本
doc.select("script, style, meta, link").remove(); doc.select("script, style, meta, link").remove();
// 获取body中的文本内容 // 获取body中的文本内容
String text = doc.body().text(); String text = doc.body().text();
// 清理文本移除多余空格 // 清理文本移除多余空格
text = text.replaceAll("\\s+", " ").trim(); text = text.replaceAll("\\s+", " ").trim();
result.put(url, text); result.put(url, text);
} else { } else {
log.warn("[webCrawler][URL({}) 请求失败,状态码: {}]", url, response.getStatus()); log.warn("[webCrawler][URL({}) 请求失败,状态码: {}]", url, response.getStatus());
@ -189,20 +189,20 @@ public class WebSearchServiceImpl implements WebSearchService {
result.put(url, ""); result.put(url, "");
} }
} }
return result; return result;
} }
/** /**
* 从URL中提取Origin * 从URL中提取Origin
* *
* @param url 完整URL * @param url 完整URL
* @return Origin (scheme://host[:port]) * @return Origin (scheme://host[:port])
*/ */
private String extractOrigin(String url) { private String extractOrigin(String url) {
try { try {
java.net.URL parsedUrl = new java.net.URL(url); java.net.URL parsedUrl = new java.net.URL(url);
return parsedUrl.getProtocol() + "://" + parsedUrl.getHost() + return parsedUrl.getProtocol() + "://" + parsedUrl.getHost() +
(parsedUrl.getPort() == -1 ? "" : ":" + parsedUrl.getPort()); (parsedUrl.getPort() == -1 ? "" : ":" + parsedUrl.getPort());
} catch (Exception e) { } catch (Exception e) {
log.warn("[extractOrigin][URL({}) 解析异常]", url, e); log.warn("[extractOrigin][URL({}) 解析异常]", url, e);

View File

@ -1,14 +1,12 @@
package cn.iocoder.yudao.module.ai.service.websearch.vo; package cn.iocoder.yudao.module.ai.service.websearch.vo;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors;
/** /**
* 搜索结果 * AI 搜索结果
*/ */
@Data @Data
@Accessors(chain = true) public class AiWebSearchRespVO {
public class WebSearchRespVO {
/** /**
* 标题 * 标题
@ -26,4 +24,5 @@ public class WebSearchRespVO {
* 网站内容 * 网站内容
*/ */
private String content; private String content;
}
}

View File

@ -1 +1,2 @@
// TODO @芋艿看情况删除
package cn.iocoder.yudao.module.ai; package cn.iocoder.yudao.module.ai;

View File

@ -1,7 +1,7 @@
package cn.iocoder.yudao.module.ai.service; package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.service.websearch.WebSearchServiceImpl; import cn.iocoder.yudao.module.ai.service.websearch.WebSearchServiceImpl;
import cn.iocoder.yudao.module.ai.service.websearch.vo.WebSearchRespVO; import cn.iocoder.yudao.module.ai.service.websearch.vo.AiWebSearchRespVO;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -41,14 +41,14 @@ public class WebSearchServiceTests {
// 创建网络搜索服务实例 // 创建网络搜索服务实例
WebSearchServiceImpl webSearchService = new WebSearchServiceImpl(); WebSearchServiceImpl webSearchService = new WebSearchServiceImpl();
// 调用谷歌搜索方法搜索"长沙今天天气"限制返回6条结果 // 调用谷歌搜索方法搜索"长沙今天天气"限制返回6条结果
List<WebSearchRespVO> webSearchRespList = webSearchService.googleSearch("长沙今天天气", 6); List<AiWebSearchRespVO> webSearchRespList = webSearchService.googleSearch("长沙今天天气", 6);
// 从搜索结果中提取URL并爬取对应网页内容 // 从搜索结果中提取URL并爬取对应网页内容
Map<String, String> webCrawlerRespMap Map<String, String> webCrawlerRespMap
= webSearchService.webCrawler(webSearchRespList.stream().map(WebSearchRespVO::getUrl).toList()); = webSearchService.webCrawler(webSearchRespList.stream().map(AiWebSearchRespVO::getUrl).toList());
// 打印搜索结果 // 打印搜索结果
for (WebSearchRespVO webSearchRespVO : webSearchRespList) { for (AiWebSearchRespVO webSearchRespVO : webSearchRespList) {
System.err.println(JSON.toJSONString(webSearchRespVO)); System.err.println(JSON.toJSONString(webSearchRespVO));
} }

View File

@ -228,7 +228,7 @@ yudao:
wxa-subscribe-message: wxa-subscribe-message:
miniprogram-state: developer # 跳转小程序类型:开发版为 “developer”体验版为 “trial”为正式版为 “formal” miniprogram-state: developer # 跳转小程序类型:开发版为 “developer”体验版为 “trial”为正式版为 “formal”
tencent-lbs-key: TVDBZ-TDILD-4ON4B-PFDZA-RNLKH-VVF6E # QQ 地图的密钥 https://lbs.qq.com/service/staticV2/staticGuide/staticDoc tencent-lbs-key: TVDBZ-TDILD-4ON4B-PFDZA-RNLKH-VVF6E # QQ 地图的密钥 https://lbs.qq.com/service/staticV2/staticGuide/staticDoc
web-search: web-search: # TODO 芋艿key 要不要放到 yudao ai 那去
bing-api-key: xx bing-api-key: xx
google-api-key: xx google-api-key: xx
justauth: justauth: