【代码重构】AI:“聊天模型”重构为“模型”,支持 type 模型类型

This commit is contained in:
YunaiV 2025-03-03 21:26:31 +08:00
parent 3f460dc620
commit 89d079349c
36 changed files with 617 additions and 554 deletions

View File

@ -12,31 +12,25 @@ public interface ErrorCodeConstants {
// ========== API 密钥 1-040-000-000 ==========
ErrorCode API_KEY_NOT_EXISTS = new ErrorCode(1_040_000_000, "API 密钥不存在");
ErrorCode API_KEY_DISABLE = new ErrorCode(1_040_000_001, "API 密钥已禁用!");
ErrorCode API_KEY_MIDJOURNEY_NOT_FOUND = new ErrorCode(1_040_000_900, "Midjourney 模型不存在");
ErrorCode API_KEY_SUNO_NOT_FOUND = new ErrorCode(1_040_000_901, "Suno 模型不存在");
ErrorCode API_KEY_IMAGE_NODE_FOUND = new ErrorCode(1_040_000_902, "平台({}) 图片模型未配置");
// ========== API 聊天模型 1-040-001-000 ==========
ErrorCode CHAT_MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!");
ErrorCode CHAT_MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!");
ErrorCode CHAT_MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认聊天模型");
// ========== API 模型 1-040-001-000 ==========
ErrorCode MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!");
ErrorCode MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!");
ErrorCode MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认模型");
// ========== API 聊天角色 1-040-002-000 ==========
ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在");
ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!");
// ========== API 聊天会话 1-040-003-000 ==========
ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!");
ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整");
// ========== API 聊天消息 1-040-004-000 ==========
ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!");
// ========== API 绘画 1-040-005-000 ==========
ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.fhs.core.trans.anno.Trans;
import com.fhs.core.trans.constant.TransType;
@ -31,7 +31,7 @@ public class AiChatConversationRespVO implements VO {
private Long roleId;
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = "name", ref = "modelName")
@Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = "name", ref = "modelName")
private Long modelId;
@Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "ERNIE-Bot-turbo-0922")

View File

