Compare commits

...

8 Commits

Author SHA1 Message Date
YunaiV d7567e669c 【代码评审】AI:联网检索 2025-03-15 07:38:22 +08:00
cherishsince 4ea26c7e81 【优化】增加注释 2025-03-02 15:03:14 +08:00
cherishsince d963f8910e 【测试】google search 2025-03-02 14:57:00 +08:00
cherishsince b2d288d584 【优化】调整注释序号 2025-03-02 12:37:12 +08:00
cherishsince 4dd63b77cf 【优化】删除没用的import 2025-03-02 12:36:21 +08:00
cherishsince cb2e9e044f 【增加】增加模型联网搜索,重写用户 prompt 2025-03-02 12:34:09 +08:00
cherishsince f638f90afc 【新增】联网搜索,爬虫抓取网页内容 2025-03-02 12:19:12 +08:00
cherishsince 86801517d1 【新增】bing、google search 2025-03-02 12:05:39 +08:00
8 changed files with 412 additions and 14 deletions

View File

@ -3,9 +3,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors;
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO") @Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
@Data @Data
@ -22,4 +20,8 @@ public class AiChatMessageSendReqVO {
@Schema(description = "是否携带上下文", example = "true") @Schema(description = "是否携带上下文", example = "true")
private Boolean useContext; private Boolean useContext;
// TODO @芋艿改成 useSearch保持和 useContext 一个风格
@Schema(description = "是否搜索", example = "true")
private Boolean searchEnable;
} }

View File

