【功能新增】AI:新增 document 向量的进度查询

This commit is contained in:
YunaiV 2025-03-02 20:54:02 +08:00
parent ebd93514b3
commit 5f5e77a392
20 changed files with 385 additions and 61 deletions

View File

@ -9,7 +9,8 @@ tenant-id: {{adminTenantId}}
"description": "测试描述", "description": "测试描述",
"embeddingModelId": 30, "embeddingModelId": 30,
"topK": 3, "topK": 3,
"similarityThreshold": 0.5 "similarityThreshold": 0.5,
"status": 0
} }
### 更新知识库 ### 更新知识库
@ -24,7 +25,8 @@ tenant-id: {{adminTenantId}}
"description": "测试描述", "description": "测试描述",
"embeddingModelId": 30, "embeddingModelId": 30,
"topK": 5, "topK": 5,
"similarityThreshold": 0.6 "similarityThreshold": 0.6,
"status": 0
} }
### 获取知识库分页 ### 获取知识库分页

View File

@ -5,7 +5,7 @@ Authorization: Bearer {{token}}
tenant-id: {{adminTenantId}} tenant-id: {{adminTenantId}}
{ {
"knowledgeId": 1, "knowledgeId": 2,
"name": "测试文档", "name": "测试文档",
"url": "https://static.iocoder.cn/README.md", "url": "https://static.iocoder.cn/README.md",
"segmentMaxTokens": 800 "segmentMaxTokens": 800

View File

@ -4,9 +4,14 @@ Content-Type: application/json
Authorization: Bearer {{token}} Authorization: Bearer {{token}}
tenant-id: {{adminTenantId}} tenant-id: {{adminTenantId}}
### 搜索段落内容
GET {{baseUrl}}/ai/knowledge/segment/search?knowledgeId=2&content=如何使用这个产品&topK=5&similarityThreshold=0.1
Content-Type: application/json
Authorization: Bearer {{token}}
tenant-id: {{adminTenantId}}
### 获取文档处理列表 ### 获取文档处理列表
GET {{baseUrl}}/ai/knowledge/segment/get-process-list?documentIds=1,2,3 GET {{baseUrl}}/ai/knowledge/segment/get-process-list?documentIds=1,2,3
Content-Type: application/json Content-Type: application/json
Authorization: Bearer {{token}} Authorization: Bearer {{token}}
tenant-id: {{adminTenantId}} tenant-id: {{adminTenantId}}

View File

@ -1,15 +1,17 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge; package cn.iocoder.yudao.module.ai.controller.admin.knowledge;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.collection.MapUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.*;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.Parameters; import io.swagger.v3.oas.annotations.Parameters;
@ -20,9 +22,12 @@ import org.hibernate.validator.constraints.URL;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
// TODO @芋艿增加权限标识 // TODO @芋艿增加权限标识
@Tag(name = "管理后台 - AI 知识库段落") @Tag(name = "管理后台 - AI 知识库段落")
@ -34,6 +39,9 @@ public class AiKnowledgeSegmentController {
@Resource @Resource
private AiKnowledgeSegmentService segmentService; private AiKnowledgeSegmentService segmentService;
@Resource
private AiKnowledgeDocumentService documentService;
@GetMapping("/page") @GetMapping("/page")
@Operation(summary = "获取段落分页") @Operation(summary = "获取段落分页")
public CommonResult<PageResult<AiKnowledgeSegmentRespVO>> getKnowledgeSegmentPage( public CommonResult<PageResult<AiKnowledgeSegmentRespVO>> getKnowledgeSegmentPage(
@ -79,4 +87,23 @@ public class AiKnowledgeSegmentController {
return success(list); return success(list);
} }
@GetMapping("/search")
@Operation(summary = "搜索段落内容")
public CommonResult<List<AiKnowledgeSegmentSearchRespVO>> searchKnowledgeSegment(
@Valid AiKnowledgeSegmentSearchReqVO reqVO) {
// 1. 搜索段落
List<AiKnowledgeSegmentSearchRespBO> segments = segmentService
.searchKnowledgeSegment(BeanUtils.toBean(reqVO, AiKnowledgeSegmentSearchReqBO.class));
if (CollUtil.isEmpty(segments)) {
return success(Collections.emptyList());
}
// 2. 拼接 VO
Map<Long, AiKnowledgeDocumentDO> documentMap = documentService.getKnowledgeDocumentMap(convertSet(
segments, AiKnowledgeSegmentSearchRespBO::getDocumentId));
return success(BeanUtils.toBean(segments, AiKnowledgeSegmentSearchRespVO.class,
segment -> MapUtils.findAndThen(documentMap, segment.getDocumentId(),
document -> segment.setDocumentName(document.getName()))));
}
} }

View File

@ -0,0 +1,22 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.validation.InEnum;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - AI 知识库文档更新状态 Request VO")
@Data
public class AiKnowledgeDocumentUpdateStatusReqVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "15583")
@NotNull(message = "编号不能为空")
private Long id;
@Schema(description = "状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "0")
@NotNull(message = "状态不能为空")
@InEnum(CommonStatusEnum.class)
private Integer status;
}

View File

@ -3,15 +3,25 @@ package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
@Schema(description = "管理后台 - AI 知识库段落召回 Request VO") @Schema(description = "管理后台 - AI 知识库段落搜索 Request VO")
@Data @Data
public class AiKnowledgeSegmentSearchReqVO { public class AiKnowledgeSegmentSearchReqVO {
@Schema(description = "知识库编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790") @Schema(description = "知识库编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
@NotNull(message = "知识库编号不能为空")
private Long knowledgeId; private Long knowledgeId;
@Schema(description = "内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 学习路线") @Schema(description = "内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "如何使用这个产品")
@NotEmpty(message = "内容不能为空")
private String content; private String content;
@Schema(description = "最大返回数量", example = "5")
private Integer topK;
@Schema(description = "相似度阈值", example = "0.7")
private Double similarityThreshold;
} }

View File

@ -0,0 +1,16 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Schema(description = "管理后台 - AI 知识库段落搜索 Response VO")
@Data
public class AiKnowledgeSegmentSearchRespVO extends AiKnowledgeSegmentRespVO {
@Schema(description = "文档名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "产品使用手册")
private String documentName;
@Schema(description = "相似度分数", requiredMode = Schema.RequiredMode.REQUIRED, example = "0.95")
private Double score;
}

View File

@ -65,6 +65,7 @@ public class AiChatConversationDO extends BaseDO {
*/ */
private Long roleId; private Long roleId;
// TODO @芋艿可优化绑定多个知识库前提spring ai 支持 RerankModel 的封装
/** /**
* 知识库编号 * 知识库编号
* <p> * <p>

View File

@ -5,8 +5,11 @@ import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.Collection;
/** /**
* AI 知识库文档 Mapper * AI 知识库文档 Mapper
* *
@ -22,4 +25,10 @@ public interface AiKnowledgeDocumentMapper extends BaseMapperX<AiKnowledgeDocume
.orderByDesc(AiKnowledgeDocumentDO::getId)); .orderByDesc(AiKnowledgeDocumentDO::getId));
} }
default void updateRetrievalCountIncr(Collection<Long> ids) {
update( new LambdaUpdateWrapper<AiKnowledgeDocumentDO>()
.setSql(" retrieval_count = retrieval_count + 1")
.in(AiKnowledgeDocumentDO::getId, ids));
}
} }

View File

@ -7,11 +7,12 @@ import cn.iocoder.yudao.framework.mybatis.core.query.MPJLambdaWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.github.yulichang.wrapper.MPJLambdaWrapper; import com.github.yulichang.wrapper.MPJLambdaWrapper;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List;
import java.util.Collection; import java.util.Collection;
import java.util.List;
/** /**
* AI 知识库分片 Mapper * AI 知识库分片 Mapper
@ -52,4 +53,10 @@ public interface AiKnowledgeSegmentMapper extends BaseMapperX<AiKnowledgeSegment
return selectJoinList(AiKnowledgeSegmentProcessRespVO.class, wrapper); return selectJoinList(AiKnowledgeSegmentProcessRespVO.class, wrapper);
} }
default void updateRetrievalCountIncrByIds(List<Long> ids) {
update( new LambdaUpdateWrapper<AiKnowledgeSegmentDO>()
.setSql(" retrieval_count = retrieval_count + 1")
.in(AiKnowledgeSegmentDO::getId, ids));
}
} }

View File

@ -12,7 +12,6 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
@ -133,7 +132,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.4 流式返回 // 3.4 流式返回
// TODO 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null; String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null;
@ -159,7 +157,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
if (Objects.isNull(knowledgeId)) { if (Objects.isNull(knowledgeId)) {
return Collections.emptyList(); return Collections.emptyList();
} }
return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content)); // return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
return null;
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList, private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,

View File

@ -1,14 +1,18 @@
package cn.iocoder.yudao.module.ai.service.knowledge; package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentCreateListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentCreateListReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
/** /**
* AI 知识库文档 Service 接口 * AI 知识库文档 Service 接口
@ -63,6 +67,13 @@ public interface AiKnowledgeDocumentService {
*/ */
void updateKnowledgeDocumentStatus(AiKnowledgeDocumentUpdateStatusReqVO reqVO); void updateKnowledgeDocumentStatus(AiKnowledgeDocumentUpdateStatusReqVO reqVO);
/**
* 更新文档检索次数增加 +1
*
* @param ids 文档编号列表
*/
void updateKnowledgeDocumentRetrievalCountIncr(Collection<Long> ids);
/** /**
* 校验文档是否存在 * 校验文档是否存在
* *
@ -79,4 +90,22 @@ public interface AiKnowledgeDocumentService {
*/ */
String readUrl(String url); String readUrl(String url);
/**
* 获取文档列表
*
* @param ids 文档编号列表
* @return 文档列表
*/
List<AiKnowledgeDocumentDO> getKnowledgeDocumentList(Collection<Long> ids);
/**
* 获取文档 Map
*
* @param ids 文档编号列表
* @return 文档 Map
*/
default Map<Long, AiKnowledgeDocumentDO> getKnowledgeDocumentMap(Collection<Long> ids) {
return convertMap(getKnowledgeDocumentList(ids), AiKnowledgeDocumentDO::getId);
}
} }

View File

@ -7,10 +7,10 @@ import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentCreateListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentCreateListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
@ -25,6 +25,7 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.List; import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -148,6 +149,14 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
} }
} }
@Override
public void updateKnowledgeDocumentRetrievalCountIncr(Collection<Long> ids) {
if (CollUtil.isEmpty(ids)) {
return;
}
knowledgeDocumentMapper.updateRetrievalCountIncr(ids);
}
@Override @Override
public AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id) { public AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id) {
AiKnowledgeDocumentDO knowledgeDocument = knowledgeDocumentMapper.selectById(id); AiKnowledgeDocumentDO knowledgeDocument = knowledgeDocumentMapper.selectById(id);
@ -182,4 +191,12 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
return document.getText(); return document.getText();
} }
@Override
public List<AiKnowledgeDocumentDO> getKnowledgeDocumentList(Collection<Long> ids) {
if (CollUtil.isEmpty(ids)) {
return new ArrayList<>();
}
return knowledgeDocumentMapper.selectByIds(ids);
}
} }

View File

@ -2,11 +2,12 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import java.util.List; import java.util.List;
@ -67,12 +68,12 @@ public interface AiKnowledgeSegmentService {
void deleteKnowledgeSegmentByDocumentId(Long documentId); void deleteKnowledgeSegmentByDocumentId(Long documentId);
/** /**
* 召回段落 * 搜索知识库段落并返回结果
* *
* @param reqVO 召回请求信息 * @param reqBO 搜索请求信息
* @return 召回的段落 * @return 搜索结果段落列表
*/ */
List<AiKnowledgeSegmentDO> similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO); List<AiKnowledgeSegmentSearchRespBO> searchKnowledgeSegment(AiKnowledgeSegmentSearchReqBO reqBO);
/** /**
* 根据 URL 内容切片创建多个段落 * 根据 URL 内容切片创建多个段落

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil; import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
@ -12,6 +13,8 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
@ -171,25 +174,45 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
} }
@Override @Override
public List<AiKnowledgeSegmentDO> similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO) { public List<AiKnowledgeSegmentSearchRespBO> searchKnowledgeSegment(AiKnowledgeSegmentSearchReqBO reqBO) {
// 1. 校验 // 1. 校验
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqVO.getKnowledgeId()); AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
// 2. 获取向量存储实例 // 2.1 向量检索
VectorStore vectorStore = apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId()); VectorStore vectorStore = apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
// 3.1 向量检索
List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder() List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
.query(reqVO.getContent()) .query(reqBO.getContent())
.topK(knowledge.getTopK()).similarityThreshold(knowledge.getSimilarityThreshold()) .topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()))
.similarityThreshold(ObjUtil.defaultIfNull(reqBO.getSimilarityThreshold(), knowledge.getSimilarityThreshold()))
.filterExpression(new FilterExpressionBuilder() .filterExpression(new FilterExpressionBuilder()
.eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqVO.getKnowledgeId()).build()) .eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqBO.getKnowledgeId()).build())
.build()); .build());
if (CollUtil.isEmpty(documents)) { if (CollUtil.isEmpty(documents)) {
return ListUtil.empty(); return ListUtil.empty();
} }
// 3.2 段落召回 // 2.2 段落召回
return segmentMapper.selectListByVectorIds(convertList(documents, Document::getId)); List<AiKnowledgeSegmentDO> segments = segmentMapper
.selectListByVectorIds(convertList(documents, Document::getId));
if (CollUtil.isEmpty(segments)) {
return ListUtil.empty();
}
// 3. 增加召回次数
segmentMapper.updateRetrievalCountIncrByIds(convertList(segments, AiKnowledgeSegmentDO::getId));
// 4. 构建结果
List<AiKnowledgeSegmentSearchRespBO> result = convertList(segments, segment -> {
Document document = CollUtil.findOne(documents, // 找到对应的文档
doc -> Objects.equals(doc.getId(), segment.getVectorId()));
if (document == null) {
return null;
}
return BeanUtils.toBean(segment, AiKnowledgeSegmentSearchRespBO.class)
.setScore(document.getScore());
});
result.sort((o1, o2)
-> Double.compare(o2.getScore(), o1.getScore())); // 按照分数降序排序
return result;
} }
@Override @Override

View File

@ -0,0 +1,39 @@
package cn.iocoder.yudao.module.ai.service.knowledge.bo;
import lombok.Data;
import javax.validation.constraints.NotNull;
import jakarta.validation.constraints.NotEmpty;
/**
* AI 知识库段落搜索 Request BO
*
* @author 芋道源码
*/
@Data
public class AiKnowledgeSegmentSearchReqBO {
/**
* 知识库编号
*/
@NotNull(message = "知识库编号不能为空")
private Long knowledgeId;
/**
* 内容
*/
@NotEmpty(message = "内容不能为空")
private String content;
/**
* 最大返回数量
*/
private Integer topK;
/**
* 相似度阈值
*/
private Double similarityThreshold;
}

View File

@ -0,0 +1,45 @@
package cn.iocoder.yudao.module.ai.service.knowledge.bo;
import lombok.Data;
/**
* AI 知识库段落搜索 Response BO
*
* @author 芋道源码
*/
@Data
public class AiKnowledgeSegmentSearchRespBO {
/**
* 段落编号
*/
private Long id;
/**
* 文档编号
*/
private Long documentId;
/**
* 知识库编号
*/
private Long knowledgeId;
/**
* 内容
*/
private String content;
/**
* 内容长度
*/
private Integer contentLength;
/**
* Token 数量
*/
private Integer tokens;
/**
* 相似度分数
*/
private Double score;
}

View File

@ -16,6 +16,7 @@ import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -154,7 +155,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
apiKey.getUrl(), chatModel.getModel()); apiKey.getUrl(), chatModel.getModel());
// 创建或获取 VectorStore 对象 // 创建或获取 VectorStore 对象
return modelFactory.getOrCreateVectorStore(embeddingModel); return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
} }
} }

View File

@ -99,9 +99,10 @@ public interface AiModelFactory {
* <p> * <p>
* 如果不存在则进行创建 * 如果不存在则进行创建
* *
* @param type 向量存储类型
* @param embeddingModel 向量模型 * @param embeddingModel 向量模型
* @return VectorStore 对象 * @return VectorStore 对象
*/ */
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel); VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel);
} }

View File

@ -1,9 +1,11 @@
package cn.iocoder.yudao.framework.ai.core.factory; package cn.iocoder.yudao.framework.ai.core.factory;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.lang.Singleton; import cn.hutool.core.lang.Singleton;
import cn.hutool.core.lang.func.Func0; import cn.hutool.core.lang.func.Func0;
import cn.hutool.core.util.ArrayUtil; import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.RuntimeUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil; import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration; import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
@ -24,6 +26,7 @@ import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions; import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel; import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.OpenAIClientBuilder;
import lombok.SneakyThrows;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
@ -60,7 +63,11 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
import org.springframework.web.client.RestClient; import org.springframework.web.client.RestClient;
import java.io.File;
import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.Timer;
import java.util.TimerTask;
/** /**
* AI Model 模型工厂的实现类 * AI Model 模型工厂的实现类
@ -73,7 +80,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) { public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url); String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> { return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
//noinspection EnhancedSwitchMigration // noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case TONG_YI: case TONG_YI:
return buildTongYiChatModel(apiKey); return buildTongYiChatModel(apiKey);
@ -105,7 +112,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override @Override
public ChatModel getDefaultChatModel(AiPlatformEnum platform) { public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration // noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case TONG_YI: case TONG_YI:
return SpringUtil.getBean(DashScopeChatModel.class); return SpringUtil.getBean(DashScopeChatModel.class);
@ -136,7 +143,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override @Override
public ImageModel getDefaultImageModel(AiPlatformEnum platform) { public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration // noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case TONG_YI: case TONG_YI:
return SpringUtil.getBean(DashScopeImageModel.class); return SpringUtil.getBean(DashScopeImageModel.class);
@ -155,7 +162,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override @Override
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) { public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration // noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case TONG_YI: case TONG_YI:
return buildTongYiImagesModel(apiKey); return buildTongYiImagesModel(apiKey);
@ -174,9 +181,11 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override @Override
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) { public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url); String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey,
url);
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> { return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney(); YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class)
.getMidjourney();
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl()); return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
}); });
} }
@ -204,25 +213,31 @@ public class AiModelFactoryImpl implements AiModelFactory {
} }
@Override @Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel) { public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel) {
// String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url); // String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey,
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel); // url);
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel, type);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> { return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
if (type == SimpleVectorStore.class) {
return buildSimpleVectorStore(embeddingModel);
}
throw new IllegalArgumentException(StrUtil.format("未知类型({})", type));
// TODO @芋艿先临时使用 store // TODO @芋艿先临时使用 store
return SimpleVectorStore.builder(embeddingModel).build();
// TODO @芋艿@xin后续看看是不是切到阿里云之类的 // TODO @芋艿@xin后续看看是不是切到阿里云之类的
// String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey); // String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
// var config = RedisVectorStore.RedisVectorStoreConfig.builder() // var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(cacheKey) // .withIndexName(cacheKey)
// .withPrefix(prefix) // .withPrefix(prefix)
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC)) // .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId",
// .build(); // Schema.FieldType.NUMERIC))
// RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class); // .build();
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel, // RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), // RedisVectorStore redisVectorStore = new RedisVectorStore(config,
// true); // embeddingModel,
// redisVectorStore.afterPropertiesSet(); // new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
// return redisVectorStore; // true);
// redisVectorStore.afterPropertiesSet();
// return redisVectorStore;
}); });
} }
@ -307,7 +322,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
*/ */
private ChatModel buildSiliconFlowChatModel(String apiKey) { private ChatModel buildSiliconFlowChatModel(String apiKey) {
YudaoAiProperties.SiliconFlowProperties properties = new YudaoAiProperties.SiliconFlowProperties() YudaoAiProperties.SiliconFlowProperties properties = new YudaoAiProperties.SiliconFlowProperties()
.setApiKey(apiKey); .setApiKey(apiKey);
return new YudaoAiAutoConfiguration().buildSiliconFlowChatClient(properties); return new YudaoAiAutoConfiguration().buildSiliconFlowChatClient(properties);
} }
@ -397,7 +412,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
*/ */
private DashScopeEmbeddingModel buildTongYiEmbeddingModel(String apiKey, String model) { private DashScopeEmbeddingModel buildTongYiEmbeddingModel(String apiKey, String model) {
DashScopeApi dashScopeApi = new DashScopeApi(apiKey); DashScopeApi dashScopeApi = new DashScopeApi(apiKey);
DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model).build(); DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model)
.build();
return new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, dashScopeEmbeddingOptions); return new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, dashScopeEmbeddingOptions);
} }
@ -407,4 +423,58 @@ public class AiModelFactoryImpl implements AiModelFactory {
return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build(); return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build();
} }
// ========== 各种创建 VectorStore 的方法 ==========
/**
* 注意仅适合本地测试使用生产建议还是使用 QdrantMilvus
*/
@SneakyThrows
@SuppressWarnings("ResultOfMethodCallIgnored")
private SimpleVectorStore buildSimpleVectorStore(EmbeddingModel embeddingModel) {
SimpleVectorStore vectorStore = SimpleVectorStore.builder(embeddingModel).build();
// 启动加载
File file = new File(StrUtil.format("{}/vector_store/simple_{}.json",
FileUtil.getUserHomePath(), embeddingModel.getClass().getSimpleName()));
if (!file.exists()) {
FileUtil.mkParentDirs(file);
file.createNewFile();
} else if (file.length() > 0) {
vectorStore.load(file);
}
// 定时持久化每分钟一次
Timer timer = new Timer("SimpleVectorStoreTimer-" + file.getAbsolutePath());
timer.scheduleAtFixedRate(new TimerTask() {
@Override
public void run() {
vectorStore.save(file);
}
}, Duration.ofMinutes(1).toMillis(), Duration.ofMinutes(1).toMillis());
// 关闭时进行持久化
RuntimeUtil.addShutdownHook(() -> vectorStore.save(file));
return vectorStore;
}
/**
* 创建向量存储文件
*
* @param embeddingModel 嵌入模型
* @return 向量存储文件
*/
private File createVectorStoreFile(EmbeddingModel embeddingModel) {
// 获取简单类名
String simpleClassName = embeddingModel.getClass().getSimpleName();
// 获取用户主目录
String userHome = FileUtil.getUserHomePath();
// 创建vector_store目录
File vectorStoreDir = new File(userHome, "vector_store");
if (!vectorStoreDir.exists()) {
vectorStoreDir.mkdirs();
}
// 创建文件
return new File(vectorStoreDir, "simple_" + simpleClassName + ".json");
}
} }