【代码重构】AI:“聊天模型”重构为“模型”,支持 type 模型类型

This commit is contained in:
YunaiV 2025-03-03 21:26:31 +08:00
parent 3f460dc620
commit 89d079349c
36 changed files with 617 additions and 554 deletions

View File

@ -12,31 +12,25 @@ public interface ErrorCodeConstants {
// ========== API 密钥 1-040-000-000 ========== // ========== API 密钥 1-040-000-000 ==========
ErrorCode API_KEY_NOT_EXISTS = new ErrorCode(1_040_000_000, "API 密钥不存在"); ErrorCode API_KEY_NOT_EXISTS = new ErrorCode(1_040_000_000, "API 密钥不存在");
ErrorCode API_KEY_DISABLE = new ErrorCode(1_040_000_001, "API 密钥已禁用!"); ErrorCode API_KEY_DISABLE = new ErrorCode(1_040_000_001, "API 密钥已禁用!");
ErrorCode API_KEY_MIDJOURNEY_NOT_FOUND = new ErrorCode(1_040_000_900, "Midjourney 模型不存在");
ErrorCode API_KEY_SUNO_NOT_FOUND = new ErrorCode(1_040_000_901, "Suno 模型不存在");
ErrorCode API_KEY_IMAGE_NODE_FOUND = new ErrorCode(1_040_000_902, "平台({}) 图片模型未配置");
// ========== API 聊天模型 1-040-001-000 ========== // ========== API 模型 1-040-001-000 ==========
ErrorCode CHAT_MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!"); ErrorCode MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!");
ErrorCode CHAT_MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!"); ErrorCode MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!");
ErrorCode CHAT_MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认聊天模型"); ErrorCode MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认模型");
// ========== API 聊天角色 1-040-002-000 ========== // ========== API 聊天角色 1-040-002-000 ==========
ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在"); ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在");
ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!"); ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!");
// ========== API 聊天会话 1-040-003-000 ========== // ========== API 聊天会话 1-040-003-000 ==========
ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!"); ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!");
ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整"); ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整");
// ========== API 聊天消息 1-040-004-000 ========== // ========== API 聊天消息 1-040-004-000 ==========
ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!"); ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!"); ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!");
// ========== API 绘画 1-040-005-000 ========== // ========== API 绘画 1-040-005-000 ==========
ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!"); ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}"); ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}"); ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation; package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.fhs.core.trans.anno.Trans; import com.fhs.core.trans.anno.Trans;
import com.fhs.core.trans.constant.TransType; import com.fhs.core.trans.constant.TransType;
@ -31,7 +31,7 @@ public class AiChatConversationRespVO implements VO {
private Long roleId; private Long roleId;
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = "name", ref = "modelName") @Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = "name", ref = "modelName")
private Long modelId; private Long modelId;
@Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "ERNIE-Bot-turbo-0922") @Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "ERNIE-Bot-turbo-0922")

View File

@ -14,18 +14,15 @@ import java.util.Map;
@Data @Data
public class AiImageDrawReqVO { public class AiImageDrawReqVO {
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI") @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private String platform; // 参见 AiPlatformEnum 枚举 @NotNull(message = "模型编号不能为空")
private Long modelId;
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "画一个长城") @Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "画一个长城")
@NotEmpty(message = "提示词不能为空") @NotEmpty(message = "提示词不能为空")
@Size(max = 1200, message = "提示词最大 1200") @Size(max = 1200, message = "提示词最大 1200")
private String prompt; private String prompt;
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "stable-diffusion-v1-6")
@NotEmpty(message = "模型不能为空")
private String model;
/** /**
* 1. dall-e-2 模型256x256512x5121024x1024 * 1. dall-e-2 模型256x256512x5121024x1024
* 2. dall-e-3 模型1024x1024, 1792x1024, 1024x1792 * 2. dall-e-3 模型1024x1024, 1792x1024, 1024x1792

View File

@ -6,9 +6,8 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyRespVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
@ -76,9 +75,9 @@ public class AiApiKeyController {
@GetMapping("/simple-list") @GetMapping("/simple-list")
@Operation(summary = "获得 API 密钥分页列表") @Operation(summary = "获得 API 密钥分页列表")
public CommonResult<List<AiChatModelRespVO>> getApiKeySimpleList() { public CommonResult<List<AiModelRespVO>> getApiKeySimpleList() {
List<AiApiKeyDO> list = apiKeyService.getApiKeyList(); List<AiApiKeyDO> list = apiKeyService.getApiKeyList();
return success(convertList(list, key -> new AiChatModelRespVO().setId(key.getId()).setName(key.getName()))); return success(convertList(list, key -> new AiModelRespVO().setId(key.getId()).setName(key.getName())));
} }
} }

View File

@ -1,84 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.model;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import jakarta.validation.Valid;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@Tag(name = "管理后台 - AI 聊天模型")
@RestController
@RequestMapping("/ai/chat-model")
@Validated
public class AiChatModelController {
@Resource
private AiChatModelService chatModelService;
@PostMapping("/create")
@Operation(summary = "创建聊天模型")
@PreAuthorize("@ss.hasPermission('ai:chat-model:create')")
public CommonResult<Long> createChatModel(@Valid @RequestBody AiChatModelSaveReqVO createReqVO) {
return success(chatModelService.createChatModel(createReqVO));
}
@PutMapping("/update")
@Operation(summary = "更新聊天模型")
@PreAuthorize("@ss.hasPermission('ai:chat-model:update')")
public CommonResult<Boolean> updateChatModel(@Valid @RequestBody AiChatModelSaveReqVO updateReqVO) {
chatModelService.updateChatModel(updateReqVO);
return success(true);
}
@DeleteMapping("/delete")
@Operation(summary = "删除聊天模型")
@Parameter(name = "id", description = "编号", required = true)
@PreAuthorize("@ss.hasPermission('ai:chat-model:delete')")
public CommonResult<Boolean> deleteChatModel(@RequestParam("id") Long id) {
chatModelService.deleteChatModel(id);
return success(true);
}
@GetMapping("/get")
@Operation(summary = "获得聊天模型")
@Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('ai:chat-model:query')")
public CommonResult<AiChatModelRespVO> getChatModel(@RequestParam("id") Long id) {
AiChatModelDO chatModel = chatModelService.getChatModel(id);
return success(BeanUtils.toBean(chatModel, AiChatModelRespVO.class));
}
@GetMapping("/page")
@Operation(summary = "获得聊天模型分页")
@PreAuthorize("@ss.hasPermission('ai:chat-model:query')")
public CommonResult<PageResult<AiChatModelRespVO>> getChatModelPage(@Valid AiChatModelPageReqVO pageReqVO) {
PageResult<AiChatModelDO> pageResult = chatModelService.getChatModelPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiChatModelRespVO.class));
}
@GetMapping("/simple-list")
@Operation(summary = "获得聊天模型列表")
@Parameter(name = "status", description = "状态", required = true, example = "1")
public CommonResult<List<AiChatModelRespVO>> getChatModelSimpleList(@RequestParam("status") Integer status) {
List<AiChatModelDO> list = chatModelService.getChatModelListByStatus(status);
return success(convertList(list, model -> new AiChatModelRespVO().setId(model.getId())
.setName(model.getName()).setModel(model.getModel())));
}
}

View File

@ -0,0 +1,89 @@
package cn.iocoder.yudao.module.ai.controller.admin.model;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import jakarta.validation.Valid;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@Tag(name = "管理后台 - AI 模型")
@RestController
@RequestMapping("/ai/model")
@Validated
public class AiModelController {
@Resource
private AiModelService modelService;
@PostMapping("/create")
@Operation(summary = "创建模型")
@PreAuthorize("@ss.hasPermission('ai:model:create')")
public CommonResult<Long> createModel(@Valid @RequestBody AiModelSaveReqVO createReqVO) {
return success(modelService.createModel(createReqVO));
}
@PutMapping("/update")
@Operation(summary = "更新模型")
@PreAuthorize("@ss.hasPermission('ai:model:update')")
public CommonResult<Boolean> updateModel(@Valid @RequestBody AiModelSaveReqVO updateReqVO) {
modelService.updateModel(updateReqVO);
return success(true);
}
@DeleteMapping("/delete")
@Operation(summary = "删除模型")
@Parameter(name = "id", description = "编号", required = true)
@PreAuthorize("@ss.hasPermission('ai:model:delete')")
public CommonResult<Boolean> deleteModel(@RequestParam("id") Long id) {
modelService.deleteModel(id);
return success(true);
}
@GetMapping("/get")
@Operation(summary = "获得模型")
@Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('ai:model:query')")
public CommonResult<AiModelRespVO> getModel(@RequestParam("id") Long id) {
AiModelDO model = modelService.getModel(id);
return success(BeanUtils.toBean(model, AiModelRespVO.class));
}
@GetMapping("/page")
@Operation(summary = "获得模型分页")
@PreAuthorize("@ss.hasPermission('ai:model:query')")
public CommonResult<PageResult<AiModelRespVO>> getModelPage(@Valid AiModelPageReqVO pageReqVO) {
PageResult<AiModelDO> pageResult = modelService.getModelPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiModelRespVO.class));
}
@GetMapping("/simple-list")
@Operation(summary = "获得模型列表")
@Parameter(name = "type", description = "类型", required = true, example = "1")
@Parameter(name = "platform", description = "平台", example = "midjourney")
public CommonResult<List<AiModelRespVO>> getModelSimpleList(
@RequestParam("type") Integer type,
@RequestParam(value = "platform", required = false) String platform) {
List<AiModelDO> list = modelService.getModelListByStatusAndType(
CommonStatusEnum.ENABLE.getStatus(), type, platform);
return success(convertList(list, model -> new AiModelRespVO().setId(model.getId())
.setName(model.getName()).setModel(model.getModel()).setPlatform(model.getPlatform())));
}
}

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole; package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.fhs.core.trans.anno.Trans; import com.fhs.core.trans.anno.Trans;
import com.fhs.core.trans.constant.TransType; import com.fhs.core.trans.constant.TransType;
import com.fhs.core.trans.vo.VO; import com.fhs.core.trans.vo.VO;
@ -20,7 +20,7 @@ public class AiChatRoleRespVO implements VO {
private Long userId; private Long userId;
@Schema(description = "模型编号", example = "17640") @Schema(description = "模型编号", example = "17640")
@Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"}) @Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"})
private Long modelId; private Long modelId;
@Schema(description = "模型名字", example = "张三") @Schema(description = "模型名字", example = "张三")
private String modelName; private String modelName;

View File

@ -1,12 +1,12 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel; package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import lombok.*; import lombok.*;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
@Schema(description = "管理后台 - API 聊天模型分页 Request VO") @Schema(description = "管理后台 - API 模型分页 Request VO")
@Data @Data
public class AiChatModelPageReqVO extends PageParam { public class AiModelPageReqVO extends PageParam {
@Schema(description = "模型名字", example = "张三") @Schema(description = "模型名字", example = "张三")
private String name; private String name;

View File

@ -1,13 +1,13 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel; package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
import java.time.LocalDateTime; import java.time.LocalDateTime;
@Schema(description = "管理后台 - AI 聊天模型 Response VO") @Schema(description = "管理后台 - AI 模型 Response VO")
@Data @Data
public class AiChatModelRespVO { public class AiModelRespVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2630") @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2630")
private Long id; private Long id;
@ -24,6 +24,9 @@ public class AiChatModelRespVO {
@Schema(description = "模型平台", example = "OpenAI") @Schema(description = "模型平台", example = "OpenAI")
private String platform; private String platform;
@Schema(description = "模型类型", example = "1")
private Integer type;
@Schema(description = "排序", example = "1") @Schema(description = "排序", example = "1")
private Integer sort; private Integer sort;

View File

@ -1,14 +1,17 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel; package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.validation.InEnum; import cn.iocoder.yudao.framework.common.validation.InEnum;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*; import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.*; import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - API 聊天模型新增/修改 Request VO") @Schema(description = "管理后台 - API 模型新增/修改 Request VO")
@Data @Data
public class AiChatModelSaveReqVO { public class AiModelSaveReqVO {
@Schema(description = "编号", example = "2630") @Schema(description = "编号", example = "2630")
private Long id; private Long id;
@ -27,8 +30,14 @@ public class AiChatModelSaveReqVO {
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI") @Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
@NotEmpty(message = "模型平台不能为空") @NotEmpty(message = "模型平台不能为空")
@InEnum(AiPlatformEnum.class)
private String platform; private String platform;
@Schema(description = "模型类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "模型类型不能为空")
@InEnum(AiModelTypeEnum.class)
private Integer type;
@Schema(description = "排序", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "排序", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "排序不能为空") @NotNull(message = "排序不能为空")
private Integer sort; private Integer sort;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
@ -76,13 +76,13 @@ public class AiChatConversationDO extends BaseDO {
/** /**
* 模型编号 * 模型编号
* *
* 关联 {@link AiChatModelDO#getId()} 字段 * 关联 {@link AiModelDO#getId()} 字段
*/ */
private Long modelId; private Long modelId;
/** /**
* 模型标志 * 模型标志
* *
* 冗余 {@link AiChatModelDO#getModel()} 字段 * 冗余 {@link AiModelDO#getModel()} 字段
*/ */
private String model; private String model;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableField;
@ -83,13 +83,13 @@ public class AiChatMessageDO extends BaseDO {
/** /**
* 模型标志 * 模型标志
* *
* 冗余 {@link AiChatModelDO#getModel()} * 冗余 {@link AiModelDO#getModel()}
*/ */
private String model; private String model;
/** /**
* 模型编号 * 模型编号
* *
* 关联 {@link AiChatModelDO#getId()} 字段 * 关联 {@link AiModelDO#getId()} 字段
*/ */
private Long modelId; private Long modelId;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.image;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO; import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
@ -52,11 +52,16 @@ public class AiImageDO extends BaseDO {
* 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum} * 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
*/ */
private String platform; private String platform;
// TODO @芋艿modelId
/** /**
* 模型 * 模型编号
* *
* 冗余 {@link AiChatModelDO#getModel()} * 关联 {@link AiModelDO#getId()}
*/
private Long modelId;
/**
* 模型标识
*
* 冗余 {@link AiModelDO#getModel()}
*/ */
private String model; private String model;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
@ -35,13 +35,13 @@ public class AiKnowledgeDO extends BaseDO {
/** /**
* 向量模型编号 * 向量模型编号
* *
* 关联 {@link AiChatModelDO#getId()} * 关联 {@link AiModelDO#getId()}
*/ */
private Long embeddingModelId; private Long embeddingModelId;
/** /**
* 模型标识 * 模型标识
* *
* 冗余 {@link AiChatModelDO#getModel()} * 冗余 {@link AiModelDO#getModel()}
*/ */
private String embeddingModel; private String embeddingModel;

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
@ -36,7 +37,12 @@ public class AiMindMapDO extends BaseDO {
* 枚举 {@link AiPlatformEnum} * 枚举 {@link AiPlatformEnum}
*/ */
private String platform; private String platform;
// TODO @芋艿modelId /**
* 模型编号
*
* 关联 {@link AiModelDO#getId()}
*/
private Long modelId;
/** /**
* 模型 * 模型
*/ */

View File

@ -58,7 +58,7 @@ public class AiChatRoleDO extends BaseDO {
/** /**
* 模型编号 * 模型编号
* *
* 关联 {@link AiChatModelDO#getId()} 字段 * 关联 {@link AiModelDO#getId()} 字段
*/ */
private Long modelId; private Long modelId;

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.model; package cn.iocoder.yudao.module.ai.dal.dataobject.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
@ -8,23 +9,22 @@ import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*; import lombok.*;
// TODO @芋艿需要改造增加 type
/** /**
* AI 聊天模型 DO * AI 模型 DO
* *
* 默认聊天模型{@link #status} 为开启并且 {@link #sort} 排序第一 * 默认模型{@link #status} 为开启并且 {@link #sort} 排序第一
* *
* @author fansili * @author fansili
* @since 2024/4/24 19:39 * @since 2024/4/24 19:39
*/ */
@TableName("ai_chat_model") @TableName("ai_model")
@KeySequence("ai_chat_model_seq") // 用于 OraclePostgreSQLKingbaseDB2H2 数据库的主键自增如果是 MySQL 等数据库可不写 @KeySequence("ai_model_seq") // 用于 OraclePostgreSQLKingbaseDB2H2 数据库的主键自增如果是 MySQL 等数据库可不写
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public class AiChatModelDO extends BaseDO { public class AiModelDO extends BaseDO {
/** /**
* 编号 * 编号
@ -51,6 +51,12 @@ public class AiChatModelDO extends BaseDO {
* 枚举 {@link AiPlatformEnum} * 枚举 {@link AiPlatformEnum}
*/ */
private String platform; private String platform;
/**
* 类型
*
* 枚举 {@link AiModelTypeEnum}
*/
private Integer type;
/** /**
* 排序值 * 排序值

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.write;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
@ -44,7 +46,12 @@ public class AiWriteDO extends BaseDO {
* 枚举 {@link AiPlatformEnum} * 枚举 {@link AiPlatformEnum}
*/ */
private String platform; private String platform;
// TODO @芋艿modelId /**
* 模型编号
*
* 关联 {@link AiModelDO#getId()}
*/
private Long modelId;
/** /**
* 模型 * 模型
*/ */
@ -67,25 +74,25 @@ public class AiWriteDO extends BaseDO {
/** /**
* 长度提示词 * 长度提示词
* *
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LENGTH} * 字典{@link DictTypeConstants#AI_WRITE_LENGTH}
*/ */
private Integer length; private Integer length;
/** /**
* 格式提示词 * 格式提示词
* *
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_FORMAT} * 字典{@link DictTypeConstants#AI_WRITE_FORMAT}
*/ */
private Integer format; private Integer format;
/** /**
* 语气提示词 * 语气提示词
* *
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_TONE} * 字典{@link DictTypeConstants#AI_WRITE_TONE}
*/ */
private Integer tone; private Integer tone;
/** /**
* 语言提示词 * 语言提示词
* *
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LANGUAGE} * 字典{@link DictTypeConstants#AI_WRITE_LANGUAGE}
*/ */
private Integer language; private Integer language;

View File

@ -0,0 +1,47 @@
package cn.iocoder.yudao.module.ai.dal.mysql.model;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import org.apache.ibatis.annotations.Mapper;
import javax.annotation.Nullable;
import java.util.List;
/**
* API 模型 Mapper
*
* @author fansili
*/
@Mapper
public interface AiChatMapper extends BaseMapperX<AiModelDO> {
default AiModelDO selectFirstByStatus(Integer type, Integer status) {
return selectOne(new QueryWrapperX<AiModelDO>()
.eq("type", type)
.eq("status", status)
.limitN(1)
.orderByAsc("sort"));
}
default PageResult<AiModelDO> selectPage(AiModelPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiModelDO>()
.likeIfPresent(AiModelDO::getName, reqVO.getName())
.eqIfPresent(AiModelDO::getModel, reqVO.getModel())
.eqIfPresent(AiModelDO::getPlatform, reqVO.getPlatform())
.orderByAsc(AiModelDO::getSort));
}
default List<AiModelDO> selectListByStatusAndType(Integer status, Integer type,
@Nullable String platform) {
return selectList(new LambdaQueryWrapperX<AiModelDO>()
.eq(AiModelDO::getStatus, status)
.eq(AiModelDO::getType, type)
.eqIfPresent(AiModelDO::getPlatform, platform)
.orderByAsc(AiModelDO::getSort));
}
}

View File

@ -1,43 +0,0 @@
package cn.iocoder.yudao.module.ai.dal.mysql.model;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import org.apache.ibatis.annotations.Mapper;
import java.util.Collection;
import java.util.List;
/**
* API 聊天模型 Mapper
*
* @author fansili
*/
@Mapper
public interface AiChatModelMapper extends BaseMapperX<AiChatModelDO> {
default AiChatModelDO selectFirstByStatus(Integer status) {
return selectOne(new QueryWrapperX<AiChatModelDO>()
.eq("status", status)
.limitN(1)
.orderByAsc("sort"));
}
default PageResult<AiChatModelDO> selectPage(AiChatModelPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiChatModelDO>()
.likeIfPresent(AiChatModelDO::getName, reqVO.getName())
.eqIfPresent(AiChatModelDO::getModel, reqVO.getModel())
.eqIfPresent(AiChatModelDO::getPlatform, reqVO.getPlatform())
.orderByAsc(AiChatModelDO::getSort));
}
default List<AiChatModelDO> selectList(Integer status) {
return selectList(new LambdaQueryWrapperX<AiChatModelDO>()
.eq(AiChatModelDO::getStatus, status)
.orderByAsc(AiChatModelDO::getSort));
}
}

View File

@ -4,17 +4,18 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -44,7 +45,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
private AiChatConversationMapper chatConversationMapper; private AiChatConversationMapper chatConversationMapper;
@Resource @Resource
private AiChatModelService chatModalService; private AiModelService chatModalService;
@Resource @Resource
private AiChatRoleService chatRoleService; private AiChatRoleService chatRoleService;
@Resource @Resource
@ -54,9 +55,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) { public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) {
// 1.1 获得 AiChatRoleDO 聊天角色 // 1.1 获得 AiChatRoleDO 聊天角色
AiChatRoleDO role = createReqVO.getRoleId() != null ? chatRoleService.validateChatRole(createReqVO.getRoleId()) : null; AiChatRoleDO role = createReqVO.getRoleId() != null ? chatRoleService.validateChatRole(createReqVO.getRoleId()) : null;
// 1.2 获得 AiChatModelDO 聊天模型 // 1.2 获得 AiModelDO 聊天模型
AiChatModelDO model = role != null && role.getModelId() != null ? chatModalService.validateChatModel(role.getModelId()) AiModelDO model = role != null && role.getModelId() != null ? chatModalService.validateModel(role.getModelId())
: chatModalService.getRequiredDefaultChatModel(); : chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
Assert.notNull(model, "必须找到默认模型"); Assert.notNull(model, "必须找到默认模型");
validateChatModel(model); validateChatModel(model);
@ -86,9 +87,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
throw exception(CHAT_CONVERSATION_NOT_EXISTS); throw exception(CHAT_CONVERSATION_NOT_EXISTS);
} }
// 1.2 校验模型是否存在修改模型的情况 // 1.2 校验模型是否存在修改模型的情况
AiChatModelDO model = null; AiModelDO model = null;
if (updateReqVO.getModelId() != null) { if (updateReqVO.getModelId() != null) {
model = chatModalService.validateChatModel(updateReqVO.getModelId()); model = chatModalService.validateModel(updateReqVO.getModelId());
} }
// 1.3 校验知识库是否存在 // 1.3 校验知识库是否存在
@ -139,7 +140,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
chatConversationMapper.deleteById(id); chatConversationMapper.deleteById(id);
} }
private void validateChatModel(AiChatModelDO model) { private void validateChatModel(AiModelDO model) {
if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) { if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) {
return; return;
} }

View File

@ -15,13 +15,12 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
@ -63,9 +62,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Resource @Resource
private AiChatConversationService chatConversationService; private AiChatConversationService chatConversationService;
@Resource @Resource
private AiChatModelService chatModalService; private AiModelService modalService;
@Resource
private AiApiKeyService apiKeyService;
@Resource @Resource
private AiKnowledgeSegmentService knowledgeSegmentService; private AiKnowledgeSegmentService knowledgeSegmentService;
@ -78,8 +75,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiModelDO model = modalService.validateModel(conversation.getModelId());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); ChatModel chatModel = modalService.getChatModel(model.getKeyId());
// 2. 插入 user 发送消息 // 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -112,8 +109,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiModelDO model = modalService.validateModel(conversation.getModelId());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId());
// 2. 插入 user 发送消息 // 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -161,8 +158,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return null; return null;
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList, private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, List<AiKnowledgeSegmentDO> segmentList,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
// 1. 构建 Prompt Message 列表 // 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
@ -232,7 +229,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId, private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
AiChatModelDO model, Long userId, Long roleId, AiModelDO model, Long userId, Long roleId,
MessageType messageType, String content, Boolean useContext) { MessageType messageType, String content, Boolean useContext) {
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId) AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId) .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)