@ -23,6 +23,8 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.websearch.WebSearchService;
import cn.iocoder.yudao.module.ai.service.websearch.vo.AiWebSearchRespVO;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
@ -60,7 +62,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Resource @Resource
private AiChatMessageMapper chatMessageMapper; private AiChatMessageMapper chatMessageMapper;
@Resource @Resource
private AiChatConversationService chatConversationService; private AiChatConversationService chatConversationService;
@Resource @Resource
@ -69,6 +70,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Resource @Resource
private AiKnowledgeSegmentService knowledgeSegmentService; private AiKnowledgeSegmentService knowledgeSegmentService;
@Resource
private WebSearchService webSearchService;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@ -93,11 +96,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 3.2 召回段落 // 3.2 召回段落
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 创建 chat 需要的 Prompt // 3.3 联网搜索内容
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO); // TODO @芋艿可能要改成前端检索
List<AiWebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
// 3.4 创建 Chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
ChatResponse chatResponse = chatModel.call(prompt); ChatResponse chatResponse = chatModel.call(prompt);
// 3.4 段式返回 // 3.5 段式返回
String newContent = chatResponse.getResult().getOutput().getContent(); String newContent = chatResponse.getResult().getOutput().getContent();
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent)); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
@ -124,15 +131,18 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 召回段落 // 3.2 召回段落
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 构建 Prompt并进行调用 // 3.3 联网搜索
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO); // todo count 看是否需要放到配置文件
List<AiWebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
// 3.4 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.4 流式返回 // 3.5 流式返回
// TODO 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 // TODO 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
@ -155,6 +165,29 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
} }
/**
* 获取 web search
*
* @param prompt 提示词
* @param searchEnable 查询到会否开启
* @param count 查询数量
* @return 返回查询结果
*/
private List<AiWebSearchRespVO> getWebSearch(String prompt, Boolean searchEnable, int count) {
if (searchEnable != null && searchEnable) {
List<AiWebSearchRespVO> webSearchRespList = webSearchService.bingSearch(prompt, count);
Map<String, String> webCrawlerRespMap
= webSearchService.webCrawler(webSearchRespList.stream().map(AiWebSearchRespVO::getUrl).toList());
for (AiWebSearchRespVO webSearchRespVO : webSearchRespList) {
if (!webCrawlerRespMap.containsKey(webSearchRespVO.getUrl())) {
continue;
}
webSearchRespVO.setContent(webCrawlerRespMap.get(webSearchRespVO.getUrl()));
}
}
return Collections.emptyList();
}
private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) { private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
if (Objects.isNull(knowledgeId)) { if (Objects.isNull(knowledgeId)) {
return Collections.emptyList(); return Collections.emptyList();
@ -162,8 +195,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content)); return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList, private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { List<AiKnowledgeSegmentDO> segmentList, AiChatModelDO model,
AiChatMessageSendReqVO sendReqVO, List<AiWebSearchRespVO> webSearchRespList) {
// 1. 构建 Prompt Message 列表 // 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
@ -184,7 +218,26 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO); List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
// 1.4 user message 新发送消息 // 1.4 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); // TODO @芋艿处理下 prompt 模版
if (sendReqVO.getSearchEnable() != null
&& sendReqVO.getSearchEnable() && CollUtil.isNotEmpty(webSearchRespList)) {
StringBuilder promptBuilder = StrUtil.builder();
promptBuilder.append("## 以下是联网搜索内容: \n");
int i = 1;
for (AiWebSearchRespVO webSearchRespVO : webSearchRespList) {
promptBuilder.append("[内容%s begin]".formatted(i)).append("\n");
promptBuilder.append("标题:").append(webSearchRespVO.getTitle()).append("\n");
promptBuilder.append("地址:").append(webSearchRespVO.getUrl()).append("\n");
promptBuilder.append("内容:").append(webSearchRespVO.getContent()).append("\n");
promptBuilder.append("[内容%s end]".formatted(i)).append("\n");
i++;
}
promptBuilder.append("## 用户问题如下: \n");
promptBuilder.append(sendReqVO.getContent()).append("\n");
} else {
chatMessages.add(new UserMessage(sendReqVO.getContent()));
}
// 2. 构建 ChatOptions 对象 // 2. 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());

View File

@ -0,0 +1,38 @@
package cn.iocoder.yudao.module.ai.service.websearch;
import cn.iocoder.yudao.module.ai.service.websearch.vo.AiWebSearchRespVO;
import java.util.List;
import java.util.Map;
/**
* Web 搜索 Service 接口
*/
public interface WebSearchService {
/**
* bing 搜索
*
* @param query 搜索关键词
* @param count 返回结果数量
* @return 搜索结果列表
*/
List<AiWebSearchRespVO> bingSearch(String query, Integer count);
/**
* Google 搜索
*
* @param query 搜索关键词
* @param count 返回结果数量
* @return 搜索结果列表
*/
List<AiWebSearchRespVO> googleSearch(String query, Integer count);
/**
* web 爬虫
*
* @param urls 爬虫地址
* @return key: url value爬虫内容
*/
Map<String, String> webCrawler(List<String> urls);
}

View File

@ -0,0 +1,212 @@
package cn.iocoder.yudao.module.ai.service.websearch;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import cn.iocoder.yudao.module.ai.service.websearch.vo.AiWebSearchRespVO;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Bing Web 搜索实现类
*/
@Service
@Slf4j
public class WebSearchServiceImpl implements WebSearchService {
/**
* google url
*/
private static final String GOOGLE_URL = "https://google.serper.dev/search";
/**
* bing url
*/
private static final String BING_URL = "https://api.bing.microsoft.com/v7.0/search";
@Value("${yudao.web-search.api-key:}")
private String bingApiKey;
@Value("${yudao.web-search.api-key:}")
private String googleApiKey;
/**
* bing 搜索
*
* @param query 搜索关键词
* @param count 返回结果数量
* @return 搜索结果列表
*/
@Override
public List<AiWebSearchRespVO> bingSearch(String query, Integer count) {
if (query == null || query.isEmpty()) {
return CollUtil.newArrayList();
}
try {
// 发送请求
HttpResponse response = HttpRequest.get(BING_URL)
.header("Ocp-Apim-Subscription-Key", bingApiKey)
.form("q", query)
.form("count", String.valueOf(count))
.form("responseFilter", "Webpages")
.form("textFormat", "Raw")
.execute();
// 解析响应
String body = response.body();
JSONObject json = JSONUtil.parseObj(body);
// 处理结果
List<AiWebSearchRespVO> results = new ArrayList<>();
if (json.containsKey("webPages") && json.getJSONObject("webPages").containsKey("value")) {
JSONArray items = json.getJSONObject("webPages").getJSONArray("value");
for (int i = 0; i < items.size(); i++) {
JSONObject item = items.getJSONObject(i);
AiWebSearchRespVO result = new AiWebSearchRespVO()
.setTitle(item.getStr("name"))
.setUrl(item.getStr("url"))
.setSnippet(item.getStr("snippet"));
results.add(result);
}
}
return results;
} catch (Exception e) {
log.error("[bingSearch][查询({}) 发生异常]", query, e);
return CollUtil.newArrayList();
}
}
/**
* Google 搜索使用 Serper API
*
* @param query 搜索关键词
* @param count 返回结果数量
* @return 搜索结果列表
*/
@Override
public List<AiWebSearchRespVO> googleSearch(String query, Integer count) {
if (query == null || query.isEmpty()) {
return CollUtil.newArrayList();
}
try {
// 构建请求体
JSONObject payload = new JSONObject();
payload.set("q", query);
payload.set("gl", "cn");
payload.set("num", count);
// 发送请求
HttpResponse response = HttpRequest.post(GOOGLE_URL)
.header("X-API-KEY", googleApiKey)
.header("Content-Type", "application/json")
.body(payload.toString())
.execute();
// 解析响应
String body = response.body();
JSONObject json = JSONUtil.parseObj(body);
JSONArray organicResults = json.getJSONArray("organic");
// 处理结果
List<AiWebSearchRespVO> results = new ArrayList<>();
for (int i = 0; i < organicResults.size(); i++) {
JSONObject item = organicResults.getJSONObject(i);
AiWebSearchRespVO result = new AiWebSearchRespVO()
.setTitle(item.getStr("title"))
.setUrl(item.getStr("link"))
.setSnippet(item.containsKey("snippet") ? item.getStr("snippet") : "");
results.add(result);
}
return results;
} catch (Exception e) {
log.error("[googleSearch][查询({}) 发生异常]", query, e);
return CollUtil.newArrayList();
}
}
/**
* web 爬虫
*
* @param urls 爬虫地址
* @return key: url value爬虫内容
*/
@Override
public Map<String, String> webCrawler(List<String> urls) {
if (CollUtil.isEmpty(urls)) {
return Map.of();
}
Map<String, String> result = new HashMap<>();
for (String url : urls) {
try {
// 解析URL以获取域名作为Origin
String origin = extractOrigin(url);
// 发送HTTP请求获取网页内容
HttpResponse response = HttpRequest.get(url)
.header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
.header("Origin", origin)
.header("Referer", origin)
.header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7")
.header("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8")
.header("Cache-Control", "max-age=0")
.timeout(10000) // 设置10秒超时
.execute();
if (response.isOk()) {
String html = response.body();
// 使用Jsoup解析HTML并提取文本内容
org.jsoup.nodes.Document doc = org.jsoup.Jsoup.parse(html);
// 移除script和style元素它们包含的内容不是我们需要的文本
doc.select("script, style, meta, link").remove();
// 获取body中的文本内容
String text = doc.body().text();
// 清理文本移除多余空格
text = text.replaceAll("\\s+", " ").trim();
result.put(url, text);
} else {
log.warn("[webCrawler][URL({}) 请求失败,状态码: {}]", url, response.getStatus());
result.put(url, "");
}
} catch (Exception e) {
log.error("[webCrawler][URL({}) 爬取异常]", url, e);
result.put(url, "");
}
}
return result;
}
/**
* 从URL中提取Origin
*
* @param url 完整URL
* @return Origin (scheme://host[:port])
*/
private String extractOrigin(String url) {
try {
java.net.URL parsedUrl = new java.net.URL(url);
return parsedUrl.getProtocol() + "://" + parsedUrl.getHost() +
(parsedUrl.getPort() == -1 ? "" : ":" + parsedUrl.getPort());
} catch (Exception e) {
log.warn("[extractOrigin][URL({}) 解析异常]", url, e);
return "";
}
}
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.module.ai.service.websearch.vo;
import lombok.Data;
/**
* AI 搜索结果
*/
@Data
public class AiWebSearchRespVO {
/**
* 标题
*/
private String title;
/**
* URL
*/
private String url;
/**
* 摘要
*/
private String snippet;
/**
* 网站内容
*/
private String content;
}

View File

@ -0,0 +1,2 @@
// TODO @芋艿看情况删除
package cn.iocoder.yudao.module.ai;

View File

@ -0,0 +1,61 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.service.websearch.WebSearchServiceImpl;
import cn.iocoder.yudao.module.ai.service.websearch.vo.AiWebSearchRespVO;
import com.alibaba.fastjson.JSON;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Map;
/**
* 网络搜索服务测试类
* 用于测试网页爬取和谷歌搜索功能
*/
public class WebSearchServiceTests {
/**
* 测试网页爬取功能
* 爬取指定URL的网页内容并打印结果
*/
@Test
public void webCrawlerTest() {
// 创建网络搜索服务实例
WebSearchServiceImpl webSearchService = new WebSearchServiceImpl();
// 调用网页爬取方法获取长沙天气页面内容
Map<String, String> webCrawlerRes = webSearchService.webCrawler(
Lists.newArrayList("https://tianqi.eastday.com/changsha/40/"));
// 遍历并打印爬取结果
for (Map.Entry<String, String> entry : webCrawlerRes.entrySet()) {
System.err.println(entry.getValue());
}
}
/**
* 测试谷歌搜索功能
* 搜索指定关键词获取搜索结果并爬取对应网页内容
*/
@Test
public void googleSearchTest() {
// 创建网络搜索服务实例
WebSearchServiceImpl webSearchService = new WebSearchServiceImpl();
// 调用谷歌搜索方法搜索"长沙今天天气"限制返回6条结果
List<AiWebSearchRespVO> webSearchRespList = webSearchService.googleSearch("长沙今天天气", 6);
// 从搜索结果中提取URL并爬取对应网页内容
Map<String, String> webCrawlerRespMap
= webSearchService.webCrawler(webSearchRespList.stream().map(AiWebSearchRespVO::getUrl).toList());
// 打印搜索结果
for (AiWebSearchRespVO webSearchRespVO : webSearchRespList) {
System.err.println(JSON.toJSONString(webSearchRespVO));
}
// 打印爬取的网页内容
for (Map.Entry<String, String> entry : webCrawlerRespMap.entrySet()) {
System.err.println("url:" + entry.getKey());
System.err.println("value" + entry.getValue());
}
}
}

View File

@ -228,7 +228,9 @@ yudao:
wxa-subscribe-message: wxa-subscribe-message:
miniprogram-state: developer # 跳转小程序类型:开发版为 “developer”体验版为 “trial”为正式版为 “formal” miniprogram-state: developer # 跳转小程序类型:开发版为 “developer”体验版为 “trial”为正式版为 “formal”
tencent-lbs-key: TVDBZ-TDILD-4ON4B-PFDZA-RNLKH-VVF6E # QQ 地图的密钥 https://lbs.qq.com/service/staticV2/staticGuide/staticDoc tencent-lbs-key: TVDBZ-TDILD-4ON4B-PFDZA-RNLKH-VVF6E # QQ 地图的密钥 https://lbs.qq.com/service/staticV2/staticGuide/staticDoc
web-search: # TODO 芋艿key 要不要放到 yudao ai 那去
bing-api-key: xx
google-api-key: xx
justauth: justauth:
enabled: true enabled: true
type: type: