【代码重构】AI:知识库相关的表结构

This commit is contained in:
YunaiV 2025-02-28 07:43:43 +08:00
parent deca69ada6
commit 0a8c75625a
31 changed files with 356 additions and 365 deletions

View File

@ -21,7 +21,7 @@ public interface ErrorCodeConstants {
ErrorCode CHAT_MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!");
ErrorCode CHAT_MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认聊天模型");
// ========== API 聊天模型 1-040-002-000 ==========
// ========== API 聊天角色 1-040-002-000 ==========
ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在");
ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!");
@ -40,7 +40,6 @@ public interface ErrorCodeConstants {
ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
ErrorCode IMAGE_FAIL = new ErrorCode(1_022_005_002, "图片绘画失败! {}");
// ========== API 音乐 1-040-006-000 ==========
ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");
@ -54,7 +53,11 @@ public interface ErrorCodeConstants {
// ========== API 知识库 1-022-008-000 ==========
ErrorCode KNOWLEDGE_NOT_EXISTS = new ErrorCode(1_022_008_000, "知识库不存在!");
ErrorCode KNOWLEDGE_DOCUMENT_NOT_EXISTS = new ErrorCode(1_022_008_001, "文档不存在!");
ErrorCode KNOWLEDGE_SEGMENT_NOT_EXISTS = new ErrorCode(1_022_008_002, "段落不存在!");
ErrorCode KNOWLEDGE_DOCUMENT_NOT_EXISTS = new ErrorCode(1_022_008_101, "文档不存在!");
ErrorCode KNOWLEDGE_DOCUMENT_FILE_EMPTY = new ErrorCode(1_022_008_102, "文档内容为空!");
ErrorCode KNOWLEDGE_DOCUMENT_FILE_READ_FAIL = new ErrorCode(1_022_008_102, "文档加载失败!");
ErrorCode KNOWLEDGE_SEGMENT_NOT_EXISTS = new ErrorCode(1_022_008_202, "段落不存在!");
}

View File

@ -1,39 +0,0 @@
package cn.iocoder.yudao.module.ai.enums.knowledge;
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 知识库-文档状态的枚举
*
* @author xiaoxin
*/
@AllArgsConstructor
@Getter
public enum AiKnowledgeDocumentStatusEnum implements ArrayValuable<Integer> {
IN_PROGRESS(10, "索引中"),
SUCCESS(20, "可用"),
FAIL(30, "失败");
/**
* 状态
*/
private final Integer status;
/**
* 状态名
*/
private final String name;
public static final Integer[] ARRAYS = Arrays.stream(values()).map(AiKnowledgeDocumentStatusEnum::getStatus).toArray(Integer[]::new);
@Override
public Integer[] array() {
return ARRAYS;
}
}

View File

@ -0,0 +1,33 @@
### 创建知识库
POST {{baseUrl}}/ai/knowledge/create
Content-Type: application/json
Authorization: {{token}}
tenant-id: {{adminTenantId}}
{
"name": "测试标题",
"description": "测试描述",
"embeddingModelId": 30,
"topK": 3,
"similarityThreshold": 0.5
}
### 更新知识库
PUT {{baseUrl}}/ai/knowledge/update
Content-Type: application/json
Authorization: {{token}}
tenant-id: {{adminTenantId}}
{
"id": 1,
"name": "测试标题(更新)",
"description": "测试描述",
"embeddingModelId": 30,
"topK": 5,
"similarityThreshold": 0.6
}
### 获取知识库分页
GET {{baseUrl}}/ai/knowledge/page?pageNo=1&pageSize=10
Authorization: {{token}}
tenant-id: {{adminTenantId}}

View File

@ -3,10 +3,9 @@ package cn.iocoder.yudao.module.ai.controller.admin.knowledge;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
import io.swagger.v3.oas.annotations.Operation;
@ -17,7 +16,6 @@ import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
@Tag(name = "管理后台 - AI 知识库")
@RestController
@ -31,20 +29,22 @@ public class AiKnowledgeController {
@GetMapping("/page")
@Operation(summary = "获取知识库分页")
public CommonResult<PageResult<AiKnowledgeRespVO>> getKnowledgePage(@Valid AiKnowledgePageReqVO pageReqVO) {
PageResult<AiKnowledgeDO> pageResult = knowledgeService.getKnowledgePage(getLoginUserId(), pageReqVO);
PageResult<AiKnowledgeDO> pageResult = knowledgeService.getKnowledgePage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiKnowledgeRespVO.class));
}
@PostMapping("/create")
@Operation(summary = "创建知识库")
public CommonResult<Long> createKnowledge(@RequestBody @Valid AiKnowledgeCreateReqVO createReqVO) {
return success(knowledgeService.createKnowledge(createReqVO, getLoginUserId()));
public CommonResult<Long> createKnowledge(@RequestBody @Valid AiKnowledgeSaveReqVO createReqVO) {
return success(knowledgeService.createKnowledge(createReqVO));
}
@PutMapping("/update")
@Operation(summary = "更新知识库")
public CommonResult<Boolean> updateKnowledge(@RequestBody @Valid AiKnowledgeUpdateReqVO updateReqVO) {
knowledgeService.updateKnowledge(updateReqVO, getLoginUserId());
public CommonResult<Boolean> updateKnowledge(@RequestBody @Valid AiKnowledgeSaveReqVO updateReqVO) {
knowledgeService.updateKnowledge(updateReqVO);
return success(true);
}
}

View File

@ -0,0 +1,12 @@
### 创建知识文档
POST {{baseUrl}}/ai/knowledge/document/create
Content-Type: application/json
Authorization: Bearer {{token}}
tenant-id: {{adminTenantId}}
{
"knowledgeId": 1,
"name": "测试文档",
"url": "https://static.iocoder.cn/README.md",
"segmentMaxTokens": 800
}

View File

@ -27,20 +27,21 @@ public class AiKnowledgeDocumentController {
@Resource
private AiKnowledgeDocumentService documentService;
@PostMapping("/create")
@Operation(summary = "新建文档")
public CommonResult<Long> createKnowledgeDocument(@Valid AiKnowledgeDocumentCreateReqVO reqVO) {
Long knowledgeDocumentId = documentService.createKnowledgeDocument(reqVO);
return success(knowledgeDocumentId);
}
@GetMapping("/page")
@Operation(summary = "获取文档分页")
public CommonResult<PageResult<AiKnowledgeDocumentRespVO>> getKnowledgeDocumentPage(@Valid AiKnowledgeDocumentPageReqVO pageReqVO) {
public CommonResult<PageResult<AiKnowledgeDocumentRespVO>> getKnowledgeDocumentPage(
@Valid AiKnowledgeDocumentPageReqVO pageReqVO) {
PageResult<AiKnowledgeDocumentDO> pageResult = documentService.getKnowledgeDocumentPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiKnowledgeDocumentRespVO.class));
}
@PostMapping("/create")
@Operation(summary = "新建文档")
public CommonResult<Long> createKnowledgeDocument(@RequestBody @Valid AiKnowledgeDocumentCreateReqVO reqVO) {
Long knowledgeDocumentId = documentService.createKnowledgeDocument(reqVO);
return success(knowledgeDocumentId);
}
@PutMapping("/update")
@Operation(summary = "更新文档")
public CommonResult<Boolean> updateKnowledgeDocument(@Valid @RequestBody AiKnowledgeDocumentUpdateReqVO reqVO) {

View File

@ -11,16 +11,15 @@ import lombok.Data;
@Data
public class AiKnowledgeDocumentUpdateReqVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "15583")
@NotNull(message = "编号不能为空")
private Long id;
@Schema(description = "名称", example = "Java 开发手册")
private String name;
@Schema(description = "是否启用", example = "1")
@InEnum(CommonStatusEnum.class)
private Integer status;
@Schema(description = "名称", example = "Java 开发手册")
private String name;
}

View File

@ -23,24 +23,8 @@ public class AiKnowledgeDocumentCreateReqVO {
@URL(message = "文档 URL 格式不正确")
private String url;
@Schema(description = "每个段落的目标 token 数", requiredMode = Schema.RequiredMode.REQUIRED, example = "800")
@NotNull(message = "每个段落的目标 token 数不能为空")
private Integer defaultSegmentTokens;
@Schema(description = "每个段落的最小字符数", requiredMode = Schema.RequiredMode.REQUIRED, example = "350")
@NotNull(message = "每个段落的最小字符数不能为空")
private Integer minSegmentWordCount;
@Schema(description = "丢弃阈值:低于此阈值的段落会被丢弃", requiredMode = Schema.RequiredMode.REQUIRED, example = "5")
@NotNull(message = "丢弃阈值不能为空")
private Integer minChunkLengthToEmbed;
@Schema(description = "最大段落数", requiredMode = Schema.RequiredMode.REQUIRED, example = "10000")
@NotNull(message = "最大段落数不能为空")
private Integer maxNumSegments;
@Schema(description = "分块是否保留分隔符", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
@NotNull(message = "分块是否保留分隔符不能为空")
private Boolean keepSeparator;
@Schema(description = "分段的最大 Token 数", requiredMode = Schema.RequiredMode.REQUIRED, example = "800")
@NotNull(message = "分段的最大 Token 数不能为空")
private Integer segmentMaxTokens;
}

View File

@ -5,11 +5,12 @@ import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.util.List;
@Schema(description = "管理后台 - AI 知识库创建 Request VO")
@Schema(description = "管理后台 - AI 知识库新增/修改 Request VO")
@Data
public class AiKnowledgeCreateReqVO {
public class AiKnowledgeSaveReqVO {
@Schema(description = "对话编号", example = "1204")
private Long id;
@Schema(description = "知识库名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "ruoyi-vue-pro 用户指南")
@NotBlank(message = "知识库名称不能为空")
@ -18,19 +19,18 @@ public class AiKnowledgeCreateReqVO {
@Schema(description = "知识库描述", requiredMode = Schema.RequiredMode.REQUIRED, example = "存储 ruoyi-vue-pro 操作文档")
private String description;
@Schema(description = "可见权限,只能选择哪些人可见", requiredMode = Schema.RequiredMode.REQUIRED, example = "[1,2,3]")
private List<Long> visibilityPermissions;
@Schema(description = "嵌入模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "嵌入模型不能为空")
private Long modelId;
@Schema(description = "相似性阈值", requiredMode = Schema.RequiredMode.REQUIRED, example = "0.5")
@NotNull(message = "相似性阈值不能为空")
private Double similarityThreshold;
@Schema(description = "向量模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "向量模型不能为空")
private Long embeddingModelId;
@Schema(description = "topK", requiredMode = Schema.RequiredMode.REQUIRED, example = "3")
@NotNull(message = "topK 不能为空")
private Integer topK;
@Schema(description = "相似性阈值", requiredMode = Schema.RequiredMode.REQUIRED, example = "0.5")
@NotNull(message = "相似性阈值不能为空")
private Double similarityThreshold;
// TODO @芋艿status
}

View File

@ -1,32 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.util.List;
@Schema(description = "管理后台 - AI 知识库更新【我的】 Request VO")
@Data
public class AiKnowledgeUpdateReqVO {
@Schema(description = "对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1204")
@NotNull(message = "知识库编号不能为空")
private Long id;
@Schema(description = "知识库名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "")
@NotBlank(message = "知识库名称不能为空")
private String name;
@Schema(description = "知识库描述", requiredMode = Schema.RequiredMode.REQUIRED, example = "")
private String description;
@Schema(description = "可见权限,只能选择哪些人可见", requiredMode = Schema.RequiredMode.REQUIRED, example = "1,2,3")
private List<Long> visibilityPermissions;
@Schema(description = "嵌入模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "嵌入模型不能为空")
private Long modelId;
}

View File

@ -80,6 +80,8 @@ public class AiChatConversationDO extends BaseDO {
private Long modelId;
/**
* 模型标志
*
* 冗余 {@link AiChatModelDO#getModel()} 字段
*/
private String model;

View File

@ -82,6 +82,8 @@ public class AiChatMessageDO extends BaseDO {
/**
* 模型标志
*
* 冗余 {@link AiChatModelDO#getModel()}
*/
private String model;
/**

View File

@ -52,6 +52,7 @@ public class AiImageDO extends BaseDO {
* 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
*/
private String platform;
// TODO @芋艿modelId
/**
* 模型
*

View File

@ -2,15 +2,12 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import java.util.List;
/**
* AI 知识库 DO
*
@ -26,12 +23,6 @@ public class AiKnowledgeDO extends BaseDO {
*/
@TableId
private Long id;
/**
* 用户编号
* <p>
* 关联 AdminUserDO userId 字段
*/
private Long userId;
/**
* 知识库名称
*/
@ -42,20 +33,17 @@ public class AiKnowledgeDO extends BaseDO {
private String description;
/**
* 可见权限,选择哪些人可见
* <p>
* -1 所有人可见其他为各自用户编号
* 向量模型编号
*
* 关联 {@link AiChatModelDO#getId()}
*/
@TableField(typeHandler = LongListTypeHandler.class)
private List<Long> visibilityPermissions;
/**
* 嵌入模型编号
*/
private Long modelId;
private Long embeddingModelId;
/**
* 模型标识
*
* 冗余 {@link AiChatModelDO#getModel()}
*/
private String model;
private String embeddingModel;
/**
* topK

View File

@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
@ -33,54 +32,27 @@ public class AiKnowledgeDocumentDO extends BaseDO {
* 文件名称
*/
private String name;
/**
* 文件 URL
*/
private String url;
/**
* 内容
*/
private String content;
/**
* URL
* 档长度
*/
private String url;
private Integer contentLength;
/**
* 文档 token 数量
*/
private Integer tokens;
/**
* 文档字符数
*/
private Integer wordCount;
// ========== 自定义分段所用参数 ==========
// TODO @新3defaultChunkSizedefaultChunkSizeminChunkSizeCharsmaxNumChunks 这几个字段的命名可能要微信一起讨论下尽量命名保持风格统一哈
/**
* 每个文本块的目标 token
*/
private Integer defaultSegmentTokens;
/**
* 每个文本块的最小字符数
*/
private Integer minSegmentWordCount;
/**
* 低于此值的块会被丢弃
*/
private Integer minChunkLengthToEmbed;
/**
* 最大块数
*/
private Integer maxNumSegments;
/**
* 分块是否保留分隔符
*/
private Boolean keepSeparator;
// ===================================
/**
* 切片状态
* <p>
* 枚举 {@link AiKnowledgeDocumentStatusEnum}
*/
private Integer sliceStatus;
private Integer segmentMaxTokens;
/**
* 状态

View File

@ -17,17 +17,16 @@ import lombok.Data;
@Data
public class AiKnowledgeSegmentDO extends BaseDO {
public static final String FIELD_KNOWLEDGE_ID = "knowledgeId";
/**
* 向量库的编号 - 空值
*/
public static final String VECTOR_ID_EMPTY = "";
/**
* 编号
*/
@TableId
private Long id;
/**
* 向量库的编号
*/
private String vectorId;
/**
* 知识库编号
* <p>
@ -45,13 +44,19 @@ public class AiKnowledgeSegmentDO extends BaseDO {
*/
private String content;
/**
* 字符数
* 切片内容长度
*/
private Integer wordCount;
private Integer contentLength;
/**
* 向量库的编号
*/
private String vectorId;
/**
* token 数量
*/
private Integer tokens;
/**
* 状态
* <p>

View File

@ -36,6 +36,7 @@ public class AiMindMapDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
// TODO @芋艿modelId
/**
* 模型
*/

View File

@ -8,6 +8,7 @@ import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*;
// TODO @芋艿需要改造增加 type
/**
* AI 聊天模型 DO
*

View File

@ -84,6 +84,7 @@ public class AiMusicDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
// TODO @芋艿modelId
/**
* 模型
*/

View File

@ -44,6 +44,7 @@ public class AiWriteDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
// TODO @芋艿modelId
/**
* 模型
*/

View File

@ -16,11 +16,9 @@ import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface AiKnowledgeMapper extends BaseMapperX<AiKnowledgeDO> {
default PageResult<AiKnowledgeDO> selectPage(Long userId, AiKnowledgePageReqVO pageReqVO) {
default PageResult<AiKnowledgeDO> selectPage(AiKnowledgePageReqVO pageReqVO) {
return selectPage(pageReqVO, new LambdaQueryWrapperX<AiKnowledgeDO>()
.eq(AiKnowledgeDO::getStatus, CommonStatusEnum.ENABLE.getStatus())
.likeIfPresent(AiKnowledgeDO::getName, pageReqVO.getName())
.and(e -> e.apply("FIND_IN_SET(" + userId + ",visibility_permissions)").or(m -> m.apply("FIND_IN_SET(-1,visibility_permissions)")))
.orderByDesc(AiKnowledgeDO::getId));
.likeIfPresent(AiKnowledgeDO::getName, pageReqVO.getName()));
}
}

View File

@ -21,7 +21,6 @@ public interface AiKnowledgeDocumentService {
*/
Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO);
/**
* 获取文档分页
*
@ -36,4 +35,13 @@ public interface AiKnowledgeDocumentService {
* @param reqVO 更新信息
*/
void updateKnowledgeDocument(AiKnowledgeDocumentUpdateReqVO reqVO);
/**
* 校验文档是否存在
*
* @param id 文档编号
* @return 文档信息
*/
AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id);
}

View File

@ -1,26 +1,22 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@ -28,7 +24,7 @@ import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_DOCUMENT_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/**
* AI 知识库文档 Service 实现类
@ -40,58 +36,46 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_DOCU
public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentService {
@Resource
private AiKnowledgeDocumentMapper documentMapper;
@Resource
private AiKnowledgeSegmentMapper segmentMapper;
private AiKnowledgeDocumentMapper knowledgeDocumentMapper;
@Resource
private TokenCountEstimator tokenCountEstimator;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService;
@Resource
@Lazy // 延迟加载避免循环依赖
private AiKnowledgeService knowledgeService;
@Override
@Transactional(rollbackFor = Exception.class)
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
// 0. 校验并获取向量存储实例
VectorStore vectorStore = knowledgeService.getVectorStoreById(createReqVO.getKnowledgeId());
// 1. 校验参数
knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
// 1.1 下载文档
// 2. 下载文档
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(createReqVO.getUrl()));
List<Document> documents = loader.get();
Document document = CollUtil.getFirst(documents);
// 1.2 文档记录入库
String content = document.getText();
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
.setTokens(tokenCountEstimator.estimate(content)).setWordCount(content.length())
.setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus());
documentMapper.insert(documentDO);
Long documentId = documentDO.getId();
if (CollUtil.isEmpty(documents)) {
return documentId;
if (document == null || StrUtil.isEmpty(document.getText())) {
throw exception(KNOWLEDGE_DOCUMENT_FILE_READ_FAIL);
}
// 2 构造文本分段器
TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(createReqVO.getDefaultSegmentTokens(), createReqVO.getMinSegmentWordCount(), createReqVO.getMinChunkLengthToEmbed(),
createReqVO.getMaxNumSegments(), createReqVO.getKeepSeparator());
// 2.1 文档分段
List<Document> segments = tokenTextSplitter.apply(documents);
// 2.2 分段内容入库
List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments,
segment -> new AiKnowledgeSegmentDO().setContent(segment.getText()).setDocumentId(documentId)
.setKnowledgeId(createReqVO.getKnowledgeId()).setVectorId(segment.getId())
.setTokens(tokenCountEstimator.estimate(segment.getText())).setWordCount(segment.getText().length())
.setStatus(CommonStatusEnum.ENABLE.getStatus()));
segmentMapper.insertBatch(segmentDOList);
// 3. 文档记录入库
String content = document.getText();
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
.setContent(content).setContentLength(content.length()).setTokens(tokenCountEstimator.estimate(content))
.setStatus(CommonStatusEnum.ENABLE.getStatus());
knowledgeDocumentMapper.insert(documentDO);
// 3. 向量化并存储
segments.forEach(segment -> segment.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId()));
vectorStore.add(segments);
return documentId;
// 4. 文档切片入库
knowledgeSegmentService.createKnowledgeSegmentBySplitContent(documentDO.getId(), document.getText());
return documentDO.getId();
}
@Override
public PageResult<AiKnowledgeDocumentDO> getKnowledgeDocumentPage(AiKnowledgeDocumentPageReqVO pageReqVO) {
return documentMapper.selectPage(pageReqVO);
return knowledgeDocumentMapper.selectPage(pageReqVO);
}
@Override
@ -100,17 +84,13 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
validateKnowledgeDocumentExists(reqVO.getId());
// 2. 更新文档
AiKnowledgeDocumentDO document = BeanUtils.toBean(reqVO, AiKnowledgeDocumentDO.class);
documentMapper.updateById(document);
knowledgeDocumentMapper.updateById(document);
// TODO @芋艿这里要处理状态的变更
}
/**
* 校验文档是否存在
*
* @param id 文档编号
* @return 文档信息
*/
private AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id) {
AiKnowledgeDocumentDO knowledgeDocument = documentMapper.selectById(id);
@Override
public AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id) {
AiKnowledgeDocumentDO knowledgeDocument = knowledgeDocumentMapper.selectById(id);
if (knowledgeDocument == null) {
throw exception(KNOWLEDGE_DOCUMENT_NOT_EXISTS);
}
@ -120,6 +100,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
private org.springframework.core.io.Resource downloadFile(String url) {
try {
byte[] bytes = HttpUtil.downloadBytes(url);
if (bytes.length == 0) {
throw exception(KNOWLEDGE_DOCUMENT_FILE_EMPTY);
}
return new ByteArrayResource(bytes);
} catch (Exception e) {
log.error("[downloadFile][url({}) 下载失败]", url, e);

View File

@ -24,6 +24,14 @@ public interface AiKnowledgeSegmentService {
*/
PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO);
/**
* 基于 content 内容切片创建多个段落
*
* @param documentId 知识库文档编号
* @param content 文档内容
*/
void createKnowledgeSegmentBySplitContent(Long documentId, String content);
/**
* 更新段落的内容
*

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -10,23 +11,28 @@ import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowle
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.transformer.splitter.TextSplitter;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
/**
@ -38,85 +44,138 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGM
@Slf4j
public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService {
public static final String VECTOR_STORE_METADATA_KNOWLEDGE_ID = "knowledgeId";
public static final String VECTOR_STORE_METADATA_DOCUMENT_ID = "documentId";
public static final String VECTOR_STORE_METADATA_SEGMENT_ID = "segmentId";
@Resource
private AiKnowledgeSegmentMapper segmentMapper;
@Resource
private AiKnowledgeService knowledgeService;
@Resource
private AiChatModelService chatModelService;
@Lazy // 延迟加载避免循环依赖
private AiKnowledgeDocumentService knowledgeDocumentService;
@Resource
private AiApiKeyService apiKeyService;
@Resource
private TokenCountEstimator tokenCountEstimator;
@Override
public PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) {
return segmentMapper.selectPage(pageReqVO);
}
@Override
public void createKnowledgeSegmentBySplitContent(Long documentId, String content) {
// 1. 校验
AiKnowledgeDocumentDO documentDO = knowledgeDocumentService.validateKnowledgeDocumentExists(documentId);
AiKnowledgeDO knowledgeDO = knowledgeService.validateKnowledgeExists(documentDO.getKnowledgeId());
VectorStore vectorStore = getVectorStoreById(knowledgeDO);
// 2. 文档切片
Document document = new Document(content);
TextSplitter textSplitter = buildTokenTextSplitter(documentDO.getSegmentMaxTokens());
List<Document> documentSegments = textSplitter.apply(Collections.singletonList(document));
// 3.1 存储切片
List<AiKnowledgeSegmentDO> segmentDOs = convertList(documentSegments, segment -> {
if (StrUtil.isEmpty(segment.getText())) {
return null;
}
return new AiKnowledgeSegmentDO().setKnowledgeId(documentDO.getKnowledgeId()).setDocumentId(documentId)
.setContent(segment.getText()).setContentLength(segment.getText().length())
.setVectorId(AiKnowledgeSegmentDO.VECTOR_ID_EMPTY).setTokens(tokenCountEstimator.estimate(segment.getText()))
.setStatus(CommonStatusEnum.ENABLE.getStatus());
});
segmentMapper.insertBatch(segmentDOs);
// 3.2 切片向量化
for (int i = 0; i < documentSegments.size(); i++) {
Document segment = documentSegments.get(i);
AiKnowledgeSegmentDO segmentDO = segmentDOs.get(i);
writeVectorStore(vectorStore, segmentDO, segment);
}
}
@Override
public void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO) {
// 1. 校验
AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
AiKnowledgeSegmentDO segment = validateKnowledgeSegmentExists(reqVO.getId());
// 2.1 获取知识库向量实例
VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
// 2.2 删除原向量
vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
// 2.3 重新向量化
Document document = new Document(reqVO.getContent());
document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
vectorStore.add(List.of(document));
// 2. 删除向量
VectorStore vectorStore = getVectorStoreById(segment.getKnowledgeId());
deleteVectorStore(vectorStore, segment);
// 3. 更新段落内容
AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
knowledgeSegment.setVectorId(document.getId());
segmentMapper.updateById(knowledgeSegment);
// 3.1 更新切片
AiKnowledgeSegmentDO segmentDO = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
segmentMapper.updateById(segmentDO);
// 3.2 重新向量化
writeVectorStore(vectorStore, segmentDO, new Document(segmentDO.getContent()));
}
@Override
public void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO) {
// 0 校验
AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
// 1 获取知识库向量实例
VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
// 1. 校验
AiKnowledgeSegmentDO segment = validateKnowledgeSegmentExists(reqVO.getId());
// 2. 获取知识库向量实例
VectorStore vectorStore = getVectorStoreById(segment.getKnowledgeId());
// 3. 更新状态
segmentMapper.updateById(new AiKnowledgeSegmentDO().setId(reqVO.getId()).setStatus(reqVO.getStatus()));
// 4. 更新向量
if (Objects.equals(reqVO.getStatus(), CommonStatusEnum.ENABLE.getStatus())) {
// 2.1 启用重新向量化
Document document = new Document(oldKnowledgeSegment.getContent());
document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
vectorStore.add(List.of(document));
knowledgeSegment.setVectorId(document.getId());
writeVectorStore(vectorStore, segment, new Document(segment.getContent()));
} else {
// 2.2 禁用删除向量
vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
knowledgeSegment.setVectorId("");
deleteVectorStore(vectorStore, segment);
}
// 3 更新段落状态
segmentMapper.updateById(knowledgeSegment);
}
private void writeVectorStore(VectorStore vectorStore, AiKnowledgeSegmentDO segmentDO, Document segment) {
// 1. 向量存储
segment.getMetadata().put(VECTOR_STORE_METADATA_KNOWLEDGE_ID, segmentDO.getKnowledgeId());
segment.getMetadata().put(VECTOR_STORE_METADATA_DOCUMENT_ID, segmentDO.getDocumentId());
segment.getMetadata().put(VECTOR_STORE_METADATA_SEGMENT_ID, segmentDO.getId());
vectorStore.add(List.of(segment));
// 2. 更新向量 ID
segmentMapper.updateById(new AiKnowledgeSegmentDO().setId(segmentDO.getId()).setVectorId(segment.getId()));
}
private void deleteVectorStore(VectorStore vectorStore, AiKnowledgeSegmentDO segmentDO) {
// 1. 更新向量 ID
if (StrUtil.isEmpty(segmentDO.getVectorId())) {
return;
}
segmentMapper.updateById(new AiKnowledgeSegmentDO().setId(segmentDO.getId())
.setVectorId(AiKnowledgeSegmentDO.VECTOR_ID_EMPTY));
// 2. 删除向量
vectorStore.delete(List.of(segmentDO.getVectorId()));
}
@Override
public List<AiKnowledgeSegmentDO> similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO) {
// 1. 校验
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 2. 获取向量存储实例
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
VectorStore vectorStore = apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
// 3.1 向量检索
List<Document> documentList = vectorStore.similaritySearch(SearchRequest.builder()
List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
.query(reqVO.getContent())
.topK(knowledge.getTopK())
.similarityThreshold(knowledge.getSimilarityThreshold())
.filterExpression(new FilterExpressionBuilder().eq(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, reqVO.getKnowledgeId()).build())
.topK(knowledge.getTopK()).similarityThreshold(knowledge.getSimilarityThreshold())
.filterExpression(new FilterExpressionBuilder()
.eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqVO.getKnowledgeId()).build())
.build());
if (CollUtil.isEmpty(documentList)) {
if (CollUtil.isEmpty(documents)) {
return ListUtil.empty();
}
// 3.2 段落召回
return segmentMapper.selectListByVectorIds(CollUtil.getFieldValues(documentList, "id", String.class));
return segmentMapper.selectListByVectorIds(convertList(documents, Document::getId));
}
/**
@ -133,4 +192,22 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
return knowledgeSegment;
}
private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) {
return apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
}
private VectorStore getVectorStoreById(Long knowledgeId) {
return getVectorStoreById(knowledgeService.validateKnowledgeExists(knowledgeId));
}
private static TextSplitter buildTokenTextSplitter(Integer segmentMaxTokens) {
return TokenTextSplitter.builder()
.withChunkSize(segmentMaxTokens)
.withMinChunkSizeChars(Integer.MAX_VALUE) // 忽略字符的截断
.withMinChunkLengthToEmbed(1) // 允许的最小有效分段长度
.withMaxNumChunks(Integer.MAX_VALUE)
.withKeepSeparator(true) // 保留分隔符
.build();
}
}

View File

@ -1,11 +1,9 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import org.springframework.ai.vectorstore.VectorStore;
/**
* AI 知识库-基础信息 Service 接口
@ -18,18 +16,16 @@ public interface AiKnowledgeService {
* 创建知识库
*
* @param createReqVO 创建信息
* @param userId 用户编号
* @return 编号
*/
Long createKnowledge(AiKnowledgeCreateReqVO createReqVO, Long userId);
Long createKnowledge(AiKnowledgeSaveReqVO createReqVO);
/**
* 更新知识库
*
* @param updateReqVO 更新信息
* @param userId 用户编号
*/
void updateKnowledge(AiKnowledgeUpdateReqVO updateReqVO, Long userId);
void updateKnowledge(AiKnowledgeSaveReqVO updateReqVO);
/**
* 校验知识库是否存在
@ -41,18 +37,9 @@ public interface AiKnowledgeService {
/**
* 获得知识库分页
*
* @param userId 用户编号
* @param pageReqVO 分页查询
* @return 知识库分页
*/
PageResult<AiKnowledgeDO> getKnowledgePage(Long userId, AiKnowledgePageReqVO pageReqVO);
/**
* 根据知识库编号获取向量存储实例
*
* @param id 知识库编号
* @return 向量存储实例
*/
VectorStore getVectorStoreById(Long id);
PageResult<AiKnowledgeDO> getKnowledgePage(AiKnowledgePageReqVO pageReqVO);
}

View File

@ -1,12 +1,10 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
@ -34,35 +32,30 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
@Resource
private AiChatModelService chatModelService;
@Resource
private AiApiKeyService apiKeyService;
@Override
public Long createKnowledge(AiKnowledgeCreateReqVO createReqVO, Long userId) {
public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) {
// 1. 校验模型配置
AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getModelId());
AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getEmbeddingModelId());
// 2. 插入知识库
AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
.setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus());
knowledgeMapper.insert(knowledgeBase);
return knowledgeBase.getId();
AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
.setEmbeddingModel(model.getModel()).setStatus(CommonStatusEnum.ENABLE.getStatus());
knowledgeMapper.insert(knowledge);
return knowledge.getId();
}
@Override
public void updateKnowledge(AiKnowledgeUpdateReqVO updateReqVO, Long userId) {
public void updateKnowledge(AiKnowledgeSaveReqVO updateReqVO) {
// 1.1 校验知识库存在
AiKnowledgeDO knowledgeBaseDO = validateKnowledgeExists(updateReqVO.getId());
if (ObjUtil.notEqual(knowledgeBaseDO.getUserId(), userId)) {
throw exception(KNOWLEDGE_NOT_EXISTS);
}
validateKnowledgeExists(updateReqVO.getId());
// 1.2 校验模型配置
AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getModelId());
AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getEmbeddingModelId());
// 2. 更新知识库
AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class);
updateDO.setModel(model.getModel());
knowledgeMapper.updateById(updateDO);
AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class)
.setEmbeddingModel(model.getModel());
knowledgeMapper.updateById(updateObj);
}
@Override
@ -75,16 +68,8 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
}
@Override
public PageResult<AiKnowledgeDO> getKnowledgePage(Long userId, AiKnowledgePageReqVO pageReqVO) {
return knowledgeMapper.selectPage(userId, pageReqVO);
}
@Override
public VectorStore getVectorStoreById(Long id) {
AiKnowledgeDO knowledge = validateKnowledgeExists(id);
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 创建或获取 VectorStore 对象
return apiKeyService.getOrCreateVectorStore(model.getKeyId());
public PageResult<AiKnowledgeDO> getKnowledgePage(AiKnowledgePageReqVO pageReqVO) {
return knowledgeMapper.selectPage(pageReqVO);
}
}

View File

@ -9,7 +9,6 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
@ -113,20 +112,12 @@ public interface AiApiKeyService {
*/
SunoApi getSunoApi();
/**
* 获得 EmbeddingModel 对象
*
* @param id 编号
* @return EmbeddingModel 对象
*/
EmbeddingModel getEmbeddingModel(Long id);
/**
* 获得 VectorStore 对象
*
* @param id 编号
* @param modelId 编号
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(Long id);
VectorStore getOrCreateVectorStoreByModelId(Long modelId);
}

View File

@ -10,12 +10,14 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
@ -36,6 +38,11 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource
private AiApiKeyMapper apiKeyMapper;
// TODO @芋艿后续要不要改
@Resource
@Lazy // 延迟加载解决渲染依赖
private AiChatModelService chatModelService;
@Resource
private AiModelFactory modelFactory;
@ -136,18 +143,18 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
}
@Override
public EmbeddingModel getEmbeddingModel(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
public VectorStore getOrCreateVectorStoreByModelId(Long modelId) {
// 获取模型 + 密钥
AiChatModelDO chatModel = chatModelService.validateChatModel(modelId);
AiApiKeyDO apiKey = validateApiKey(chatModel.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public VectorStore getOrCreateVectorStore(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
// 创建或获取 EmbeddingModel 对象
EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(),
apiKey.getUrl(), chatModel.getModel());
// 创建或获取 VectorStore 对象
return modelFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
return modelFactory.getOrCreateVectorStore(embeddingModel);
}
}

View File

@ -89,21 +89,19 @@ public interface AiModelFactory {
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @param model 模型
* @return ChatModel 对象
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model);
/**
* 基于指定配置获得 VectorStore 对象
* <p>
* 如果不存在则进行创建
*
* @param embeddingModel 嵌入模型
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @param embeddingModel 向量模型
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel);
}

View File

@ -21,6 +21,7 @@ import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
import com.azure.ai.openai.OpenAIClientBuilder;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
@ -33,10 +34,13 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiApi;
@ -184,13 +188,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
@Override
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url, model);
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
// TODO @xin 先测试一个
switch (platform) {
case TONG_YI:
return buildTongYiEmbeddingModel(apiKey);
return buildTongYiEmbeddingModel(apiKey, model);
case OLLAMA:
return buildOllamaEmbeddingModel(url, model);
// TODO @芋艿各个平台的向量化能力
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@ -198,13 +204,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
@Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel) {
// String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
// TODO @芋艿先临时使用 store
return SimpleVectorStore.builder(embeddingModel).build();
// TODO @芋艿@xin后续看看是不是切到阿里云之类的
// String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(cacheKey)
// .withPrefix(prefix)
@ -388,9 +395,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
/**
* 可参考 {@link DashScopeAutoConfiguration} dashscopeEmbeddingModel 方法
*/
private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) {
private DashScopeEmbeddingModel buildTongYiEmbeddingModel(String apiKey, String model) {
DashScopeApi dashScopeApi = new DashScopeApi(apiKey);
return new DashScopeEmbeddingModel(dashScopeApi);
DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model).build();
return new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, dashScopeEmbeddingOptions);
}
private OllamaEmbeddingModel buildOllamaEmbeddingModel(String url, String model) {
OllamaApi ollamaApi = new OllamaApi(url);
OllamaOptions ollamaOptions = OllamaOptions.builder().model(model).build();
return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build();
}
}