@ -14,18 +14,15 @@ import java.util.Map;
@Data
public class AiImageDrawReqVO {
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
private String platform; // 参见 AiPlatformEnum 枚举
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
@NotNull(message = "模型编号不能为空")
private Long modelId;
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "画一个长城")
@NotEmpty(message = "提示词不能为空")
@Size(max = 1200, message = "提示词最大 1200")
private String prompt;
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "stable-diffusion-v1-6")
@NotEmpty(message = "模型不能为空")
private String model;
/**
* 1. dall-e-2 模型256x256512x5121024x1024
* 2. dall-e-3 模型1024x1024, 1792x1024, 1024x1792

View File

@ -6,9 +6,8 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
@ -76,9 +75,9 @@ public class AiApiKeyController {
@GetMapping("/simple-list")
@Operation(summary = "获得 API 密钥分页列表")
public CommonResult<List<AiChatModelRespVO>> getApiKeySimpleList() {
public CommonResult<List<AiModelRespVO>> getApiKeySimpleList() {
List<AiApiKeyDO> list = apiKeyService.getApiKeyList();
return success(convertList(list, key -> new AiChatModelRespVO().setId(key.getId()).setName(key.getName())));
return success(convertList(list, key -> new AiModelRespVO().setId(key.getId()).setName(key.getName())));
}
}

View File

@ -1,84 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.model;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import jakarta.validation.Valid;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@Tag(name = "管理后台 - AI 聊天模型")
@RestController
@RequestMapping("/ai/chat-model")
@Validated
public class AiChatModelController {
@Resource
private AiChatModelService chatModelService;
@PostMapping("/create")
@Operation(summary = "创建聊天模型")
@PreAuthorize("@ss.hasPermission('ai:chat-model:create')")
public CommonResult<Long> createChatModel(@Valid @RequestBody AiChatModelSaveReqVO createReqVO) {
return success(chatModelService.createChatModel(createReqVO));
}
@PutMapping("/update")
@Operation(summary = "更新聊天模型")
@PreAuthorize("@ss.hasPermission('ai:chat-model:update')")
public CommonResult<Boolean> updateChatModel(@Valid @RequestBody AiChatModelSaveReqVO updateReqVO) {
chatModelService.updateChatModel(updateReqVO);
return success(true);
}
@DeleteMapping("/delete")
@Operation(summary = "删除聊天模型")
@Parameter(name = "id", description = "编号", required = true)
@PreAuthorize("@ss.hasPermission('ai:chat-model:delete')")
public CommonResult<Boolean> deleteChatModel(@RequestParam("id") Long id) {
chatModelService.deleteChatModel(id);
return success(true);
}
@GetMapping("/get")
@Operation(summary = "获得聊天模型")
@Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('ai:chat-model:query')")
public CommonResult<AiChatModelRespVO> getChatModel(@RequestParam("id") Long id) {
AiChatModelDO chatModel = chatModelService.getChatModel(id);
return success(BeanUtils.toBean(chatModel, AiChatModelRespVO.class));
}
@GetMapping("/page")
@Operation(summary = "获得聊天模型分页")
@PreAuthorize("@ss.hasPermission('ai:chat-model:query')")
public CommonResult<PageResult<AiChatModelRespVO>> getChatModelPage(@Valid AiChatModelPageReqVO pageReqVO) {
PageResult<AiChatModelDO> pageResult = chatModelService.getChatModelPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiChatModelRespVO.class));
}
@GetMapping("/simple-list")
@Operation(summary = "获得聊天模型列表")
@Parameter(name = "status", description = "状态", required = true, example = "1")
public CommonResult<List<AiChatModelRespVO>> getChatModelSimpleList(@RequestParam("status") Integer status) {
List<AiChatModelDO> list = chatModelService.getChatModelListByStatus(status);
return success(convertList(list, model -> new AiChatModelRespVO().setId(model.getId())
.setName(model.getName()).setModel(model.getModel())));
}
}

View File

@ -0,0 +1,89 @@
package cn.iocoder.yudao.module.ai.controller.admin.model;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import jakarta.validation.Valid;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@Tag(name = "管理后台 - AI 模型")
@RestController
@RequestMapping("/ai/model")
@Validated
public class AiModelController {
@Resource
private AiModelService modelService;
@PostMapping("/create")
@Operation(summary = "创建模型")
@PreAuthorize("@ss.hasPermission('ai:model:create')")
public CommonResult<Long> createModel(@Valid @RequestBody AiModelSaveReqVO createReqVO) {
return success(modelService.createModel(createReqVO));
}
@PutMapping("/update")
@Operation(summary = "更新模型")
@PreAuthorize("@ss.hasPermission('ai:model:update')")
public CommonResult<Boolean> updateModel(@Valid @RequestBody AiModelSaveReqVO updateReqVO) {
modelService.updateModel(updateReqVO);
return success(true);
}
@DeleteMapping("/delete")
@Operation(summary = "删除模型")
@Parameter(name = "id", description = "编号", required = true)
@PreAuthorize("@ss.hasPermission('ai:model:delete')")
public CommonResult<Boolean> deleteModel(@RequestParam("id") Long id) {
modelService.deleteModel(id);
return success(true);
}
@GetMapping("/get")
@Operation(summary = "获得模型")
@Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('ai:model:query')")
public CommonResult<AiModelRespVO> getModel(@RequestParam("id") Long id) {
AiModelDO model = modelService.getModel(id);
return success(BeanUtils.toBean(model, AiModelRespVO.class));
}
@GetMapping("/page")
@Operation(summary = "获得模型分页")
@PreAuthorize("@ss.hasPermission('ai:model:query')")
public CommonResult<PageResult<AiModelRespVO>> getModelPage(@Valid AiModelPageReqVO pageReqVO) {
PageResult<AiModelDO> pageResult = modelService.getModelPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiModelRespVO.class));
}
@GetMapping("/simple-list")
@Operation(summary = "获得模型列表")
@Parameter(name = "type", description = "类型", required = true, example = "1")
@Parameter(name = "platform", description = "平台", example = "midjourney")
public CommonResult<List<AiModelRespVO>> getModelSimpleList(
@RequestParam("type") Integer type,
@RequestParam(value = "platform", required = false) String platform) {
List<AiModelDO> list = modelService.getModelListByStatusAndType(
CommonStatusEnum.ENABLE.getStatus(), type, platform);
return success(convertList(list, model -> new AiModelRespVO().setId(model.getId())
.setName(model.getName()).setModel(model.getModel()).setPlatform(model.getPlatform())));
}
}

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.fhs.core.trans.anno.Trans;
import com.fhs.core.trans.constant.TransType;
import com.fhs.core.trans.vo.VO;
@ -20,7 +20,7 @@ public class AiChatRoleRespVO implements VO {
private Long userId;
@Schema(description = "模型编号", example = "17640")
@Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"})
@Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"})
private Long modelId;
@Schema(description = "模型名字", example = "张三")
private String modelName;

View File

@ -1,12 +1,12 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import lombok.*;
import io.swagger.v3.oas.annotations.media.Schema;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
@Schema(description = "管理后台 - API 聊天模型分页 Request VO")
@Schema(description = "管理后台 - API 模型分页 Request VO")
@Data
public class AiChatModelPageReqVO extends PageParam {
public class AiModelPageReqVO extends PageParam {
@Schema(description = "模型名字", example = "张三")
private String name;

View File

@ -1,13 +1,13 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
@Schema(description = "管理后台 - AI 聊天模型 Response VO")
@Schema(description = "管理后台 - AI 模型 Response VO")
@Data
public class AiChatModelRespVO {
public class AiModelRespVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2630")
private Long id;
@ -24,6 +24,9 @@ public class AiChatModelRespVO {
@Schema(description = "模型平台", example = "OpenAI")
private String platform;
@Schema(description = "模型类型", example = "1")
private Integer type;
@Schema(description = "排序", example = "1")
private Integer sort;

View File

@ -1,14 +1,17 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.validation.InEnum;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
import jakarta.validation.constraints.*;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - API 聊天模型新增/修改 Request VO")
@Schema(description = "管理后台 - API 模型新增/修改 Request VO")
@Data
public class AiChatModelSaveReqVO {
public class AiModelSaveReqVO {
@Schema(description = "编号", example = "2630")
private Long id;
@ -27,8 +30,14 @@ public class AiChatModelSaveReqVO {
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
@NotEmpty(message = "模型平台不能为空")
@InEnum(AiPlatformEnum.class)
private String platform;
@Schema(description = "模型类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "模型类型不能为空")
@InEnum(AiModelTypeEnum.class)
private Integer type;
@Schema(description = "排序", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "排序不能为空")
private Integer sort;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
@ -76,13 +76,13 @@ public class AiChatConversationDO extends BaseDO {
/**
* 模型编号
*
* 关联 {@link AiChatModelDO#getId()} 字段
* 关联 {@link AiModelDO#getId()} 字段
*/
private Long modelId;
/**
* 模型标志
*
* 冗余 {@link AiChatModelDO#getModel()} 字段
* 冗余 {@link AiModelDO#getModel()} 字段
*/
private String model;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField;
@ -83,13 +83,13 @@ public class AiChatMessageDO extends BaseDO {
/**
* 模型标志
*
* 冗余 {@link AiChatModelDO#getModel()}
* 冗余 {@link AiModelDO#getModel()}
*/
private String model;
/**
* 模型编号
*
* 关联 {@link AiChatModelDO#getId()} 字段
* 关联 {@link AiModelDO#getId()} 字段
*/
private Long modelId;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.image;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
import com.baomidou.mybatisplus.annotation.KeySequence;
@ -52,11 +52,16 @@ public class AiImageDO extends BaseDO {
* 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
*/
private String platform;
// TODO @芋艿modelId
/**
* 模型
* 模型编号
*
* 冗余 {@link AiChatModelDO#getModel()}
* 关联 {@link AiModelDO#getId()}
*/
private Long modelId;
/**
* 模型标识
*
* 冗余 {@link AiModelDO#getModel()}
*/
private String model;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
@ -35,13 +35,13 @@ public class AiKnowledgeDO extends BaseDO {
/**
* 向量模型编号
*
* 关联 {@link AiChatModelDO#getId()}
* 关联 {@link AiModelDO#getId()}
*/
private Long embeddingModelId;
/**
* 模型标识
*
* 冗余 {@link AiChatModelDO#getModel()}
* 冗余 {@link AiModelDO#getModel()}
*/
private String embeddingModel;

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
@ -36,7 +37,12 @@ public class AiMindMapDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
// TODO @芋艿modelId
/**
* 模型编号
*
* 关联 {@link AiModelDO#getId()}
*/
private Long modelId;
/**
* 模型
*/

View File

@ -58,7 +58,7 @@ public class AiChatRoleDO extends BaseDO {
/**
* 模型编号
*
* 关联 {@link AiChatModelDO#getId()} 字段
* 关联 {@link AiModelDO#getId()} 字段
*/
private Long modelId;

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
@ -8,23 +9,22 @@ import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*;
// TODO @芋艿需要改造增加 type
/**
* AI 聊天模型 DO
* AI 模型 DO
*
* 默认聊天模型{@link #status} 为开启并且 {@link #sort} 排序第一
* 默认模型{@link #status} 为开启并且 {@link #sort} 排序第一
*
* @author fansili
* @since 2024/4/24 19:39
*/
@TableName("ai_chat_model")
@KeySequence("ai_chat_model_seq") // 用于 OraclePostgreSQLKingbaseDB2H2 数据库的主键自增如果是 MySQL 等数据库可不写
@TableName("ai_model")
@KeySequence("ai_model_seq") // 用于 OraclePostgreSQLKingbaseDB2H2 数据库的主键自增如果是 MySQL 等数据库可不写
@Data
@EqualsAndHashCode(callSuper = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class AiChatModelDO extends BaseDO {
public class AiModelDO extends BaseDO {
/**
* 编号
@ -51,6 +51,12 @@ public class AiChatModelDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
/**
* 类型
*
* 枚举 {@link AiModelTypeEnum}
*/
private Integer type;
/**
* 排序值

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.write;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
@ -44,7 +46,12 @@ public class AiWriteDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
// TODO @芋艿modelId
/**
* 模型编号
*
* 关联 {@link AiModelDO#getId()}
*/
private Long modelId;
/**
* 模型
*/
@ -67,25 +74,25 @@ public class AiWriteDO extends BaseDO {
/**
* 长度提示词
*
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LENGTH}
* 字典{@link DictTypeConstants#AI_WRITE_LENGTH}
*/
private Integer length;
/**
* 格式提示词
*
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_FORMAT}
* 字典{@link DictTypeConstants#AI_WRITE_FORMAT}
*/
private Integer format;
/**
* 语气提示词
*
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_TONE}
* 字典{@link DictTypeConstants#AI_WRITE_TONE}
*/
private Integer tone;
/**
* 语言提示词
*
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LANGUAGE}
* 字典{@link DictTypeConstants#AI_WRITE_LANGUAGE}
*/
private Integer language;

View File

@ -0,0 +1,47 @@
package cn.iocoder.yudao.module.ai.dal.mysql.model;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import org.apache.ibatis.annotations.Mapper;
import javax.annotation.Nullable;
import java.util.List;
/**
* API 模型 Mapper
*
* @author fansili
*/
@Mapper
public interface AiChatMapper extends BaseMapperX<AiModelDO> {
default AiModelDO selectFirstByStatus(Integer type, Integer status) {
return selectOne(new QueryWrapperX<AiModelDO>()
.eq("type", type)
.eq("status", status)
.limitN(1)
.orderByAsc("sort"));
}
default PageResult<AiModelDO> selectPage(AiModelPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiModelDO>()
.likeIfPresent(AiModelDO::getName, reqVO.getName())
.eqIfPresent(AiModelDO::getModel, reqVO.getModel())
.eqIfPresent(AiModelDO::getPlatform, reqVO.getPlatform())
.orderByAsc(AiModelDO::getSort));
}
default List<AiModelDO> selectListByStatusAndType(Integer status, Integer type,
@Nullable String platform) {
return selectList(new LambdaQueryWrapperX<AiModelDO>()
.eq(AiModelDO::getStatus, status)
.eq(AiModelDO::getType, type)
.eqIfPresent(AiModelDO::getPlatform, platform)
.orderByAsc(AiModelDO::getSort));
}
}

View File

@ -1,43 +0,0 @@
package cn.iocoder.yudao.module.ai.dal.mysql.model;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import org.apache.ibatis.annotations.Mapper;
import java.util.Collection;
import java.util.List;
/**
* API 聊天模型 Mapper
*
* @author fansili
*/
@Mapper
public interface AiChatModelMapper extends BaseMapperX<AiChatModelDO> {
default AiChatModelDO selectFirstByStatus(Integer status) {
return selectOne(new QueryWrapperX<AiChatModelDO>()
.eq("status", status)
.limitN(1)
.orderByAsc("sort"));
}
default PageResult<AiChatModelDO> selectPage(AiChatModelPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiChatModelDO>()
.likeIfPresent(AiChatModelDO::getName, reqVO.getName())
.eqIfPresent(AiChatModelDO::getModel, reqVO.getModel())
.eqIfPresent(AiChatModelDO::getPlatform, reqVO.getPlatform())
.orderByAsc(AiChatModelDO::getSort));
}
default List<AiChatModelDO> selectList(Integer status) {
return selectList(new LambdaQueryWrapperX<AiChatModelDO>()
.eq(AiChatModelDO::getStatus, status)
.orderByAsc(AiChatModelDO::getSort));
}
}

View File

@ -4,17 +4,18 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
@ -44,7 +45,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
private AiChatConversationMapper chatConversationMapper;
@Resource
private AiChatModelService chatModalService;
private AiModelService chatModalService;
@Resource
private AiChatRoleService chatRoleService;
@Resource
@ -54,9 +55,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) {
// 1.1 获得 AiChatRoleDO 聊天角色
AiChatRoleDO role = createReqVO.getRoleId() != null ? chatRoleService.validateChatRole(createReqVO.getRoleId()) : null;
// 1.2 获得 AiChatModelDO 聊天模型
AiChatModelDO model = role != null && role.getModelId() != null ? chatModalService.validateChatModel(role.getModelId())
: chatModalService.getRequiredDefaultChatModel();
// 1.2 获得 AiModelDO 聊天模型
AiModelDO model = role != null && role.getModelId() != null ? chatModalService.validateModel(role.getModelId())
: chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
Assert.notNull(model, "必须找到默认模型");
validateChatModel(model);
@ -86,9 +87,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
}
// 1.2 校验模型是否存在修改模型的情况
AiChatModelDO model = null;
AiModelDO model = null;
if (updateReqVO.getModelId() != null) {
model = chatModalService.validateChatModel(updateReqVO.getModelId());
model = chatModalService.validateModel(updateReqVO.getModelId());
}
// 1.3 校验知识库是否存在
@ -139,7 +140,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
chatConversationMapper.deleteById(id);
}
private void validateChatModel(AiChatModelDO model) {
private void validateChatModel(AiModelDO model) {
if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) {
return;
}

View File

@ -15,13 +15,12 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
@ -63,9 +62,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Resource
private AiChatConversationService chatConversationService;
@Resource
private AiChatModelService chatModalService;
@Resource
private AiApiKeyService apiKeyService;
private AiModelService modalService;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService;
@ -78,8 +75,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
AiModelDO model = modalService.validateModel(conversation.getModelId());
ChatModel chatModel = modalService.getChatModel(model.getKeyId());
// 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -112,8 +109,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
AiModelDO model = modalService.validateModel(conversation.getModelId());
StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId());
// 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -161,8 +158,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return null;
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, List<AiKnowledgeSegmentDO> segmentList,
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
// 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>();
@ -232,7 +229,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
AiChatModelDO model, Long userId, Long roleId,
AiModelDO model, Long userId, Long roleId,
MessageType messageType, String content, Boolean useContext) {
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)

View File

@ -12,13 +12,17 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
import jakarta.annotation.Resource;
@ -54,15 +58,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
@Slf4j
public class AiImageServiceImpl implements AiImageService {
@Resource
private AiModelService modelService;
@Resource
private AiImageMapper imageMapper;
@Resource
private FileApi fileApi;
@Resource
private AiApiKeyService apiKeyService;
@Override
public PageResult<AiImageDO> getImagePageMy(Long userId, AiImagePageReqVO pageReqVO) {
return imageMapper.selectPageMy(userId, pageReqVO);
@ -88,23 +92,31 @@ public class AiImageServiceImpl implements AiImageService {
@Override
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
// 1. 保存数据库
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
// 1. 校验模型
AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
// 2. 保存数据库
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId)
.setPlatform(model.getPlatform()).setModelId(model.getId()).setModel(model.getModel())
.setPublicStatus(false).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(image);
// 2. 异步绘制后续前端通过返回的 id 进行轮询结果
getSelf().executeDrawImage(image, drawReqVO);
// 3. 异步绘制后续前端通过返回的 id 进行轮询结果
getSelf().executeDrawImage(image, drawReqVO, model);
return image.getId();
}
@Async
public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) {
public void executeDrawImage(AiImageDO image, AiImageDrawReqVO reqVO, AiModelDO model) {
try {
// 1.1 构建请求
ImageOptions request = buildImageOptions(req);
ImageOptions request = buildImageOptions(reqVO, model);
// 1.2 执行请求
ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform()));
ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
ImageModel imageModel = modelService.getImageModel(model.getId());
ImageResponse response = imageModel.call(new ImagePrompt(reqVO.getPrompt(), request));
if (response.getResult() == null) {
throw new IllegalArgumentException("生成结果为空");
}
// 2. 上传到文件服务
String b64Json = response.getResult().getOutput().getB64Json();
@ -116,25 +128,25 @@ public class AiImageServiceImpl implements AiImageService {
imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus())
.setPicUrl(filePath).setFinishTime(LocalDateTime.now()));
} catch (Exception ex) {
log.error("[doDall][image({}) 生成异常]", image, ex);
log.error("[executeDrawImage][image({}) 生成异常]", image, ex);
imageMapper.updateById(new AiImageDO().setId(image.getId())
.setStatus(AiImageStatusEnum.FAIL.getStatus())
.setErrorMessage(ex.getMessage()).setFinishTime(LocalDateTime.now()));
}
}
private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) {
if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
private static ImageOptions buildImageOptions(AiImageDrawReqVO draw, AiModelDO model) {
if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
// https://platform.openai.com/docs/api-reference/images/create
return OpenAiImageOptions.builder().withModel(draw.getModel())
return OpenAiImageOptions.builder().withModel(model.getModel())
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
.withResponseFormat("b64_json")
.build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
return StabilityAiImageOptions.builder().model(draw.getModel())
return StabilityAiImageOptions.builder().model(model.getModel())
.height(draw.getHeight()).width(draw.getWidth())
.seed(Long.valueOf(draw.getOptions().get("seed")))
.cfgScale(Float.valueOf(draw.getOptions().get("scale")))
@ -143,22 +155,22 @@ public class AiImageServiceImpl implements AiImageService {
.stylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
.clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
.build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
return DashScopeImageOptions.builder()
.withModel(draw.getModel()).withN(1)
.withModel(model.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
return QianFanImageOptions.builder()
.model(draw.getModel()).N(1)
.model(model.getModel()).N(1)
.height(draw.getHeight()).width(draw.getWidth())
.build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
return ZhiPuAiImageOptions.builder()
.model(draw.getModel())
.model(model.getModel())
.build();
}
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
throw new IllegalArgumentException("不支持的 AI 平台:" + model.getPlatform());
}
@Override
@ -206,7 +218,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override
@Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
// 1. 保存数据库
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
@ -237,7 +249,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override
public Integer midjourneySync() {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
// 1.1 获取 Midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
@ -308,7 +320,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
// 1.1 检查 image
AiImageDO image = validateImageExists(reqVO.getId());
if (ObjUtil.notEqual(userId, image.getUserId())) {

View File

@ -7,14 +7,17 @@ import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.*;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSaveReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
@ -33,8 +36,8 @@ import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
/**
* AI 知识库分片 Service 实现类
@ -58,7 +61,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Lazy // 延迟加载避免循环依赖
private AiKnowledgeDocumentService knowledgeDocumentService;
@Resource
private AiApiKeyService apiKeyService;
private AiModelService modelService;
@Resource
private TokenCountEstimator tokenCountEstimator;
@ -180,7 +183,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
// 2.1 向量检索
VectorStore vectorStore = apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
VectorStore vectorStore = getVectorStoreById(knowledge);
List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
.query(reqBO.getContent())
.topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()))
@ -251,11 +254,12 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
}
private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) {
return apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
return modelService.getOrCreateVectorStore(knowledge.getEmbeddingModelId());
}
private VectorStore getVectorStoreById(Long knowledgeId) {
return getVectorStoreById(knowledgeService.validateKnowledgeExists(knowledgeId));
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(knowledgeId);
return getVectorStoreById(knowledge);
}
private static List<Document> splitContentByToken(String content, Integer segmentMaxTokens) {

View File

@ -5,9 +5,9 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@ -28,12 +28,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
private AiKnowledgeMapper knowledgeMapper;
@Resource
private AiChatModelService chatModelService;
private AiModelService chatModelService;
@Override
public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) {
// 1. 校验模型配置
AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getEmbeddingModelId());
AiModelDO model = chatModelService.validateModel(createReqVO.getEmbeddingModelId());
// 2. 插入知识库
AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
@ -47,7 +47,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 1.1 校验知识库存在
validateKnowledgeExists(updateReqVO.getId());
// 1.2 校验模型配置
AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getEmbeddingModelId());
AiModelDO model = chatModelService.validateModel(updateReqVO.getEmbeddingModelId());
// 2. 更新知识库
AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class)

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
@ -12,14 +13,13 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
@ -50,9 +50,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_E
public class AiMindMapServiceImpl implements AiMindMapService {
@Resource
private AiApiKeyService apiKeyService;
@Resource
private AiChatModelService chatModalService;
private AiModelService modalService;
@Resource
private AiChatRoleService chatRoleService;
@ -65,17 +63,17 @@ public class AiMindMapServiceImpl implements AiMindMapService {
AiChatRoleDO role = CollUtil.getFirst(
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
// 1.1 获取导图执行模型
AiChatModelDO model = getModel(role);
AiModelDO model = getModel(role);
// 1.2 获取角色设定消息
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
// 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
ChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 插入思维导图信息
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, mindMap -> mindMap.setUserId(userId)
.setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
mindMapMapper.insert(mindMapDO);
// 3.1 构建 Prompt并进行调用
@ -103,7 +101,7 @@ public class AiMindMapServiceImpl implements AiMindMapService {
}
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
// 1. 构建 message 列表
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象
@ -123,13 +121,13 @@ public class AiMindMapServiceImpl implements AiMindMapService {
return chatMessages;
}
private AiChatModelDO getModel(AiChatRoleDO role) {
AiChatModelDO model = null;
private AiModelDO getModel(AiChatRoleDO role) {
AiModelDO model = null;
if (role != null && role.getModelId() != null) {
model = chatModalService.getChatModel(role.getModelId());
model = modalService.getModel(role.getModelId());
}
if (model == null) {
model = chatModalService.getRequiredDefaultChatModel();
model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
}
Assert.notNull(model, "[AI] 获取不到模型");
return model;

View File

@ -1,16 +1,10 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import java.util.List;
@ -74,50 +68,13 @@ public interface AiApiKeyService {
*/
List<AiApiKeyDO> getApiKeyList();
// ========== spring-ai 集成 ==========
/**
* 获得 ChatModel 对象
*
* @param id 编号
* @return ChatModel 对象
*/
ChatModel getChatModel(Long id);
/**
* 获得 ImageModel 对象
*
* TODO 可优化点目前默认获取 platform 对应的第一个开启的配置用于绘画后续可以支持配置选择
* 获得默认的 API 密钥
*
* @param platform 平台
* @return ImageModel 对象
* @param status 状态
* @return API 密钥
*/
ImageModel getImageModel(AiPlatformEnum platform);
/**
* 获得 MidjourneyApi 对象
*
* TODO 可优化点目前默认获取 Midjourney 对应的第一个开启的配置用于绘画后续可以支持配置选择
*
* @return MidjourneyApi 对象
*/
MidjourneyApi getMidjourneyApi();
/**
* 获得 SunoApi 对象
*
* TODO 可优化点目前默认获取 Suno 对应的第一个开启的配置用于音乐后续可以支持配置选择
*
* @return SunoApi 对象
*/
SunoApi getSunoApi();
/**
* 获得 VectorStore 对象
*
* @param modelId 编号
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStoreByModelId(Long modelId);
AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status);
}

View File

@ -1,31 +1,21 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_DISABLE;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_NOT_EXISTS;
/**
* AI API 密钥 Service 实现类
@ -39,14 +29,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource
private AiApiKeyMapper apiKeyMapper;
// TODO @芋艿后续要不要改
@Resource
@Lazy // 延迟加载解决渲染依赖
private AiChatModelService chatModelService;
@Resource
private AiModelFactory modelFactory;
@Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
// 插入
@ -105,57 +87,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return apiKeyMapper.selectList();
}
// ========== spring-ai 集成 ==========
@Override
public ChatModel getChatModel(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public ImageModel getImageModel(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
public AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform, status);
if (apiKey == null) {
throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName());
throw exception(API_KEY_NOT_EXISTS);
}
return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public MidjourneyApi getMidjourneyApi() {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) {
throw exception(API_KEY_MIDJOURNEY_NOT_FOUND);
}
return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public SunoApi getSunoApi() {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) {
throw exception(API_KEY_SUNO_NOT_FOUND);
}
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public VectorStore getOrCreateVectorStoreByModelId(Long modelId) {
// 获取模型 + 密钥
AiChatModelDO chatModel = chatModelService.validateChatModel(modelId);
AiApiKeyDO apiKey = validateApiKey(chatModel.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
// 创建或获取 EmbeddingModel 对象
EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(),
apiKey.getUrl(), chatModel.getModel());
// 创建或获取 VectorStore 对象
return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
return apiKey;
}
}

View File

@ -1,92 +0,0 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import jakarta.validation.Valid;
import java.util.Collection;
import java.util.List;
import java.util.Set;
/**
* AI 聊天模型 Service 接口
*
* @author fansili
* @since 2024/4/24 19:42
*/
public interface AiChatModelService {
/**
* 创建聊天模型
*
* @param createReqVO 创建信息
* @return 编号
*/
Long createChatModel(@Valid AiChatModelSaveReqVO createReqVO);
/**
* 更新聊天模型
*
* @param updateReqVO 更新信息
*/
void updateChatModel(@Valid AiChatModelSaveReqVO updateReqVO);
/**
* 删除聊天模型
*
* @param id 编号
*/
void deleteChatModel(Long id);
/**
* 获得聊天模型
*
* @param id 编号
* @return 聊天模型
*/
AiChatModelDO getChatModel(Long id);
/**
* 获得默认的聊天模型
*
* 如果获取不到则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
*
* @return 聊天模型
*/
AiChatModelDO getRequiredDefaultChatModel();
/**
* 获得聊天模型分页
*
* @param pageReqVO 分页查询
* @return 聊天模型分页
*/
PageResult<AiChatModelDO> getChatModelPage(AiChatModelPageReqVO pageReqVO);
/**
* 校验聊天模型
*
* @param id 编号
* @return 聊天模型
*/
AiChatModelDO validateChatModel(Long id);
/**
* 获得聊天模型列表
*
* @param status 状态
* @return 聊天模型列表
*/
List<AiChatModelDO> getChatModelListByStatus(Integer status);
/**
* 获得聊天模型列表
*
* @param ids 编号数组
* @return 模型列表
*/
List<AiChatModelDO> getChatModelList(Collection<Long> ids);
}

View File

@ -1,114 +1,168 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatModelMapper;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import java.util.Collection;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/**
* AI 聊天模型 Service 实现类
* AI 模型 Service 实现类
*
* @author fansili
*/
@Service
@Validated
public class AiChatModelServiceImpl implements AiChatModelService {
public class AiChatModelServiceImpl implements AiModelService {
@Resource
private AiApiKeyService apiKeyService;
@Resource
private AiChatModelMapper chatModelMapper;
private AiChatMapper modelMapper;
@Resource
private AiModelFactory modelFactory;
@Override
public Long createChatModel(AiChatModelSaveReqVO createReqVO) {
public Long createModel(AiModelSaveReqVO createReqVO) {
// 1. 校验
AiPlatformEnum.validatePlatform(createReqVO.getPlatform());
apiKeyService.validateApiKey(createReqVO.getKeyId());
// 2. 插入
AiChatModelDO chatModel = BeanUtils.toBean(createReqVO, AiChatModelDO.class);
chatModelMapper.insert(chatModel);
return chatModel.getId();
AiModelDO model = BeanUtils.toBean(createReqVO, AiModelDO.class);
modelMapper.insert(model);
return model.getId();
}
@Override
public void updateChatModel(AiChatModelSaveReqVO updateReqVO) {
public void updateModel(AiModelSaveReqVO updateReqVO) {
// 1. 校验
validateChatModelExists(updateReqVO.getId());
validateModelExists(updateReqVO.getId());
AiPlatformEnum.validatePlatform(updateReqVO.getPlatform());
apiKeyService.validateApiKey(updateReqVO.getKeyId());
// 2. 更新
AiChatModelDO updateObj = BeanUtils.toBean(updateReqVO, AiChatModelDO.class);
chatModelMapper.updateById(updateObj);
AiModelDO updateObj = BeanUtils.toBean(updateReqVO, AiModelDO.class);
modelMapper.updateById(updateObj);
}
@Override
public void deleteChatModel(Long id) {
public void deleteModel(Long id) {
// 校验存在
validateChatModelExists(id);
validateModelExists(id);
// 删除
chatModelMapper.deleteById(id);
modelMapper.deleteById(id);
}
private AiChatModelDO validateChatModelExists(Long id) {
AiChatModelDO model = chatModelMapper.selectById(id);
if (chatModelMapper.selectById(id) == null) {
throw exception(CHAT_MODEL_NOT_EXISTS);
private AiModelDO validateModelExists(Long id) {
AiModelDO model = modelMapper.selectById(id);
if (modelMapper.selectById(id) == null) {
throw exception(MODEL_NOT_EXISTS);
}
return model;
}
@Override
public AiChatModelDO getChatModel(Long id) {
return chatModelMapper.selectById(id);
public AiModelDO getModel(Long id) {
return modelMapper.selectById(id);
}
@Override
public AiChatModelDO getRequiredDefaultChatModel() {
AiChatModelDO model = chatModelMapper.selectFirstByStatus(CommonStatusEnum.ENABLE.getStatus());
public AiModelDO getRequiredDefaultModel(Integer type) {
AiModelDO model = modelMapper.selectFirstByStatus(type, CommonStatusEnum.ENABLE.getStatus());
if (model == null) {
throw exception(CHAT_MODEL_DEFAULT_NOT_EXISTS);
throw exception(MODEL_DEFAULT_NOT_EXISTS);
}
return model;
}
@Override
public PageResult<AiChatModelDO> getChatModelPage(AiChatModelPageReqVO pageReqVO) {
return chatModelMapper.selectPage(pageReqVO);
public PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO) {
return modelMapper.selectPage(pageReqVO);
}
@Override
public AiChatModelDO validateChatModel(Long id) {
AiChatModelDO model = validateChatModelExists(id);
public AiModelDO validateModel(Long id) {
AiModelDO model = validateModelExists(id);
if (CommonStatusEnum.isDisable(model.getStatus())) {
throw exception(CHAT_MODEL_DISABLE);
throw exception(MODEL_DISABLE);
}
return model;
}
@Override
public List<AiChatModelDO> getChatModelListByStatus(Integer status) {
return chatModelMapper.selectList(status);
public List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type,
String platform) {
return modelMapper.selectListByStatusAndType(status, type, platform);
}
// ========== Spring AI 集成 ==========
@Override
public ChatModel getChatModel(Long id) {
AiModelDO model = validateModel(id);
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public List<AiChatModelDO> getChatModelList(Collection<Long> ids) {
return chatModelMapper.selectBatchIds(ids);
public ImageModel getImageModel(Long id) {
AiModelDO model = validateModel(id);
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public MidjourneyApi getMidjourneyApi() {
AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public SunoApi getSunoApi() {
AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public VectorStore getOrCreateVectorStore(Long id) {
// 获取模型 + 密钥
AiModelDO model = validateModel(id);
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
// 创建或获取 EmbeddingModel 对象
EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(
platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel());
// 创建或获取 VectorStore 对象
return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
}
}

View File

@ -0,0 +1,131 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import javax.annotation.Nullable;
import java.util.List;
/**
* AI 模型 Service 接口
*
* @author fansili
* @since 2024/4/24 19:42
*/
public interface AiModelService {
/**
* 创建模型
*
* @param createReqVO 创建信息
* @return 编号
*/
Long createModel(@Valid AiModelSaveReqVO createReqVO);
/**
* 更新模型
*
* @param updateReqVO 更新信息
*/
void updateModel(@Valid AiModelSaveReqVO updateReqVO);
/**
* 删除模型
*
* @param id 编号
*/
void deleteModel(Long id);
/**
* 获得模型
*
* @param id 编号
* @return 模型
*/
AiModelDO getModel(Long id);
/**
* 获得默认的模型
*
* 如果获取不到则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
*
* @return 模型
*/
AiModelDO getRequiredDefaultModel(Integer type);
/**
* 获得模型分页
*
* @param pageReqVO 分页查询
* @return 模型分页
*/
PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO);
/**
* 校验模型是否可使用
*
* @param id 编号
* @return 模型
*/
AiModelDO validateModel(Long id);
/**
* 获得模型列表
*
* @param status 状态
* @param type 类型
* @param platform 平台允许空
* @return 模型列表
*/
List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type,
@Nullable String platform);
// ========== Spring AI 集成 ==========
/**
* 获得 ChatModel 对象
*
* @param id 编号
* @return ChatModel 对象
*/
ChatModel getChatModel(Long id);
/**
* 获得 ImageModel 对象
*
* @param id 编号
* @return ImageModel 对象
*/
ImageModel getImageModel(Long id);
/**
* 获得 MidjourneyApi 对象
*
* @return MidjourneyApi 对象
*/
MidjourneyApi getMidjourneyApi();
/**
* 获得 SunoApi 对象
*
* @return SunoApi 对象
*/
SunoApi getSunoApi();
/**
* 获得 VectorStore 对象
*
* @param id 编号
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(Long id);
}

View File

@ -16,7 +16,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
@ -41,7 +41,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MUSIC_NOT_EXIS
public class AiMusicServiceImpl implements AiMusicService {
@Resource
private AiApiKeyService apiKeyService;
private AiModelService modelService;
@Resource
private AiMusicMapper musicMapper;
@ -53,7 +53,7 @@ public class AiMusicServiceImpl implements AiMusicService {
@Transactional(rollbackFor = Exception.class)
public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
// 1. 调用 Suno 生成音乐
SunoApi sunoApi = apiKeyService.getSunoApi();
SunoApi sunoApi = modelService.getSunoApi();
List<SunoApi.MusicData> musicDataList;
if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
// 1.1 描述模式
@ -88,7 +88,7 @@ public class AiMusicServiceImpl implements AiMusicService {
log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size());
// GET 请求为避免参数过长分批次处理
SunoApi sunoApi = apiKeyService.getSunoApi();
SunoApi sunoApi = modelService.getSunoApi();
CollUtil.split(streamingTask, 36).forEach(chunkList -> {
Map<String, Long> taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId);
List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
@ -11,17 +12,16 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWritePageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
@ -54,17 +54,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.WRITE_NOT_EXIS
public class AiWriteServiceImpl implements AiWriteService {
@Resource
private AiApiKeyService apiKeyService;
@Resource
private AiChatModelService chatModalService;
private AiModelService chatModalService;
@Resource
private AiChatRoleService chatRoleService;
@Resource
private DictDataApi dictDataApi;
private AiWriteMapper writeMapper;
@Resource
private AiWriteMapper writeMapper;
private DictDataApi dictDataApi;
@Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
@ -72,17 +70,17 @@ public class AiWriteServiceImpl implements AiWriteService {
AiChatRoleDO writeRole = CollUtil.getFirst(
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
// 1.1 获取写作执行模型
AiChatModelDO model = getModel(writeRole);
AiModelDO model = getModel(writeRole);
// 1.2 获取角色设定消息
String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
// 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
StreamingChatModel chatModel = chatModalService.getChatModel(model.getKeyId());
// 2. 插入写作信息
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class,
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId)
.setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
writeMapper.insert(writeDO);
// 3.1 构建 Prompt并进行调用
@ -109,19 +107,19 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
}
private AiChatModelDO getModel(AiChatRoleDO writeRole) {
AiChatModelDO model = null;
private AiModelDO getModel(AiChatRoleDO writeRole) {
AiModelDO model = null;
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
model = chatModalService.getChatModel(writeRole.getModelId());
model = chatModalService.getModel(writeRole.getModelId());
}
if (model == null) {
model = chatModalService.getRequiredDefaultChatModel();
model = chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
}
Assert.notNull(model, "[AI] 获取不到模型");
return model;
}
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
// 1. 构建 message 列表
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象

View File

@ -0,0 +1,41 @@
package cn.iocoder.yudao.framework.ai.core.enums;
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.Arrays;
/**
* AI 模型类型的枚举
*
* @author 芋道源码
*/
@Getter
@RequiredArgsConstructor
public enum AiModelTypeEnum implements ArrayValuable<Integer> {
CHAT(1, "对话"),
IMAGE(2, "图片"),
VOICE(3, "语音"),
VIDEO(4, "视频"),
EMBEDDING(5, "向量"),
RERANK(6, "重排序");
/**
* 类型
*/
private final Integer type;
/**
* 类型名
*/
private final String name;
public static final Integer[] ARRAYS = Arrays.stream(values()).map(AiModelTypeEnum::getType).toArray(Integer[]::new);
@Override
public Integer[] array() {
return ARRAYS;
}
}

View File

@ -1,8 +1,11 @@
package cn.iocoder.yudao.framework.ai.core.enums;
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 模型平台
*
@ -10,7 +13,7 @@ import lombok.Getter;
*/
@Getter
@AllArgsConstructor
public enum AiPlatformEnum {
public enum AiPlatformEnum implements ArrayValuable<String> {
// ========== 国内平台 ==========
@ -44,6 +47,8 @@ public enum AiPlatformEnum {
*/
private final String name;
public static final String[] ARRAYS = Arrays.stream(values()).map(AiPlatformEnum::getPlatform).toArray(String[]::new);
public static AiPlatformEnum validatePlatform(String platform) {
for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) {
if (platformEnum.getPlatform().equals(platform)) {
@ -53,4 +58,9 @@ public enum AiPlatformEnum {
throw new IllegalArgumentException("非法平台: " + platform);
}
@Override
public String[] array() {
return ARRAYS;
}
}

View File

@ -456,25 +456,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
return vectorStore;
}
/**
* 创建向量存储文件
*
* @param embeddingModel 嵌入模型
* @return 向量存储文件
*/
private File createVectorStoreFile(EmbeddingModel embeddingModel) {
// 获取简单类名
String simpleClassName = embeddingModel.getClass().getSimpleName();
// 获取用户主目录
String userHome = FileUtil.getUserHomePath();
// 创建vector_store目录
File vectorStoreDir = new File(userHome, "vector_store");
if (!vectorStoreDir.exists()) {
vectorStoreDir.mkdirs();
}
// 创建文件
return new File(vectorStoreDir, "simple_" + simpleClassName + ".json");
}
}