View File

@ -12,13 +12,17 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.infra.api.file.FileApi; import cn.iocoder.yudao.module.infra.api.file.FileApi;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions; import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
@ -54,15 +58,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
@Slf4j @Slf4j
public class AiImageServiceImpl implements AiImageService { public class AiImageServiceImpl implements AiImageService {
@Resource
private AiModelService modelService;
@Resource @Resource
private AiImageMapper imageMapper; private AiImageMapper imageMapper;
@Resource @Resource
private FileApi fileApi; private FileApi fileApi;
@Resource
private AiApiKeyService apiKeyService;
@Override @Override
public PageResult<AiImageDO> getImagePageMy(Long userId, AiImagePageReqVO pageReqVO) { public PageResult<AiImageDO> getImagePageMy(Long userId, AiImagePageReqVO pageReqVO) {
return imageMapper.selectPageMy(userId, pageReqVO); return imageMapper.selectPageMy(userId, pageReqVO);
@ -88,23 +92,31 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) { public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
// 1. 保存数据库 // 1. 校验模型
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
// 2. 保存数据库
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId)
.setPlatform(model.getPlatform()).setModelId(model.getId()).setModel(model.getModel())
.setPublicStatus(false).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(image); imageMapper.insert(image);
// 2. 异步绘制后续前端通过返回的 id 进行轮询结果
getSelf().executeDrawImage(image, drawReqVO); // 3. 异步绘制后续前端通过返回的 id 进行轮询结果
getSelf().executeDrawImage(image, drawReqVO, model);
return image.getId(); return image.getId();
} }
@Async @Async
public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) { public void executeDrawImage(AiImageDO image, AiImageDrawReqVO reqVO, AiModelDO model) {
try { try {
// 1.1 构建请求 // 1.1 构建请求
ImageOptions request = buildImageOptions(req); ImageOptions request = buildImageOptions(reqVO, model);
// 1.2 执行请求 // 1.2 执行请求
ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform())); ImageModel imageModel = modelService.getImageModel(model.getId());
ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request)); ImageResponse response = imageModel.call(new ImagePrompt(reqVO.getPrompt(), request));
if (response.getResult() == null) {
throw new IllegalArgumentException("生成结果为空");
}
// 2. 上传到文件服务 // 2. 上传到文件服务
String b64Json = response.getResult().getOutput().getB64Json(); String b64Json = response.getResult().getOutput().getB64Json();
@ -116,25 +128,25 @@ public class AiImageServiceImpl implements AiImageService {
imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus()) imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus())
.setPicUrl(filePath).setFinishTime(LocalDateTime.now())); .setPicUrl(filePath).setFinishTime(LocalDateTime.now()));
} catch (Exception ex) { } catch (Exception ex) {
log.error("[doDall][image({}) 生成异常]", image, ex); log.error("[executeDrawImage][image({}) 生成异常]", image, ex);
imageMapper.updateById(new AiImageDO().setId(image.getId()) imageMapper.updateById(new AiImageDO().setId(image.getId())
.setStatus(AiImageStatusEnum.FAIL.getStatus()) .setStatus(AiImageStatusEnum.FAIL.getStatus())
.setErrorMessage(ex.getMessage()).setFinishTime(LocalDateTime.now())); .setErrorMessage(ex.getMessage()).setFinishTime(LocalDateTime.now()));
} }
} }
private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) { private static ImageOptions buildImageOptions(AiImageDrawReqVO draw, AiModelDO model) {
if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) { if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
// https://platform.openai.com/docs/api-reference/images/create // https://platform.openai.com/docs/api-reference/images/create
return OpenAiImageOptions.builder().withModel(draw.getModel()) return OpenAiImageOptions.builder().withModel(model.getModel())
.withHeight(draw.getHeight()).withWidth(draw.getWidth()) .withHeight(draw.getHeight()).withWidth(draw.getWidth())
.withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格 .withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
.withResponseFormat("b64_json") .withResponseFormat("b64_json")
.build(); .build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) { } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage // https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage // https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
return StabilityAiImageOptions.builder().model(draw.getModel()) return StabilityAiImageOptions.builder().model(model.getModel())
.height(draw.getHeight()).width(draw.getWidth()) .height(draw.getHeight()).width(draw.getWidth())
.seed(Long.valueOf(draw.getOptions().get("seed"))) .seed(Long.valueOf(draw.getOptions().get("seed")))
.cfgScale(Float.valueOf(draw.getOptions().get("scale"))) .cfgScale(Float.valueOf(draw.getOptions().get("scale")))
@ -143,22 +155,22 @@ public class AiImageServiceImpl implements AiImageService {
.stylePreset(String.valueOf(draw.getOptions().get("stylePreset"))) .stylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
.clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset"))) .clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
.build(); .build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) { } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
return DashScopeImageOptions.builder() return DashScopeImageOptions.builder()
.withModel(draw.getModel()).withN(1) .withModel(model.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth()) .withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build(); .build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) { } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
return QianFanImageOptions.builder() return QianFanImageOptions.builder()
.model(draw.getModel()).N(1) .model(model.getModel()).N(1)
.height(draw.getHeight()).width(draw.getWidth()) .height(draw.getHeight()).width(draw.getWidth())
.build(); .build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) { } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
return ZhiPuAiImageOptions.builder() return ZhiPuAiImageOptions.builder()
.model(draw.getModel()) .model(model.getModel())
.build(); .build();
} }
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); throw new IllegalArgumentException("不支持的 AI 平台:" + model.getPlatform());
} }
@Override @Override
@ -206,7 +218,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) { public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi(); MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
// 1. 保存数据库 // 1. 保存数据库
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()) .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
@ -237,7 +249,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
public Integer midjourneySync() { public Integer midjourneySync() {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi(); MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
// 1.1 获取 Midjourney 平台状态在 进行中 image // 1.1 获取 Midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform( List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform()); AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
@ -308,7 +320,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) { public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi(); MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
// 1.1 检查 image // 1.1 检查 image
AiImageDO image = validateImageExists(reqVO.getId()); AiImageDO image = validateImageExists(reqVO.getId());
if (ObjUtil.notEqual(userId, image.getUserId())) { if (ObjUtil.notEqual(userId, image.getUserId())) {

View File

@ -7,14 +7,17 @@ import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.*; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSaveReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
@ -33,8 +36,8 @@ import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
/** /**
* AI 知识库分片 Service 实现类 * AI 知识库分片 Service 实现类
@ -58,7 +61,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Lazy // 延迟加载避免循环依赖 @Lazy // 延迟加载避免循环依赖
private AiKnowledgeDocumentService knowledgeDocumentService; private AiKnowledgeDocumentService knowledgeDocumentService;
@Resource @Resource
private AiApiKeyService apiKeyService; private AiModelService modelService;
@Resource @Resource
private TokenCountEstimator tokenCountEstimator; private TokenCountEstimator tokenCountEstimator;
@ -180,7 +183,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId()); AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
// 2.1 向量检索 // 2.1 向量检索
VectorStore vectorStore = apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId()); VectorStore vectorStore = getVectorStoreById(knowledge);
List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder() List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
.query(reqBO.getContent()) .query(reqBO.getContent())
.topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK())) .topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()))
@ -251,11 +254,12 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
} }
private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) { private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) {
return apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId()); return modelService.getOrCreateVectorStore(knowledge.getEmbeddingModelId());
} }
private VectorStore getVectorStoreById(Long knowledgeId) { private VectorStore getVectorStoreById(Long knowledgeId) {
return getVectorStoreById(knowledgeService.validateKnowledgeExists(knowledgeId)); AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(knowledgeId);
return getVectorStoreById(knowledge);
} }
private static List<Document> splitContentByToken(String content, Integer segmentMaxTokens) { private static List<Document> splitContentByToken(String content, Integer segmentMaxTokens) {

View File

@ -5,9 +5,9 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -28,12 +28,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
private AiKnowledgeMapper knowledgeMapper; private AiKnowledgeMapper knowledgeMapper;
@Resource @Resource
private AiChatModelService chatModelService; private AiModelService chatModelService;
@Override @Override
public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) { public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) {
// 1. 校验模型配置 // 1. 校验模型配置
AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getEmbeddingModelId()); AiModelDO model = chatModelService.validateModel(createReqVO.getEmbeddingModelId());
// 2. 插入知识库 // 2. 插入知识库
AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class) AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
@ -47,7 +47,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 1.1 校验知识库存在 // 1.1 校验知识库存在
validateKnowledgeExists(updateReqVO.getId()); validateKnowledgeExists(updateReqVO.getId());
// 1.2 校验模型配置 // 1.2 校验模型配置
AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getEmbeddingModelId()); AiModelDO model = chatModelService.validateModel(updateReqVO.getEmbeddingModelId());
// 2. 更新知识库 // 2. 更新知识库
AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class) AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class)

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
@ -12,14 +13,13 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper; import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
@ -50,9 +50,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_E
public class AiMindMapServiceImpl implements AiMindMapService { public class AiMindMapServiceImpl implements AiMindMapService {
@Resource @Resource
private AiApiKeyService apiKeyService; private AiModelService modalService;
@Resource
private AiChatModelService chatModalService;
@Resource @Resource
private AiChatRoleService chatRoleService; private AiChatRoleService chatRoleService;
@ -65,17 +63,17 @@ public class AiMindMapServiceImpl implements AiMindMapService {
AiChatRoleDO role = CollUtil.getFirst( AiChatRoleDO role = CollUtil.getFirst(
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
// 1.1 获取导图执行模型 // 1.1 获取导图执行模型
AiChatModelDO model = getModel(role); AiModelDO model = getModel(role);
// 1.2 获取角色设定消息 // 1.2 获取角色设定消息
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage()) String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
// 1.3 校验平台 // 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); ChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 插入思维导图信息 // 2. 插入思维导图信息
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, mindMap -> mindMap.setUserId(userId)
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); .setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
mindMapMapper.insert(mindMapDO); mindMapMapper.insert(mindMapDO);
// 3.1 构建 Prompt并进行调用 // 3.1 构建 Prompt并进行调用
@ -103,7 +101,7 @@ public class AiMindMapServiceImpl implements AiMindMapService {
} }
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
// 1. 构建 message 列表 // 1. 构建 message 列表
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage); List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象 // 2. 构建 options 对象
@ -123,13 +121,13 @@ public class AiMindMapServiceImpl implements AiMindMapService {
return chatMessages; return chatMessages;
} }
private AiChatModelDO getModel(AiChatRoleDO role) { private AiModelDO getModel(AiChatRoleDO role) {
AiChatModelDO model = null; AiModelDO model = null;
if (role != null && role.getModelId() != null) { if (role != null && role.getModelId() != null) {
model = chatModalService.getChatModel(role.getModelId()); model = modalService.getModel(role.getModelId());
} }
if (model == null) { if (model == null) {
model = chatModalService.getRequiredDefaultChatModel(); model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
} }
Assert.notNull(model, "[AI] 获取不到模型"); Assert.notNull(model, "[AI] 获取不到模型");
return model; return model;

View File

@ -1,16 +1,10 @@
package cn.iocoder.yudao.module.ai.service.model; package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import java.util.List; import java.util.List;
@ -74,50 +68,13 @@ public interface AiApiKeyService {
*/ */
List<AiApiKeyDO> getApiKeyList(); List<AiApiKeyDO> getApiKeyList();
// ========== spring-ai 集成 ==========
/** /**
* 获得 ChatModel 对象 * 获得默认的 API 密钥
*
* @param id 编号
* @return ChatModel 对象
*/
ChatModel getChatModel(Long id);
/**
* 获得 ImageModel 对象
*
* TODO 可优化点目前默认获取 platform 对应的第一个开启的配置用于绘画后续可以支持配置选择
* *
* @param platform 平台 * @param platform 平台
* @return ImageModel 对象 * @param status 状态
* @return API 密钥
*/ */
ImageModel getImageModel(AiPlatformEnum platform); AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status);
/**
* 获得 MidjourneyApi 对象
*
* TODO 可优化点目前默认获取 Midjourney 对应的第一个开启的配置用于绘画后续可以支持配置选择
*
* @return MidjourneyApi 对象
*/
MidjourneyApi getMidjourneyApi();
/**
* 获得 SunoApi 对象
*
* TODO 可优化点目前默认获取 Suno 对应的第一个开启的配置用于音乐后续可以支持配置选择
*
* @return SunoApi 对象
*/
SunoApi getSunoApi();
/**
* 获得 VectorStore 对象
*
* @param modelId 编号
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStoreByModelId(Long modelId);
} }

View File

@ -1,31 +1,21 @@
package cn.iocoder.yudao.module.ai.service.model; package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import java.util.List; import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_DISABLE;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_NOT_EXISTS;
/** /**
* AI API 密钥 Service 实现类 * AI API 密钥 Service 实现类
@ -39,14 +29,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource @Resource
private AiApiKeyMapper apiKeyMapper; private AiApiKeyMapper apiKeyMapper;
// TODO @芋艿后续要不要改
@Resource
@Lazy // 延迟加载解决渲染依赖
private AiChatModelService chatModelService;
@Resource
private AiModelFactory modelFactory;
@Override @Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) { public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
// 插入 // 插入
@ -105,57 +87,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return apiKeyMapper.selectList(); return apiKeyMapper.selectList();
} }
// ========== spring-ai 集成 ==========
@Override @Override
public ChatModel getChatModel(Long id) { public AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status) {
AiApiKeyDO apiKey = validateApiKey(id); AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform, status);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public ImageModel getImageModel(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) { if (apiKey == null) {
throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName()); throw exception(API_KEY_NOT_EXISTS);
} }
return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl()); return apiKey;
}
@Override
public MidjourneyApi getMidjourneyApi() {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) {
throw exception(API_KEY_MIDJOURNEY_NOT_FOUND);
}
return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public SunoApi getSunoApi() {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) {
throw exception(API_KEY_SUNO_NOT_FOUND);
}
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public VectorStore getOrCreateVectorStoreByModelId(Long modelId) {
// 获取模型 + 密钥
AiChatModelDO chatModel = chatModelService.validateChatModel(modelId);
AiApiKeyDO apiKey = validateApiKey(chatModel.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
// 创建或获取 EmbeddingModel 对象
EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(),
apiKey.getUrl(), chatModel.getModel());
// 创建或获取 VectorStore 对象
return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
} }
} }

View File

@ -1,92 +0,0 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import jakarta.validation.Valid;
import java.util.Collection;
import java.util.List;
import java.util.Set;
/**
* AI 聊天模型 Service 接口
*
* @author fansili
* @since 2024/4/24 19:42
*/
public interface AiChatModelService {
/**
* 创建聊天模型
*
* @param createReqVO 创建信息
* @return 编号
*/
Long createChatModel(@Valid AiChatModelSaveReqVO createReqVO);
/**
* 更新聊天模型
*
* @param updateReqVO 更新信息
*/
void updateChatModel(@Valid AiChatModelSaveReqVO updateReqVO);
/**
* 删除聊天模型
*
* @param id 编号
*/
void deleteChatModel(Long id);
/**
* 获得聊天模型
*
* @param id 编号
* @return 聊天模型
*/
AiChatModelDO getChatModel(Long id);
/**
* 获得默认的聊天模型
*
* 如果获取不到则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
*
* @return 聊天模型
*/
AiChatModelDO getRequiredDefaultChatModel();
/**
* 获得聊天模型分页
*
* @param pageReqVO 分页查询
* @return 聊天模型分页
*/
PageResult<AiChatModelDO> getChatModelPage(AiChatModelPageReqVO pageReqVO);
/**
* 校验聊天模型
*
* @param id 编号
* @return 聊天模型
*/
AiChatModelDO validateChatModel(Long id);
/**
* 获得聊天模型列表
*
* @param status 状态
* @return 聊天模型列表
*/
List<AiChatModelDO> getChatModelListByStatus(Integer status);
/**
* 获得聊天模型列表
*
* @param ids 编号数组
* @return 模型列表
*/
List<AiChatModelDO> getChatModelList(Collection<Long> ids);
}

View File

@ -1,114 +1,168 @@
package cn.iocoder.yudao.module.ai.service.model; package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatModelMapper; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatMapper;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import java.util.Collection;
import java.util.List; import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/** /**
* AI 聊天模型 Service 实现类 * AI 模型 Service 实现类
* *
* @author fansili * @author fansili
*/ */
@Service @Service
@Validated @Validated
public class AiChatModelServiceImpl implements AiChatModelService { public class AiChatModelServiceImpl implements AiModelService {
@Resource @Resource
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Resource @Resource
private AiChatModelMapper chatModelMapper; private AiChatMapper modelMapper;
@Resource
private AiModelFactory modelFactory;
@Override @Override
public Long createChatModel(AiChatModelSaveReqVO createReqVO) { public Long createModel(AiModelSaveReqVO createReqVO) {
// 1. 校验 // 1. 校验
AiPlatformEnum.validatePlatform(createReqVO.getPlatform()); AiPlatformEnum.validatePlatform(createReqVO.getPlatform());
apiKeyService.validateApiKey(createReqVO.getKeyId()); apiKeyService.validateApiKey(createReqVO.getKeyId());
// 2. 插入 // 2. 插入
AiChatModelDO chatModel = BeanUtils.toBean(createReqVO, AiChatModelDO.class); AiModelDO model = BeanUtils.toBean(createReqVO, AiModelDO.class);
chatModelMapper.insert(chatModel); modelMapper.insert(model);
return chatModel.getId(); return model.getId();
} }
@Override @Override
public void updateChatModel(AiChatModelSaveReqVO updateReqVO) { public void updateModel(AiModelSaveReqVO updateReqVO) {
// 1. 校验 // 1. 校验
validateChatModelExists(updateReqVO.getId()); validateModelExists(updateReqVO.getId());
AiPlatformEnum.validatePlatform(updateReqVO.getPlatform()); AiPlatformEnum.validatePlatform(updateReqVO.getPlatform());
apiKeyService.validateApiKey(updateReqVO.getKeyId()); apiKeyService.validateApiKey(updateReqVO.getKeyId());
// 2. 更新 // 2. 更新
AiChatModelDO updateObj = BeanUtils.toBean(updateReqVO, AiChatModelDO.class); AiModelDO updateObj = BeanUtils.toBean(updateReqVO, AiModelDO.class);
chatModelMapper.updateById(updateObj); modelMapper.updateById(updateObj);
} }
@Override @Override
public void deleteChatModel(Long id) { public void deleteModel(Long id) {
// 校验存在 // 校验存在
validateChatModelExists(id); validateModelExists(id);
// 删除 // 删除
chatModelMapper.deleteById(id); modelMapper.deleteById(id);
} }
private AiChatModelDO validateChatModelExists(Long id) { private AiModelDO validateModelExists(Long id) {
AiChatModelDO model = chatModelMapper.selectById(id); AiModelDO model = modelMapper.selectById(id);
if (chatModelMapper.selectById(id) == null) { if (modelMapper.selectById(id) == null) {
throw exception(CHAT_MODEL_NOT_EXISTS); throw exception(MODEL_NOT_EXISTS);
} }
return model; return model;
} }
@Override @Override
public AiChatModelDO getChatModel(Long id) { public AiModelDO getModel(Long id) {
return chatModelMapper.selectById(id); return modelMapper.selectById(id);
} }
@Override @Override
public AiChatModelDO getRequiredDefaultChatModel() { public AiModelDO getRequiredDefaultModel(Integer type) {
AiChatModelDO model = chatModelMapper.selectFirstByStatus(CommonStatusEnum.ENABLE.getStatus()); AiModelDO model = modelMapper.selectFirstByStatus(type, CommonStatusEnum.ENABLE.getStatus());
if (model == null) { if (model == null) {
throw exception(CHAT_MODEL_DEFAULT_NOT_EXISTS); throw exception(MODEL_DEFAULT_NOT_EXISTS);
} }
return model; return model;
} }
@Override @Override
public PageResult<AiChatModelDO> getChatModelPage(AiChatModelPageReqVO pageReqVO) { public PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO) {
return chatModelMapper.selectPage(pageReqVO); return modelMapper.selectPage(pageReqVO);
} }
@Override @Override
public AiChatModelDO validateChatModel(Long id) { public AiModelDO validateModel(Long id) {
AiChatModelDO model = validateChatModelExists(id); AiModelDO model = validateModelExists(id);
if (CommonStatusEnum.isDisable(model.getStatus())) { if (CommonStatusEnum.isDisable(model.getStatus())) {
throw exception(CHAT_MODEL_DISABLE); throw exception(MODEL_DISABLE);
} }
return model; return model;
} }
@Override @Override
public List<AiChatModelDO> getChatModelListByStatus(Integer status) { public List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type,
return chatModelMapper.selectList(status); String platform) {
return modelMapper.selectListByStatusAndType(status, type, platform);
}
// ========== Spring AI 集成 ==========
@Override
public ChatModel getChatModel(Long id) {
AiModelDO model = validateModel(id);
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
} }
@Override @Override
public List<AiChatModelDO> getChatModelList(Collection<Long> ids) { public ImageModel getImageModel(Long id) {
return chatModelMapper.selectBatchIds(ids); AiModelDO model = validateModel(id);
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public MidjourneyApi getMidjourneyApi() {
AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public SunoApi getSunoApi() {
AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public VectorStore getOrCreateVectorStore(Long id) {
// 获取模型 + 密钥
AiModelDO model = validateModel(id);
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
// 创建或获取 EmbeddingModel 对象
EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(
platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel());
// 创建或获取 VectorStore 对象
return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
} }
} }

View File

@ -0,0 +1,131 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import javax.annotation.Nullable;
import java.util.List;
/**
* AI 模型 Service 接口
*
* @author fansili
* @since 2024/4/24 19:42
*/
public interface AiModelService {
/**
* 创建模型
*
* @param createReqVO 创建信息
* @return 编号
*/
Long createModel(@Valid AiModelSaveReqVO createReqVO);
/**
* 更新模型
*
* @param updateReqVO 更新信息
*/
void updateModel(@Valid AiModelSaveReqVO updateReqVO);
/**
* 删除模型
*
* @param id 编号
*/
void deleteModel(Long id);
/**
* 获得模型
*
* @param id 编号
* @return 模型
*/
AiModelDO getModel(Long id);
/**
* 获得默认的模型
*
* 如果获取不到则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
*
* @return 模型
*/
AiModelDO getRequiredDefaultModel(Integer type);
/**
* 获得模型分页
*
* @param pageReqVO 分页查询
* @return 模型分页
*/
PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO);
/**
* 校验模型是否可使用
*
* @param id 编号
* @return 模型
*/
AiModelDO validateModel(Long id);
/**
* 获得模型列表
*
* @param status 状态
* @param type 类型
* @param platform 平台允许空
* @return 模型列表
*/
List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type,
@Nullable String platform);
// ========== Spring AI 集成 ==========
/**
* 获得 ChatModel 对象
*
* @param id 编号
* @return ChatModel 对象
*/
ChatModel getChatModel(Long id);
/**
* 获得 ImageModel 对象
*
* @param id 编号
* @return ImageModel 对象
*/
ImageModel getImageModel(Long id);
/**
* 获得 MidjourneyApi 对象
*
* @return MidjourneyApi 对象
*/
MidjourneyApi getMidjourneyApi();
/**
* 获得 SunoApi 对象
*
* @return SunoApi 对象
*/
SunoApi getSunoApi();
/**
* 获得 VectorStore 对象
*
* @param id 编号
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(Long id);
}

View File

@ -16,7 +16,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper; import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum; import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum; import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.infra.api.file.FileApi; import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -41,7 +41,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MUSIC_NOT_EXIS
public class AiMusicServiceImpl implements AiMusicService { public class AiMusicServiceImpl implements AiMusicService {
@Resource @Resource
private AiApiKeyService apiKeyService; private AiModelService modelService;
@Resource @Resource
private AiMusicMapper musicMapper; private AiMusicMapper musicMapper;
@ -53,7 +53,7 @@ public class AiMusicServiceImpl implements AiMusicService {
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) { public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
// 1. 调用 Suno 生成音乐 // 1. 调用 Suno 生成音乐
SunoApi sunoApi = apiKeyService.getSunoApi(); SunoApi sunoApi = modelService.getSunoApi();
List<SunoApi.MusicData> musicDataList; List<SunoApi.MusicData> musicDataList;
if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) { if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
// 1.1 描述模式 // 1.1 描述模式
@ -88,7 +88,7 @@ public class AiMusicServiceImpl implements AiMusicService {
log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size()); log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size());
// GET 请求为避免参数过长分批次处理 // GET 请求为避免参数过长分批次处理
SunoApi sunoApi = apiKeyService.getSunoApi(); SunoApi sunoApi = modelService.getSunoApi();
CollUtil.split(streamingTask, 36).forEach(chunkList -> { CollUtil.split(streamingTask, 36).forEach(chunkList -> {
Map<String, Long> taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId); Map<String, Long> taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId);
List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet())); List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
@ -11,17 +12,16 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWritePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWritePageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.system.api.dict.DictDataApi; import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -54,17 +54,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.WRITE_NOT_EXIS
public class AiWriteServiceImpl implements AiWriteService { public class AiWriteServiceImpl implements AiWriteService {
@Resource @Resource
private AiApiKeyService apiKeyService; private AiModelService chatModalService;
@Resource
private AiChatModelService chatModalService;
@Resource @Resource
private AiChatRoleService chatRoleService; private AiChatRoleService chatRoleService;
@Resource @Resource
private DictDataApi dictDataApi; private AiWriteMapper writeMapper;
@Resource @Resource
private AiWriteMapper writeMapper; private DictDataApi dictDataApi;
@Override @Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
@ -72,17 +70,17 @@ public class AiWriteServiceImpl implements AiWriteService {
AiChatRoleDO writeRole = CollUtil.getFirst( AiChatRoleDO writeRole = CollUtil.getFirst(
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
// 1.1 获取写作执行模型 // 1.1 获取写作执行模型
AiChatModelDO model = getModel(writeRole); AiModelDO model = getModel(writeRole);
// 1.2 获取角色设定消息 // 1.2 获取角色设定消息
String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage()) String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
// 1.3 校验平台 // 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); StreamingChatModel chatModel = chatModalService.getChatModel(model.getKeyId());
// 2. 插入写作信息 // 2. 插入写作信息
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId)
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel())); .setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
writeMapper.insert(writeDO); writeMapper.insert(writeDO);
// 3.1 构建 Prompt并进行调用 // 3.1 构建 Prompt并进行调用
@ -109,19 +107,19 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
} }
private AiChatModelDO getModel(AiChatRoleDO writeRole) { private AiModelDO getModel(AiChatRoleDO writeRole) {
AiChatModelDO model = null; AiModelDO model = null;
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
model = chatModalService.getChatModel(writeRole.getModelId()); model = chatModalService.getModel(writeRole.getModelId());
} }
if (model == null) { if (model == null) {
model = chatModalService.getRequiredDefaultChatModel(); model = chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
} }
Assert.notNull(model, "[AI] 获取不到模型"); Assert.notNull(model, "[AI] 获取不到模型");
return model; return model;
} }
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
// 1. 构建 message 列表 // 1. 构建 message 列表
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage); List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象 // 2. 构建 options 对象

View File

@ -0,0 +1,41 @@
package cn.iocoder.yudao.framework.ai.core.enums;
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.Arrays;
/**
* AI 模型类型的枚举
*
* @author 芋道源码
*/
@Getter
@RequiredArgsConstructor
public enum AiModelTypeEnum implements ArrayValuable<Integer> {
CHAT(1, "对话"),
IMAGE(2, "图片"),
VOICE(3, "语音"),
VIDEO(4, "视频"),
EMBEDDING(5, "向量"),
RERANK(6, "重排序");
/**
* 类型
*/
private final Integer type;
/**
* 类型名
*/
private final String name;
public static final Integer[] ARRAYS = Arrays.stream(values()).map(AiModelTypeEnum::getType).toArray(Integer[]::new);
@Override
public Integer[] array() {
return ARRAYS;
}
}

View File

@ -1,8 +1,11 @@
package cn.iocoder.yudao.framework.ai.core.enums; package cn.iocoder.yudao.framework.ai.core.enums;
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import java.util.Arrays;
/** /**
* AI 模型平台 * AI 模型平台
* *
@ -10,7 +13,7 @@ import lombok.Getter;
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public enum AiPlatformEnum { public enum AiPlatformEnum implements ArrayValuable<String> {
// ========== 国内平台 ========== // ========== 国内平台 ==========
@ -44,6 +47,8 @@ public enum AiPlatformEnum {
*/ */
private final String name; private final String name;
public static final String[] ARRAYS = Arrays.stream(values()).map(AiPlatformEnum::getPlatform).toArray(String[]::new);
public static AiPlatformEnum validatePlatform(String platform) { public static AiPlatformEnum validatePlatform(String platform) {
for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) { for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) {
if (platformEnum.getPlatform().equals(platform)) { if (platformEnum.getPlatform().equals(platform)) {
@ -53,4 +58,9 @@ public enum AiPlatformEnum {
throw new IllegalArgumentException("非法平台: " + platform); throw new IllegalArgumentException("非法平台: " + platform);
} }
@Override
public String[] array() {
return ARRAYS;
}
} }

View File

@ -456,25 +456,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
return vectorStore; return vectorStore;
} }
/**
* 创建向量存储文件
*
* @param embeddingModel 嵌入模型
* @return 向量存储文件
*/
private File createVectorStoreFile(EmbeddingModel embeddingModel) {
// 获取简单类名
String simpleClassName = embeddingModel.getClass().getSimpleName();
// 获取用户主目录
String userHome = FileUtil.getUserHomePath();
// 创建vector_store目录
File vectorStoreDir = new File(userHome, "vector_store");
if (!vectorStoreDir.exists()) {
vectorStoreDir.mkdirs();
}
// 创建文件
return new File(vectorStoreDir, "simple_" + simpleClassName + ".json");
}
} }