diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java index 2b236e6f6a..205fc80b10 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java @@ -12,31 +12,25 @@ public interface ErrorCodeConstants { // ========== API 密钥 1-040-000-000 ========== ErrorCode API_KEY_NOT_EXISTS = new ErrorCode(1_040_000_000, "API 密钥不存在"); ErrorCode API_KEY_DISABLE = new ErrorCode(1_040_000_001, "API 密钥已禁用!"); - ErrorCode API_KEY_MIDJOURNEY_NOT_FOUND = new ErrorCode(1_040_000_900, "Midjourney 模型不存在"); - ErrorCode API_KEY_SUNO_NOT_FOUND = new ErrorCode(1_040_000_901, "Suno 模型不存在"); - ErrorCode API_KEY_IMAGE_NODE_FOUND = new ErrorCode(1_040_000_902, "平台({}) 图片模型未配置"); - // ========== API 聊天模型 1-040-001-000 ========== - ErrorCode CHAT_MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!"); - ErrorCode CHAT_MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!"); - ErrorCode CHAT_MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认聊天模型"); + // ========== API 模型 1-040-001-000 ========== + ErrorCode MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!"); + ErrorCode MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!"); + ErrorCode MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认模型"); // ========== API 聊天角色 1-040-002-000 ========== ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在"); ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!"); // ========== API 聊天会话 1-040-003-000 ========== - ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!"); ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整"); // ========== API 聊天消息 1-040-004-000 ========== - ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!"); ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!"); // ========== API 绘画 1-040-005-000 ========== - ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!"); ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}"); ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}"); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java index 66eb24db5f..7da37ebc9b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java @@ -1,6 +1,6 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import com.fhs.core.trans.anno.Trans; import com.fhs.core.trans.constant.TransType; @@ -31,7 +31,7 @@ public class AiChatConversationRespVO implements VO { private Long roleId; @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") - @Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = "name", ref = "modelName") + @Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = "name", ref = "modelName") private Long modelId; @Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "ERNIE-Bot-turbo-0922") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDrawReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDrawReqVO.java index a38935ef71..02225ea496 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDrawReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDrawReqVO.java @@ -14,18 +14,15 @@ import java.util.Map; @Data public class AiImageDrawReqVO { - @Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI") - private String platform; // 参见 AiPlatformEnum 枚举 + @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") + @NotNull(message = "模型编号不能为空") + private Long modelId; @Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "画一个长城") @NotEmpty(message = "提示词不能为空") @Size(max = 1200, message = "提示词最大 1200") private String prompt; - @Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "stable-diffusion-v1-6") - @NotEmpty(message = "模型不能为空") - private String model; - /** * 1. dall-e-2 模型:256x256、512x512、1024x1024 * 2. dall-e-3 模型:1024x1024, 1792x1024, 或 1024x1792 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiApiKeyController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiApiKeyController.java index 2bc190051d..c109b033ca 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiApiKeyController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiApiKeyController.java @@ -6,9 +6,8 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyRespVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -76,9 +75,9 @@ public class AiApiKeyController { @GetMapping("/simple-list") @Operation(summary = "获得 API 密钥分页列表") - public CommonResult> getApiKeySimpleList() { + public CommonResult> getApiKeySimpleList() { List list = apiKeyService.getApiKeyList(); - return success(convertList(list, key -> new AiChatModelRespVO().setId(key.getId()).setName(key.getName()))); + return success(convertList(list, key -> new AiModelRespVO().setId(key.getId()).setName(key.getName()))); } } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiChatModelController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiChatModelController.java deleted file mode 100644 index 08a53b286b..0000000000 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiChatModelController.java +++ /dev/null @@ -1,84 +0,0 @@ -package cn.iocoder.yudao.module.ai.controller.admin.model; - -import cn.iocoder.yudao.framework.common.pojo.CommonResult; -import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; -import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; -import io.swagger.v3.oas.annotations.Operation; -import io.swagger.v3.oas.annotations.Parameter; -import io.swagger.v3.oas.annotations.tags.Tag; -import jakarta.annotation.Resource; -import jakarta.validation.Valid; -import org.springframework.security.access.prepost.PreAuthorize; -import org.springframework.validation.annotation.Validated; -import org.springframework.web.bind.annotation.*; - -import java.util.List; - -import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; -import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; - -@Tag(name = "管理后台 - AI 聊天模型") -@RestController -@RequestMapping("/ai/chat-model") -@Validated -public class AiChatModelController { - - @Resource - private AiChatModelService chatModelService; - - @PostMapping("/create") - @Operation(summary = "创建聊天模型") - @PreAuthorize("@ss.hasPermission('ai:chat-model:create')") - public CommonResult createChatModel(@Valid @RequestBody AiChatModelSaveReqVO createReqVO) { - return success(chatModelService.createChatModel(createReqVO)); - } - - @PutMapping("/update") - @Operation(summary = "更新聊天模型") - @PreAuthorize("@ss.hasPermission('ai:chat-model:update')") - public CommonResult updateChatModel(@Valid @RequestBody AiChatModelSaveReqVO updateReqVO) { - chatModelService.updateChatModel(updateReqVO); - return success(true); - } - - @DeleteMapping("/delete") - @Operation(summary = "删除聊天模型") - @Parameter(name = "id", description = "编号", required = true) - @PreAuthorize("@ss.hasPermission('ai:chat-model:delete')") - public CommonResult deleteChatModel(@RequestParam("id") Long id) { - chatModelService.deleteChatModel(id); - return success(true); - } - - @GetMapping("/get") - @Operation(summary = "获得聊天模型") - @Parameter(name = "id", description = "编号", required = true, example = "1024") - @PreAuthorize("@ss.hasPermission('ai:chat-model:query')") - public CommonResult getChatModel(@RequestParam("id") Long id) { - AiChatModelDO chatModel = chatModelService.getChatModel(id); - return success(BeanUtils.toBean(chatModel, AiChatModelRespVO.class)); - } - - @GetMapping("/page") - @Operation(summary = "获得聊天模型分页") - @PreAuthorize("@ss.hasPermission('ai:chat-model:query')") - public CommonResult> getChatModelPage(@Valid AiChatModelPageReqVO pageReqVO) { - PageResult pageResult = chatModelService.getChatModelPage(pageReqVO); - return success(BeanUtils.toBean(pageResult, AiChatModelRespVO.class)); - } - - @GetMapping("/simple-list") - @Operation(summary = "获得聊天模型列表") - @Parameter(name = "status", description = "状态", required = true, example = "1") - public CommonResult> getChatModelSimpleList(@RequestParam("status") Integer status) { - List list = chatModelService.getChatModelListByStatus(status); - return success(convertList(list, model -> new AiChatModelRespVO().setId(model.getId()) - .setName(model.getName()).setModel(model.getModel()))); - } - -} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiModelController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiModelController.java new file mode 100644 index 0000000000..86dd4d0a61 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiModelController.java @@ -0,0 +1,89 @@ +package cn.iocoder.yudao.module.ai.controller.admin.model; + +import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.framework.common.util.object.BeanUtils; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import jakarta.validation.Valid; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.*; + +import java.util.List; + +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; +import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; + +@Tag(name = "管理后台 - AI 模型") +@RestController +@RequestMapping("/ai/model") +@Validated +public class AiModelController { + + @Resource + private AiModelService modelService; + + @PostMapping("/create") + @Operation(summary = "创建模型") + @PreAuthorize("@ss.hasPermission('ai:model:create')") + public CommonResult createModel(@Valid @RequestBody AiModelSaveReqVO createReqVO) { + return success(modelService.createModel(createReqVO)); + } + + @PutMapping("/update") + @Operation(summary = "更新模型") + @PreAuthorize("@ss.hasPermission('ai:model:update')") + public CommonResult updateModel(@Valid @RequestBody AiModelSaveReqVO updateReqVO) { + modelService.updateModel(updateReqVO); + return success(true); + } + + @DeleteMapping("/delete") + @Operation(summary = "删除模型") + @Parameter(name = "id", description = "编号", required = true) + @PreAuthorize("@ss.hasPermission('ai:model:delete')") + public CommonResult deleteModel(@RequestParam("id") Long id) { + modelService.deleteModel(id); + return success(true); + } + + @GetMapping("/get") + @Operation(summary = "获得模型") + @Parameter(name = "id", description = "编号", required = true, example = "1024") + @PreAuthorize("@ss.hasPermission('ai:model:query')") + public CommonResult getModel(@RequestParam("id") Long id) { + AiModelDO model = modelService.getModel(id); + return success(BeanUtils.toBean(model, AiModelRespVO.class)); + } + + @GetMapping("/page") + @Operation(summary = "获得模型分页") + @PreAuthorize("@ss.hasPermission('ai:model:query')") + public CommonResult> getModelPage(@Valid AiModelPageReqVO pageReqVO) { + PageResult pageResult = modelService.getModelPage(pageReqVO); + return success(BeanUtils.toBean(pageResult, AiModelRespVO.class)); + } + + @GetMapping("/simple-list") + @Operation(summary = "获得模型列表") + @Parameter(name = "type", description = "类型", required = true, example = "1") + @Parameter(name = "platform", description = "平台", example = "midjourney") + public CommonResult> getModelSimpleList( + @RequestParam("type") Integer type, + @RequestParam(value = "platform", required = false) String platform) { + List list = modelService.getModelListByStatusAndType( + CommonStatusEnum.ENABLE.getStatus(), type, platform); + return success(convertList(list, model -> new AiModelRespVO().setId(model.getId()) + .setName(model.getName()).setModel(model.getModel()).setPlatform(model.getPlatform()))); + } + +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java index eb34da2748..d47c66d2bd 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java @@ -1,6 +1,6 @@ package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import com.fhs.core.trans.anno.Trans; import com.fhs.core.trans.constant.TransType; import com.fhs.core.trans.vo.VO; @@ -20,7 +20,7 @@ public class AiChatRoleRespVO implements VO { private Long userId; @Schema(description = "模型编号", example = "17640") - @Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"}) + @Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"}) private Long modelId; @Schema(description = "模型名字", example = "张三") private String modelName; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelPageReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelPageReqVO.java similarity index 67% rename from yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelPageReqVO.java rename to yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelPageReqVO.java index ce2f83b4b1..af8d1121a4 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelPageReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelPageReqVO.java @@ -1,12 +1,12 @@ -package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel; +package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model; import lombok.*; import io.swagger.v3.oas.annotations.media.Schema; import cn.iocoder.yudao.framework.common.pojo.PageParam; -@Schema(description = "管理后台 - API 聊天模型分页 Request VO") +@Schema(description = "管理后台 - API 模型分页 Request VO") @Data -public class AiChatModelPageReqVO extends PageParam { +public class AiModelPageReqVO extends PageParam { @Schema(description = "模型名字", example = "张三") private String name; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelRespVO.java similarity index 83% rename from yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelRespVO.java rename to yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelRespVO.java index 681dabe687..b50b70a087 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelRespVO.java @@ -1,13 +1,13 @@ -package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel; +package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model; import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; import java.time.LocalDateTime; -@Schema(description = "管理后台 - AI 聊天模型 Response VO") +@Schema(description = "管理后台 - AI 模型 Response VO") @Data -public class AiChatModelRespVO { +public class AiModelRespVO { @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2630") private Long id; @@ -24,6 +24,9 @@ public class AiChatModelRespVO { @Schema(description = "模型平台", example = "OpenAI") private String platform; + @Schema(description = "模型类型", example = "1") + private Integer type; + @Schema(description = "排序", example = "1") private Integer sort; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelSaveReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelSaveReqVO.java similarity index 71% rename from yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelSaveReqVO.java rename to yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelSaveReqVO.java index 4fad5a1fc1..281bcd6732 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelSaveReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelSaveReqVO.java @@ -1,14 +1,17 @@ -package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel; +package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model; +import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.validation.InEnum; import io.swagger.v3.oas.annotations.media.Schema; -import lombok.*; -import jakarta.validation.constraints.*; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import lombok.Data; -@Schema(description = "管理后台 - API 聊天模型新增/修改 Request VO") +@Schema(description = "管理后台 - API 模型新增/修改 Request VO") @Data -public class AiChatModelSaveReqVO { +public class AiModelSaveReqVO { @Schema(description = "编号", example = "2630") private Long id; @@ -27,8 +30,14 @@ public class AiChatModelSaveReqVO { @Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI") @NotEmpty(message = "模型平台不能为空") + @InEnum(AiPlatformEnum.class) private String platform; + @Schema(description = "模型类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") + @NotNull(message = "模型类型不能为空") + @InEnum(AiModelTypeEnum.class) + private Integer type; + @Schema(description = "排序", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @NotNull(message = "排序不能为空") private Integer sort; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java index b5e9d06541..a9c956deae 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java @@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableId; @@ -76,13 +76,13 @@ public class AiChatConversationDO extends BaseDO { /** * 模型编号 * - * 关联 {@link AiChatModelDO#getId()} 字段 + * 关联 {@link AiModelDO#getId()} 字段 */ private Long modelId; /** * 模型标志 * - * 冗余 {@link AiChatModelDO#getModel()} 字段 + * 冗余 {@link AiModelDO#getModel()} 字段 */ private String model; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java index c65360ccbb..a026ec8193 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java @@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableField; @@ -83,13 +83,13 @@ public class AiChatMessageDO extends BaseDO { /** * 模型标志 * - * 冗余 {@link AiChatModelDO#getModel()} + * 冗余 {@link AiModelDO#getModel()} */ private String model; /** * 模型编号 * - * 关联 {@link AiChatModelDO#getId()} 字段 + * 关联 {@link AiModelDO#getId()} 字段 */ private Long modelId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java index 5b43bc91e9..a18904c022 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java @@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.image; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO; import com.baomidou.mybatisplus.annotation.KeySequence; @@ -52,11 +52,16 @@ public class AiImageDO extends BaseDO { * 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum} */ private String platform; - // TODO @芋艿:modelId? /** - * 模型 + * 模型编号 * - * 冗余 {@link AiChatModelDO#getModel()} + * 关联 {@link AiModelDO#getId()} + */ + private Long modelId; + /** + * 模型标识 + * + * 冗余 {@link AiModelDO#getModel()} */ private String model; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java index 732d54e24a..e1327a50ef 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java @@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; @@ -35,13 +35,13 @@ public class AiKnowledgeDO extends BaseDO { /** * 向量模型编号 * - * 关联 {@link AiChatModelDO#getId()} + * 关联 {@link AiModelDO#getId()} */ private Long embeddingModelId; /** * 模型标识 * - * 冗余 {@link AiChatModelDO#getModel()} + * 冗余 {@link AiModelDO#getModel()} */ private String embeddingModel; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java index 47aa0ce4a5..6dd5d44302 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java @@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; @@ -36,7 +37,12 @@ public class AiMindMapDO extends BaseDO { * 枚举 {@link AiPlatformEnum} */ private String platform; - // TODO @芋艿:modelId? + /** + * 模型编号 + * + * 关联 {@link AiModelDO#getId()} + */ + private Long modelId; /** * 模型 */ diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java index f5ed533a92..fd35b85795 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java @@ -58,7 +58,7 @@ public class AiChatRoleDO extends BaseDO { /** * 模型编号 * - * 关联 {@link AiChatModelDO#getId()} 字段 + * 关联 {@link AiModelDO#getId()} 字段 */ private Long modelId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatModelDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiModelDO.java similarity index 77% rename from yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatModelDO.java rename to yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiModelDO.java index c3800bd4ca..f4ff09212a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatModelDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiModelDO.java @@ -1,5 +1,6 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.model; +import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; @@ -8,23 +9,22 @@ import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import lombok.*; -// TODO @芋艿,需要改造,增加 type /** - * AI 聊天模型 DO + * AI 模型 DO * - * 默认聊天模型:{@link #status} 为开启,并且 {@link #sort} 排序第一 + * 默认模型:{@link #status} 为开启,并且 {@link #sort} 排序第一 * * @author fansili * @since 2024/4/24 19:39 */ -@TableName("ai_chat_model") -@KeySequence("ai_chat_model_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。 +@TableName("ai_model") +@KeySequence("ai_model_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。 @Data @EqualsAndHashCode(callSuper = true) @Builder @NoArgsConstructor @AllArgsConstructor -public class AiChatModelDO extends BaseDO { +public class AiModelDO extends BaseDO { /** * 编号 @@ -51,6 +51,12 @@ public class AiChatModelDO extends BaseDO { * 枚举 {@link AiPlatformEnum} */ private String platform; + /** + * 类型 + * + * 枚举 {@link AiModelTypeEnum} + */ + private Integer type; /** * 排序值 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java index a99f94e82d..e07f994aad 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java @@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.write; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; +import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableId; @@ -44,7 +46,12 @@ public class AiWriteDO extends BaseDO { * 枚举 {@link AiPlatformEnum} */ private String platform; - // TODO @芋艿:modelId? + /** + * 模型编号 + * + * 关联 {@link AiModelDO#getId()} + */ + private Long modelId; /** * 模型 */ @@ -67,25 +74,25 @@ public class AiWriteDO extends BaseDO { /** * 长度提示词 * - * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LENGTH} + * 字典:{@link DictTypeConstants#AI_WRITE_LENGTH} */ private Integer length; /** * 格式提示词 * - * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_FORMAT} + * 字典:{@link DictTypeConstants#AI_WRITE_FORMAT} */ private Integer format; /** * 语气提示词 * - * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_TONE} + * 字典:{@link DictTypeConstants#AI_WRITE_TONE} */ private Integer tone; /** * 语言提示词 * - * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LANGUAGE} + * 字典:{@link DictTypeConstants#AI_WRITE_LANGUAGE} */ private Integer language; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatMapper.java new file mode 100644 index 0000000000..bfe2caf52a --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatMapper.java @@ -0,0 +1,47 @@ +package cn.iocoder.yudao.module.ai.dal.mysql.model; + +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; +import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; +import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; +import org.apache.ibatis.annotations.Mapper; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * API 模型 Mapper + * + * @author fansili + */ +@Mapper +public interface AiChatMapper extends BaseMapperX { + + default AiModelDO selectFirstByStatus(Integer type, Integer status) { + return selectOne(new QueryWrapperX() + .eq("type", type) + .eq("status", status) + .limitN(1) + .orderByAsc("sort")); + } + + default PageResult selectPage(AiModelPageReqVO reqVO) { + return selectPage(reqVO, new LambdaQueryWrapperX() + .likeIfPresent(AiModelDO::getName, reqVO.getName()) + .eqIfPresent(AiModelDO::getModel, reqVO.getModel()) + .eqIfPresent(AiModelDO::getPlatform, reqVO.getPlatform()) + .orderByAsc(AiModelDO::getSort)); + } + + default List selectListByStatusAndType(Integer status, Integer type, + @Nullable String platform) { + return selectList(new LambdaQueryWrapperX() + .eq(AiModelDO::getStatus, status) + .eq(AiModelDO::getType, type) + .eqIfPresent(AiModelDO::getPlatform, platform) + .orderByAsc(AiModelDO::getSort)); + } + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatModelMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatModelMapper.java deleted file mode 100644 index a3136fa9f6..0000000000 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatModelMapper.java +++ /dev/null @@ -1,43 +0,0 @@ -package cn.iocoder.yudao.module.ai.dal.mysql.model; - -import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; -import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; -import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; -import org.apache.ibatis.annotations.Mapper; - -import java.util.Collection; -import java.util.List; - -/** - * API 聊天模型 Mapper - * - * @author fansili - */ -@Mapper -public interface AiChatModelMapper extends BaseMapperX { - - default AiChatModelDO selectFirstByStatus(Integer status) { - return selectOne(new QueryWrapperX() - .eq("status", status) - .limitN(1) - .orderByAsc("sort")); - } - - default PageResult selectPage(AiChatModelPageReqVO reqVO) { - return selectPage(reqVO, new LambdaQueryWrapperX() - .likeIfPresent(AiChatModelDO::getName, reqVO.getName()) - .eqIfPresent(AiChatModelDO::getModel, reqVO.getModel()) - .eqIfPresent(AiChatModelDO::getPlatform, reqVO.getPlatform()) - .orderByAsc(AiChatModelDO::getSort)); - } - - default List selectList(Integer status) { - return selectList(new LambdaQueryWrapperX() - .eq(AiChatModelDO::getStatus, status) - .orderByAsc(AiChatModelDO::getSort)); - } - -} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java index 8f094087f1..792a8ae150 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java @@ -4,17 +4,18 @@ import cn.hutool.core.collection.CollUtil; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjectUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService; -import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; @@ -44,7 +45,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService private AiChatConversationMapper chatConversationMapper; @Resource - private AiChatModelService chatModalService; + private AiModelService chatModalService; @Resource private AiChatRoleService chatRoleService; @Resource @@ -54,9 +55,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) { // 1.1 获得 AiChatRoleDO 聊天角色 AiChatRoleDO role = createReqVO.getRoleId() != null ? chatRoleService.validateChatRole(createReqVO.getRoleId()) : null; - // 1.2 获得 AiChatModelDO 聊天模型 - AiChatModelDO model = role != null && role.getModelId() != null ? chatModalService.validateChatModel(role.getModelId()) - : chatModalService.getRequiredDefaultChatModel(); + // 1.2 获得 AiModelDO 聊天模型 + AiModelDO model = role != null && role.getModelId() != null ? chatModalService.validateModel(role.getModelId()) + : chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType()); Assert.notNull(model, "必须找到默认模型"); validateChatModel(model); @@ -86,9 +87,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService throw exception(CHAT_CONVERSATION_NOT_EXISTS); } // 1.2 校验模型是否存在(修改模型的情况) - AiChatModelDO model = null; + AiModelDO model = null; if (updateReqVO.getModelId() != null) { - model = chatModalService.validateChatModel(updateReqVO.getModelId()); + model = chatModalService.validateModel(updateReqVO.getModelId()); } // 1.3 校验知识库是否存在 @@ -139,7 +140,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService chatConversationMapper.deleteById(id); } - private void validateChatModel(AiChatModelDO model) { + private void validateChatModel(AiModelDO model) { if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) { return; } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 1c8927e07a..2298d9b1d9 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -15,13 +15,12 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; -import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; -import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; @@ -63,9 +62,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { @Resource private AiChatConversationService chatConversationService; @Resource - private AiChatModelService chatModalService; - @Resource - private AiApiKeyService apiKeyService; + private AiModelService modalService; @Resource private AiKnowledgeSegmentService knowledgeSegmentService; @@ -78,8 +75,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); // 1.2 校验模型 - AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); - ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); + AiModelDO model = modalService.validateModel(conversation.getModelId()); + ChatModel chatModel = modalService.getChatModel(model.getKeyId()); // 2. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, @@ -112,8 +109,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); // 1.2 校验模型 - AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); - StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); + AiModelDO model = modalService.validateModel(conversation.getModelId()); + StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId()); // 2. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, @@ -161,8 +158,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { return null; } - private Prompt buildPrompt(AiChatConversationDO conversation, List messages,List segmentList, - AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { + private Prompt buildPrompt(AiChatConversationDO conversation, List messages, List segmentList, + AiModelDO model, AiChatMessageSendReqVO sendReqVO) { // 1. 构建 Prompt Message 列表 List chatMessages = new ArrayList<>(); @@ -232,7 +229,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } private AiChatMessageDO createChatMessage(Long conversationId, Long replyId, - AiChatModelDO model, Long userId, Long roleId, + AiModelDO model, Long userId, Long roleId, MessageType messageType, String content, Boolean useContext) { AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId) .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 523014d530..7204f4d5b2 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -12,13 +12,17 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; -import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; import cn.iocoder.yudao.module.infra.api.file.FileApi; import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions; import jakarta.annotation.Resource; @@ -54,15 +58,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*; @Slf4j public class AiImageServiceImpl implements AiImageService { + @Resource + private AiModelService modelService; + @Resource private AiImageMapper imageMapper; @Resource private FileApi fileApi; - @Resource - private AiApiKeyService apiKeyService; - @Override public PageResult getImagePageMy(Long userId, AiImagePageReqVO pageReqVO) { return imageMapper.selectPageMy(userId, pageReqVO); @@ -88,23 +92,31 @@ public class AiImageServiceImpl implements AiImageService { @Override public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) { - // 1. 保存数据库 - AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) - .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); + // 1. 校验模型 + AiModelDO model = modelService.validateModel(drawReqVO.getModelId()); + + // 2. 保存数据库 + AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId) + .setPlatform(model.getPlatform()).setModelId(model.getId()).setModel(model.getModel()) + .setPublicStatus(false).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); imageMapper.insert(image); - // 2. 异步绘制,后续前端通过返回的 id 进行轮询结果 - getSelf().executeDrawImage(image, drawReqVO); + + // 3. 异步绘制,后续前端通过返回的 id 进行轮询结果 + getSelf().executeDrawImage(image, drawReqVO, model); return image.getId(); } @Async - public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) { + public void executeDrawImage(AiImageDO image, AiImageDrawReqVO reqVO, AiModelDO model) { try { // 1.1 构建请求 - ImageOptions request = buildImageOptions(req); + ImageOptions request = buildImageOptions(reqVO, model); // 1.2 执行请求 - ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform())); - ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request)); + ImageModel imageModel = modelService.getImageModel(model.getId()); + ImageResponse response = imageModel.call(new ImagePrompt(reqVO.getPrompt(), request)); + if (response.getResult() == null) { + throw new IllegalArgumentException("生成结果为空"); + } // 2. 上传到文件服务 String b64Json = response.getResult().getOutput().getB64Json(); @@ -116,25 +128,25 @@ public class AiImageServiceImpl implements AiImageService { imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus()) .setPicUrl(filePath).setFinishTime(LocalDateTime.now())); } catch (Exception ex) { - log.error("[doDall][image({}) 生成异常]", image, ex); + log.error("[executeDrawImage][image({}) 生成异常]", image, ex); imageMapper.updateById(new AiImageDO().setId(image.getId()) .setStatus(AiImageStatusEnum.FAIL.getStatus()) .setErrorMessage(ex.getMessage()).setFinishTime(LocalDateTime.now())); } } - private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) { - if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) { + private static ImageOptions buildImageOptions(AiImageDrawReqVO draw, AiModelDO model) { + if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) { // https://platform.openai.com/docs/api-reference/images/create - return OpenAiImageOptions.builder().withModel(draw.getModel()) + return OpenAiImageOptions.builder().withModel(model.getModel()) .withHeight(draw.getHeight()).withWidth(draw.getWidth()) .withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格 .withResponseFormat("b64_json") .build(); - } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) { + } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) { // https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage // https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage - return StabilityAiImageOptions.builder().model(draw.getModel()) + return StabilityAiImageOptions.builder().model(model.getModel()) .height(draw.getHeight()).width(draw.getWidth()) .seed(Long.valueOf(draw.getOptions().get("seed"))) .cfgScale(Float.valueOf(draw.getOptions().get("scale"))) @@ -143,22 +155,22 @@ public class AiImageServiceImpl implements AiImageService { .stylePreset(String.valueOf(draw.getOptions().get("stylePreset"))) .clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset"))) .build(); - } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) { + } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) { return DashScopeImageOptions.builder() - .withModel(draw.getModel()).withN(1) + .withModel(model.getModel()).withN(1) .withHeight(draw.getHeight()).withWidth(draw.getWidth()) .build(); - } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) { + } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) { return QianFanImageOptions.builder() - .model(draw.getModel()).N(1) + .model(model.getModel()).N(1) .height(draw.getHeight()).width(draw.getWidth()) .build(); - } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) { + } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) { return ZhiPuAiImageOptions.builder() - .model(draw.getModel()) + .model(model.getModel()) .build(); } - throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); + throw new IllegalArgumentException("不支持的 AI 平台:" + model.getPlatform()); } @Override @@ -206,7 +218,7 @@ public class AiImageServiceImpl implements AiImageService { @Override @Transactional(rollbackFor = Exception.class) public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) { - MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi(); + MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(); // 1. 保存数据库 AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()) @@ -237,7 +249,7 @@ public class AiImageServiceImpl implements AiImageService { @Override public Integer midjourneySync() { - MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi(); + MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(); // 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image List imageList = imageMapper.selectListByStatusAndPlatform( AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform()); @@ -308,7 +320,7 @@ public class AiImageServiceImpl implements AiImageService { @Override public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) { - MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi(); + MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(); // 1.1 检查 image AiImageDO image = validateImageExists(reqVO.getId()); if (ObjUtil.notEqual(userId, image.getUserId())) { diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java index adf1616cb0..910ff1249a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java @@ -7,14 +7,17 @@ import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.*; +import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO; +import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSaveReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; -import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.document.Document; @@ -33,8 +36,8 @@ import java.util.Objects; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; -import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG; +import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS; /** * AI 知识库分片 Service 实现类 @@ -58,7 +61,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService @Lazy // 延迟加载,避免循环依赖 private AiKnowledgeDocumentService knowledgeDocumentService; @Resource - private AiApiKeyService apiKeyService; + private AiModelService modelService; @Resource private TokenCountEstimator tokenCountEstimator; @@ -180,7 +183,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId()); // 2.1 向量检索 - VectorStore vectorStore = apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId()); + VectorStore vectorStore = getVectorStoreById(knowledge); List documents = vectorStore.similaritySearch(SearchRequest.builder() .query(reqBO.getContent()) .topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK())) @@ -251,11 +254,12 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService } private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) { - return apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId()); + return modelService.getOrCreateVectorStore(knowledge.getEmbeddingModelId()); } private VectorStore getVectorStoreById(Long knowledgeId) { - return getVectorStoreById(knowledgeService.validateKnowledgeExists(knowledgeId)); + AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(knowledgeId); + return getVectorStoreById(knowledge); } private static List splitContentByToken(String content, Integer segmentMaxTokens) { diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java index 900cfb4a17..55822796ce 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java @@ -5,9 +5,9 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper; -import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -28,12 +28,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService { private AiKnowledgeMapper knowledgeMapper; @Resource - private AiChatModelService chatModelService; + private AiModelService chatModelService; @Override public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) { // 1. 校验模型配置 - AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getEmbeddingModelId()); + AiModelDO model = chatModelService.validateModel(createReqVO.getEmbeddingModelId()); // 2. 插入知识库 AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class) @@ -47,7 +47,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService { // 1.1 校验知识库存在 validateKnowledgeExists(updateReqVO.getId()); // 1.2 校验模型配置 - AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getEmbeddingModelId()); + AiModelDO model = chatModelService.validateModel(updateReqVO.getEmbeddingModelId()); // 2. 更新知识库 AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java index 5536831469..3eb49441b5 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java @@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.mindmap; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; @@ -12,14 +13,13 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper; import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; -import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; -import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; @@ -50,9 +50,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_E public class AiMindMapServiceImpl implements AiMindMapService { @Resource - private AiApiKeyService apiKeyService; - @Resource - private AiChatModelService chatModalService; + private AiModelService modalService; @Resource private AiChatRoleService chatRoleService; @@ -65,17 +63,17 @@ public class AiMindMapServiceImpl implements AiMindMapService { AiChatRoleDO role = CollUtil.getFirst( chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); // 1.1 获取导图执行模型 - AiChatModelDO model = getModel(role); + AiModelDO model = getModel(role); // 1.2 获取角色设定消息 String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage()) ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); - ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); + ChatModel chatModel = modalService.getChatModel(model.getId()); // 2. 插入思维导图信息 - AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, - mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); + AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, mindMap -> mindMap.setUserId(userId) + .setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel())); mindMapMapper.insert(mindMapDO); // 3.1 构建 Prompt,并进行调用 @@ -103,7 +101,7 @@ public class AiMindMapServiceImpl implements AiMindMapService { } - private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { + private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) { // 1. 构建 message 列表 List chatMessages = buildMessages(generateReqVO, systemMessage); // 2. 构建 options 对象 @@ -123,13 +121,13 @@ public class AiMindMapServiceImpl implements AiMindMapService { return chatMessages; } - private AiChatModelDO getModel(AiChatRoleDO role) { - AiChatModelDO model = null; + private AiModelDO getModel(AiChatRoleDO role) { + AiModelDO model = null; if (role != null && role.getModelId() != null) { - model = chatModalService.getChatModel(role.getModelId()); + model = modalService.getModel(role.getModelId()); } if (model == null) { - model = chatModalService.getRequiredDefaultChatModel(); + model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType()); } Assert.notNull(model, "[AI] 获取不到模型"); return model; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java index 29d9e3e07a..44da80041c 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java @@ -1,16 +1,10 @@ package cn.iocoder.yudao.module.ai.service.model; -import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; -import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import jakarta.validation.Valid; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.image.ImageModel; -import org.springframework.ai.vectorstore.VectorStore; import java.util.List; @@ -74,50 +68,13 @@ public interface AiApiKeyService { */ List getApiKeyList(); - // ========== 与 spring-ai 集成 ========== - /** - * 获得 ChatModel 对象 - * - * @param id 编号 - * @return ChatModel 对象 - */ - ChatModel getChatModel(Long id); - - /** - * 获得 ImageModel 对象 - * - * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择 + * 获得默认的 API 密钥 * * @param platform 平台 - * @return ImageModel 对象 + * @param status 状态 + * @return API 密钥 */ - ImageModel getImageModel(AiPlatformEnum platform); - - /** - * 获得 MidjourneyApi 对象 - * - * TODO 可优化点:目前默认获取 Midjourney 对应的第一个开启的配置用于绘画;后续可以支持配置选择 - * - * @return MidjourneyApi 对象 - */ - MidjourneyApi getMidjourneyApi(); - - /** - * 获得 SunoApi 对象 - * - * TODO 可优化点:目前默认获取 Suno 对应的第一个开启的配置用于音乐;后续可以支持配置选择 - * - * @return SunoApi 对象 - */ - SunoApi getSunoApi(); - - /** - * 获得 VectorStore 对象 - * - * @param modelId 编号 - * @return VectorStore 对象 - */ - VectorStore getOrCreateVectorStoreByModelId(Long modelId); + AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status); } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java index 38d6c95495..f0bac8a6d9 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java @@ -1,31 +1,21 @@ package cn.iocoder.yudao.module.ai.service.model; -import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; -import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; -import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import jakarta.annotation.Resource; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.image.ImageModel; -import org.springframework.ai.vectorstore.SimpleVectorStore; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; import java.util.List; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; -import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*; +import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_DISABLE; +import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_NOT_EXISTS; /** * AI API 密钥 Service 实现类 @@ -39,14 +29,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { @Resource private AiApiKeyMapper apiKeyMapper; - // TODO @芋艿:后续要不要改? - @Resource - @Lazy // 延迟加载,解决渲染依赖 - private AiChatModelService chatModelService; - - @Resource - private AiModelFactory modelFactory; - @Override public Long createApiKey(AiApiKeySaveReqVO createReqVO) { // 插入 @@ -105,57 +87,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { return apiKeyMapper.selectList(); } - // ========== 与 spring-ai 集成 ========== - @Override - public ChatModel getChatModel(Long id) { - AiApiKeyDO apiKey = validateApiKey(id); - AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); - return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl()); - } - - @Override - public ImageModel getImageModel(AiPlatformEnum platform) { - AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); + public AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status) { + AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform, status); if (apiKey == null) { - throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName()); + throw exception(API_KEY_NOT_EXISTS); } - return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl()); - } - - @Override - public MidjourneyApi getMidjourneyApi() { - AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus( - AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); - if (apiKey == null) { - throw exception(API_KEY_MIDJOURNEY_NOT_FOUND); - } - return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); - } - - @Override - public SunoApi getSunoApi() { - AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus( - AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); - if (apiKey == null) { - throw exception(API_KEY_SUNO_NOT_FOUND); - } - return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); - } - - @Override - public VectorStore getOrCreateVectorStoreByModelId(Long modelId) { - // 获取模型 + 密钥 - AiChatModelDO chatModel = chatModelService.validateChatModel(modelId); - AiApiKeyDO apiKey = validateApiKey(chatModel.getKeyId()); - AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); - - // 创建或获取 EmbeddingModel 对象 - EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), - apiKey.getUrl(), chatModel.getModel()); - - // 创建或获取 VectorStore 对象 - return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel); + return apiKey; } } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java deleted file mode 100644 index f83ac73c91..0000000000 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java +++ /dev/null @@ -1,92 +0,0 @@ -package cn.iocoder.yudao.module.ai.service.model; - -import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; -import jakarta.validation.Valid; - -import java.util.Collection; -import java.util.List; - -import java.util.Set; - -/** - * AI 聊天模型 Service 接口 - * - * @author fansili - * @since 2024/4/24 19:42 - */ -public interface AiChatModelService { - - /** - * 创建聊天模型 - * - * @param createReqVO 创建信息 - * @return 编号 - */ - Long createChatModel(@Valid AiChatModelSaveReqVO createReqVO); - - /** - * 更新聊天模型 - * - * @param updateReqVO 更新信息 - */ - void updateChatModel(@Valid AiChatModelSaveReqVO updateReqVO); - - /** - * 删除聊天模型 - * - * @param id 编号 - */ - void deleteChatModel(Long id); - - /** - * 获得聊天模型 - * - * @param id 编号 - * @return 聊天模型 - */ - AiChatModelDO getChatModel(Long id); - - /** - * 获得默认的聊天模型 - * - * 如果获取不到,则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常 - * - * @return 聊天模型 - */ - AiChatModelDO getRequiredDefaultChatModel(); - - /** - * 获得聊天模型分页 - * - * @param pageReqVO 分页查询 - * @return 聊天模型分页 - */ - PageResult getChatModelPage(AiChatModelPageReqVO pageReqVO); - - /** - * 校验聊天模型 - * - * @param id 编号 - * @return 聊天模型 - */ - AiChatModelDO validateChatModel(Long id); - - /** - * 获得聊天模型列表 - * - * @param status 状态 - * @return 聊天模型列表 - */ - List getChatModelListByStatus(Integer status); - - /** - * 获得聊天模型列表 - * - * @param ids 编号数组 - * @return 模型列表 - */ - List getChatModelList(Collection ids); -} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java index 4b11602f54..9140882a1f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java @@ -1,114 +1,168 @@ package cn.iocoder.yudao.module.ai.service.model; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; +import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; +import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; -import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatModelMapper; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; +import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatMapper; import jakarta.annotation.Resource; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.image.ImageModel; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; -import java.util.Collection; import java.util.List; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*; /** - * AI 聊天模型 Service 实现类 + * AI 模型 Service 实现类 * * @author fansili */ @Service @Validated -public class AiChatModelServiceImpl implements AiChatModelService { +public class AiChatModelServiceImpl implements AiModelService { @Resource private AiApiKeyService apiKeyService; @Resource - private AiChatModelMapper chatModelMapper; + private AiChatMapper modelMapper; + + @Resource + private AiModelFactory modelFactory; @Override - public Long createChatModel(AiChatModelSaveReqVO createReqVO) { + public Long createModel(AiModelSaveReqVO createReqVO) { // 1. 校验 AiPlatformEnum.validatePlatform(createReqVO.getPlatform()); apiKeyService.validateApiKey(createReqVO.getKeyId()); // 2. 插入 - AiChatModelDO chatModel = BeanUtils.toBean(createReqVO, AiChatModelDO.class); - chatModelMapper.insert(chatModel); - return chatModel.getId(); + AiModelDO model = BeanUtils.toBean(createReqVO, AiModelDO.class); + modelMapper.insert(model); + return model.getId(); } @Override - public void updateChatModel(AiChatModelSaveReqVO updateReqVO) { + public void updateModel(AiModelSaveReqVO updateReqVO) { // 1. 校验 - validateChatModelExists(updateReqVO.getId()); + validateModelExists(updateReqVO.getId()); AiPlatformEnum.validatePlatform(updateReqVO.getPlatform()); apiKeyService.validateApiKey(updateReqVO.getKeyId()); // 2. 更新 - AiChatModelDO updateObj = BeanUtils.toBean(updateReqVO, AiChatModelDO.class); - chatModelMapper.updateById(updateObj); + AiModelDO updateObj = BeanUtils.toBean(updateReqVO, AiModelDO.class); + modelMapper.updateById(updateObj); } @Override - public void deleteChatModel(Long id) { + public void deleteModel(Long id) { // 校验存在 - validateChatModelExists(id); + validateModelExists(id); // 删除 - chatModelMapper.deleteById(id); + modelMapper.deleteById(id); } - private AiChatModelDO validateChatModelExists(Long id) { - AiChatModelDO model = chatModelMapper.selectById(id); - if (chatModelMapper.selectById(id) == null) { - throw exception(CHAT_MODEL_NOT_EXISTS); + private AiModelDO validateModelExists(Long id) { + AiModelDO model = modelMapper.selectById(id); + if (modelMapper.selectById(id) == null) { + throw exception(MODEL_NOT_EXISTS); } return model; } @Override - public AiChatModelDO getChatModel(Long id) { - return chatModelMapper.selectById(id); + public AiModelDO getModel(Long id) { + return modelMapper.selectById(id); } @Override - public AiChatModelDO getRequiredDefaultChatModel() { - AiChatModelDO model = chatModelMapper.selectFirstByStatus(CommonStatusEnum.ENABLE.getStatus()); + public AiModelDO getRequiredDefaultModel(Integer type) { + AiModelDO model = modelMapper.selectFirstByStatus(type, CommonStatusEnum.ENABLE.getStatus()); if (model == null) { - throw exception(CHAT_MODEL_DEFAULT_NOT_EXISTS); + throw exception(MODEL_DEFAULT_NOT_EXISTS); } return model; } @Override - public PageResult getChatModelPage(AiChatModelPageReqVO pageReqVO) { - return chatModelMapper.selectPage(pageReqVO); + public PageResult getModelPage(AiModelPageReqVO pageReqVO) { + return modelMapper.selectPage(pageReqVO); } @Override - public AiChatModelDO validateChatModel(Long id) { - AiChatModelDO model = validateChatModelExists(id); + public AiModelDO validateModel(Long id) { + AiModelDO model = validateModelExists(id); if (CommonStatusEnum.isDisable(model.getStatus())) { - throw exception(CHAT_MODEL_DISABLE); + throw exception(MODEL_DISABLE); } return model; } @Override - public List getChatModelListByStatus(Integer status) { - return chatModelMapper.selectList(status); + public List getModelListByStatusAndType(Integer status, Integer type, + String platform) { + return modelMapper.selectListByStatusAndType(status, type, platform); + } + + // ========== 与 Spring AI 集成 ========== + + @Override + public ChatModel getChatModel(Long id) { + AiModelDO model = validateModel(id); + AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId()); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); + return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl()); } @Override - public List getChatModelList(Collection ids) { - return chatModelMapper.selectBatchIds(ids); + public ImageModel getImageModel(Long id) { + AiModelDO model = validateModel(id); + AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId()); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); + return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl()); + } + + @Override + public MidjourneyApi getMidjourneyApi() { + AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey( + AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); + return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); + } + + @Override + public SunoApi getSunoApi() { + AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey( + AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); + return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); + } + + @Override + public VectorStore getOrCreateVectorStore(Long id) { + // 获取模型 + 密钥 + AiModelDO model = validateModel(id); + AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId()); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); + + // 创建或获取 EmbeddingModel 对象 + EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel( + platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel()); + + // 创建或获取 VectorStore 对象 + return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel); } } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java new file mode 100644 index 0000000000..e3746b14d2 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java @@ -0,0 +1,131 @@ +package cn.iocoder.yudao.module.ai.service.model; + +import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; +import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; +import jakarta.validation.Valid; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.image.ImageModel; +import org.springframework.ai.vectorstore.VectorStore; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * AI 模型 Service 接口 + * + * @author fansili + * @since 2024/4/24 19:42 + */ +public interface AiModelService { + + /** + * 创建模型 + * + * @param createReqVO 创建信息 + * @return 编号 + */ + Long createModel(@Valid AiModelSaveReqVO createReqVO); + + /** + * 更新模型 + * + * @param updateReqVO 更新信息 + */ + void updateModel(@Valid AiModelSaveReqVO updateReqVO); + + /** + * 删除模型 + * + * @param id 编号 + */ + void deleteModel(Long id); + + /** + * 获得模型 + * + * @param id 编号 + * @return 模型 + */ + AiModelDO getModel(Long id); + + /** + * 获得默认的模型 + * + * 如果获取不到,则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常 + * + * @return 模型 + */ + AiModelDO getRequiredDefaultModel(Integer type); + + /** + * 获得模型分页 + * + * @param pageReqVO 分页查询 + * @return 模型分页 + */ + PageResult getModelPage(AiModelPageReqVO pageReqVO); + + /** + * 校验模型是否可使用 + * + * @param id 编号 + * @return 模型 + */ + AiModelDO validateModel(Long id); + + /** + * 获得模型列表 + * + * @param status 状态 + * @param type 类型 + * @param platform 平台,允许空 + * @return 模型列表 + */ + List getModelListByStatusAndType(Integer status, Integer type, + @Nullable String platform); + + // ========== 与 Spring AI 集成 ========== + + /** + * 获得 ChatModel 对象 + * + * @param id 编号 + * @return ChatModel 对象 + */ + ChatModel getChatModel(Long id); + + /** + * 获得 ImageModel 对象 + * + * @param id 编号 + * @return ImageModel 对象 + */ + ImageModel getImageModel(Long id); + + /** + * 获得 MidjourneyApi 对象 + * + * @return MidjourneyApi 对象 + */ + MidjourneyApi getMidjourneyApi(); + + /** + * 获得 SunoApi 对象 + * + * @return SunoApi 对象 + */ + SunoApi getSunoApi(); + + /** + * 获得 VectorStore 对象 + * + * @param id 编号 + * @return VectorStore 对象 + */ + VectorStore getOrCreateVectorStore(Long id); + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java index 3f10ec8402..e4ff81a477 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java @@ -16,7 +16,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO; import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper; import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum; import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum; -import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; import cn.iocoder.yudao.module.infra.api.file.FileApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; @@ -41,7 +41,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MUSIC_NOT_EXIS public class AiMusicServiceImpl implements AiMusicService { @Resource - private AiApiKeyService apiKeyService; + private AiModelService modelService; @Resource private AiMusicMapper musicMapper; @@ -53,7 +53,7 @@ public class AiMusicServiceImpl implements AiMusicService { @Transactional(rollbackFor = Exception.class) public List generateMusic(Long userId, AiSunoGenerateReqVO reqVO) { // 1. 调用 Suno 生成音乐 - SunoApi sunoApi = apiKeyService.getSunoApi(); + SunoApi sunoApi = modelService.getSunoApi(); List musicDataList; if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) { // 1.1 描述模式 @@ -88,7 +88,7 @@ public class AiMusicServiceImpl implements AiMusicService { log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size()); // GET 请求,为避免参数过长,分批次处理 - SunoApi sunoApi = apiKeyService.getSunoApi(); + SunoApi sunoApi = modelService.getSunoApi(); CollUtil.split(streamingTask, 36).forEach(chunkList -> { Map taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId); List musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet())); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index f238bd3e30..ceae2fc9c9 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.write; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; @@ -11,17 +12,16 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWritePageReqVO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; -import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; -import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; +import cn.iocoder.yudao.module.ai.service.model.AiModelService; import cn.iocoder.yudao.module.system.api.dict.DictDataApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; @@ -54,17 +54,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.WRITE_NOT_EXIS public class AiWriteServiceImpl implements AiWriteService { @Resource - private AiApiKeyService apiKeyService; - @Resource - private AiChatModelService chatModalService; + private AiModelService chatModalService; @Resource private AiChatRoleService chatRoleService; @Resource - private DictDataApi dictDataApi; + private AiWriteMapper writeMapper; @Resource - private AiWriteMapper writeMapper; + private DictDataApi dictDataApi; @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { @@ -72,17 +70,17 @@ public class AiWriteServiceImpl implements AiWriteService { AiChatRoleDO writeRole = CollUtil.getFirst( chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); // 1.1 获取写作执行模型 - AiChatModelDO model = getModel(writeRole); + AiModelDO model = getModel(writeRole); // 1.2 获取角色设定消息 String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage()) ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); - StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); + StreamingChatModel chatModel = chatModalService.getChatModel(model.getKeyId()); // 2. 插入写作信息 - AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, - write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel())); + AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId) + .setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel())); writeMapper.insert(writeDO); // 3.1 构建 Prompt,并进行调用 @@ -109,19 +107,19 @@ public class AiWriteServiceImpl implements AiWriteService { }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); } - private AiChatModelDO getModel(AiChatRoleDO writeRole) { - AiChatModelDO model = null; + private AiModelDO getModel(AiChatRoleDO writeRole) { + AiModelDO model = null; if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { - model = chatModalService.getChatModel(writeRole.getModelId()); + model = chatModalService.getModel(writeRole.getModelId()); } if (model == null) { - model = chatModalService.getRequiredDefaultChatModel(); + model = chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType()); } Assert.notNull(model, "[AI] 获取不到模型"); return model; } - private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { + private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) { // 1. 构建 message 列表 List chatMessages = buildMessages(generateReqVO, systemMessage); // 2. 构建 options 对象 diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiModelTypeEnum.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiModelTypeEnum.java new file mode 100644 index 0000000000..4f7a4e462d --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiModelTypeEnum.java @@ -0,0 +1,41 @@ +package cn.iocoder.yudao.framework.ai.core.enums; + +import cn.iocoder.yudao.framework.common.core.ArrayValuable; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.Arrays; + +/** + * AI 模型类型的枚举 + * + * @author 芋道源码 + */ +@Getter +@RequiredArgsConstructor +public enum AiModelTypeEnum implements ArrayValuable { + + CHAT(1, "对话"), + IMAGE(2, "图片"), + VOICE(3, "语音"), + VIDEO(4, "视频"), + EMBEDDING(5, "向量"), + RERANK(6, "重排序"); + + /** + * 类型 + */ + private final Integer type; + /** + * 类型名 + */ + private final String name; + + public static final Integer[] ARRAYS = Arrays.stream(values()).map(AiModelTypeEnum::getType).toArray(Integer[]::new); + + @Override + public Integer[] array() { + return ARRAYS; + } + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java index 7b479d8f1a..a154db8c88 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java @@ -1,8 +1,11 @@ package cn.iocoder.yudao.framework.ai.core.enums; +import cn.iocoder.yudao.framework.common.core.ArrayValuable; import lombok.AllArgsConstructor; import lombok.Getter; +import java.util.Arrays; + /** * AI 模型平台 * @@ -10,7 +13,7 @@ import lombok.Getter; */ @Getter @AllArgsConstructor -public enum AiPlatformEnum { +public enum AiPlatformEnum implements ArrayValuable { // ========== 国内平台 ========== @@ -44,6 +47,8 @@ public enum AiPlatformEnum { */ private final String name; + public static final String[] ARRAYS = Arrays.stream(values()).map(AiPlatformEnum::getPlatform).toArray(String[]::new); + public static AiPlatformEnum validatePlatform(String platform) { for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) { if (platformEnum.getPlatform().equals(platform)) { @@ -53,4 +58,9 @@ public enum AiPlatformEnum { throw new IllegalArgumentException("非法平台: " + platform); } + @Override + public String[] array() { + return ARRAYS; + } + } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index 1f9b1f21a2..ffc052c5d3 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -456,25 +456,4 @@ public class AiModelFactoryImpl implements AiModelFactory { return vectorStore; } - /** - * 创建向量存储文件 - * - * @param embeddingModel 嵌入模型 - * @return 向量存储文件 - */ - private File createVectorStoreFile(EmbeddingModel embeddingModel) { - // 获取简单类名 - String simpleClassName = embeddingModel.getClass().getSimpleName(); - // 获取用户主目录 - String userHome = FileUtil.getUserHomePath(); - // 创建vector_store目录 - File vectorStoreDir = new File(userHome, "vector_store"); - if (!vectorStoreDir.exists()) { - vectorStoreDir.mkdirs(); - } - - // 创建文件 - return new File(vectorStoreDir, "simple_" + simpleClassName + ".json"); - } - }