【代码重构】AI:“聊天模型”重构为“模型”,支持 type 模型类型
This commit is contained in:
parent
3f460dc620
commit
89d079349c
|
@ -12,31 +12,25 @@ public interface ErrorCodeConstants {
|
|||
// ========== API 密钥 1-040-000-000 ==========
|
||||
ErrorCode API_KEY_NOT_EXISTS = new ErrorCode(1_040_000_000, "API 密钥不存在");
|
||||
ErrorCode API_KEY_DISABLE = new ErrorCode(1_040_000_001, "API 密钥已禁用!");
|
||||
ErrorCode API_KEY_MIDJOURNEY_NOT_FOUND = new ErrorCode(1_040_000_900, "Midjourney 模型不存在");
|
||||
ErrorCode API_KEY_SUNO_NOT_FOUND = new ErrorCode(1_040_000_901, "Suno 模型不存在");
|
||||
ErrorCode API_KEY_IMAGE_NODE_FOUND = new ErrorCode(1_040_000_902, "平台({}) 图片模型未配置");
|
||||
|
||||
// ========== API 聊天模型 1-040-001-000 ==========
|
||||
ErrorCode CHAT_MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!");
|
||||
ErrorCode CHAT_MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!");
|
||||
ErrorCode CHAT_MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认聊天模型");
|
||||
// ========== API 模型 1-040-001-000 ==========
|
||||
ErrorCode MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!");
|
||||
ErrorCode MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!");
|
||||
ErrorCode MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认模型");
|
||||
|
||||
// ========== API 聊天角色 1-040-002-000 ==========
|
||||
ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在");
|
||||
ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!");
|
||||
|
||||
// ========== API 聊天会话 1-040-003-000 ==========
|
||||
|
||||
ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!");
|
||||
ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整");
|
||||
|
||||
// ========== API 聊天消息 1-040-004-000 ==========
|
||||
|
||||
ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
|
||||
ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!");
|
||||
|
||||
// ========== API 绘画 1-040-005-000 ==========
|
||||
|
||||
ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
|
||||
ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
|
||||
ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import com.fhs.core.trans.anno.Trans;
|
||||
import com.fhs.core.trans.constant.TransType;
|
||||
|
@ -31,7 +31,7 @@ public class AiChatConversationRespVO implements VO {
|
|||
private Long roleId;
|
||||
|
||||
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
@Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = "name", ref = "modelName")
|
||||
@Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = "name", ref = "modelName")
|
||||
private Long modelId;
|
||||
|
||||
@Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "ERNIE-Bot-turbo-0922")
|
||||
|
|
|
@ -14,18 +14,15 @@ import java.util.Map;
|
|||
@Data
|
||||
public class AiImageDrawReqVO {
|
||||
|
||||
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
|
||||
private String platform; // 参见 AiPlatformEnum 枚举
|
||||
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
|
||||
@NotNull(message = "模型编号不能为空")
|
||||
private Long modelId;
|
||||
|
||||
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "画一个长城")
|
||||
@NotEmpty(message = "提示词不能为空")
|
||||
@Size(max = 1200, message = "提示词最大 1200")
|
||||
private String prompt;
|
||||
|
||||
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "stable-diffusion-v1-6")
|
||||
@NotEmpty(message = "模型不能为空")
|
||||
private String model;
|
||||
|
||||
/**
|
||||
* 1. dall-e-2 模型:256x256、512x512、1024x1024
|
||||
* 2. dall-e-3 模型:1024x1024, 1792x1024, 或 1024x1792
|
||||
|
|
|
@ -6,9 +6,8 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
|
@ -76,9 +75,9 @@ public class AiApiKeyController {
|
|||
|
||||
@GetMapping("/simple-list")
|
||||
@Operation(summary = "获得 API 密钥分页列表")
|
||||
public CommonResult<List<AiChatModelRespVO>> getApiKeySimpleList() {
|
||||
public CommonResult<List<AiModelRespVO>> getApiKeySimpleList() {
|
||||
List<AiApiKeyDO> list = apiKeyService.getApiKeyList();
|
||||
return success(convertList(list, key -> new AiChatModelRespVO().setId(key.getId()).setName(key.getName())));
|
||||
return success(convertList(list, key -> new AiModelRespVO().setId(key.getId()).setName(key.getName())));
|
||||
}
|
||||
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
import jakarta.validation.Valid;
|
||||
import org.springframework.security.access.prepost.PreAuthorize;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
||||
|
||||
@Tag(name = "管理后台 - AI 聊天模型")
|
||||
@RestController
|
||||
@RequestMapping("/ai/chat-model")
|
||||
@Validated
|
||||
public class AiChatModelController {
|
||||
|
||||
@Resource
|
||||
private AiChatModelService chatModelService;
|
||||
|
||||
@PostMapping("/create")
|
||||
@Operation(summary = "创建聊天模型")
|
||||
@PreAuthorize("@ss.hasPermission('ai:chat-model:create')")
|
||||
public CommonResult<Long> createChatModel(@Valid @RequestBody AiChatModelSaveReqVO createReqVO) {
|
||||
return success(chatModelService.createChatModel(createReqVO));
|
||||
}
|
||||
|
||||
@PutMapping("/update")
|
||||
@Operation(summary = "更新聊天模型")
|
||||
@PreAuthorize("@ss.hasPermission('ai:chat-model:update')")
|
||||
public CommonResult<Boolean> updateChatModel(@Valid @RequestBody AiChatModelSaveReqVO updateReqVO) {
|
||||
chatModelService.updateChatModel(updateReqVO);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@DeleteMapping("/delete")
|
||||
@Operation(summary = "删除聊天模型")
|
||||
@Parameter(name = "id", description = "编号", required = true)
|
||||
@PreAuthorize("@ss.hasPermission('ai:chat-model:delete')")
|
||||
public CommonResult<Boolean> deleteChatModel(@RequestParam("id") Long id) {
|
||||
chatModelService.deleteChatModel(id);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@GetMapping("/get")
|
||||
@Operation(summary = "获得聊天模型")
|
||||
@Parameter(name = "id", description = "编号", required = true, example = "1024")
|
||||
@PreAuthorize("@ss.hasPermission('ai:chat-model:query')")
|
||||
public CommonResult<AiChatModelRespVO> getChatModel(@RequestParam("id") Long id) {
|
||||
AiChatModelDO chatModel = chatModelService.getChatModel(id);
|
||||
return success(BeanUtils.toBean(chatModel, AiChatModelRespVO.class));
|
||||
}
|
||||
|
||||
@GetMapping("/page")
|
||||
@Operation(summary = "获得聊天模型分页")
|
||||
@PreAuthorize("@ss.hasPermission('ai:chat-model:query')")
|
||||
public CommonResult<PageResult<AiChatModelRespVO>> getChatModelPage(@Valid AiChatModelPageReqVO pageReqVO) {
|
||||
PageResult<AiChatModelDO> pageResult = chatModelService.getChatModelPage(pageReqVO);
|
||||
return success(BeanUtils.toBean(pageResult, AiChatModelRespVO.class));
|
||||
}
|
||||
|
||||
@GetMapping("/simple-list")
|
||||
@Operation(summary = "获得聊天模型列表")
|
||||
@Parameter(name = "status", description = "状态", required = true, example = "1")
|
||||
public CommonResult<List<AiChatModelRespVO>> getChatModelSimpleList(@RequestParam("status") Integer status) {
|
||||
List<AiChatModelDO> list = chatModelService.getChatModelListByStatus(status);
|
||||
return success(convertList(list, model -> new AiChatModelRespVO().setId(model.getId())
|
||||
.setName(model.getName()).setModel(model.getModel())));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
import jakarta.validation.Valid;
|
||||
import org.springframework.security.access.prepost.PreAuthorize;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
||||
|
||||
@Tag(name = "管理后台 - AI 模型")
|
||||
@RestController
|
||||
@RequestMapping("/ai/model")
|
||||
@Validated
|
||||
public class AiModelController {
|
||||
|
||||
@Resource
|
||||
private AiModelService modelService;
|
||||
|
||||
@PostMapping("/create")
|
||||
@Operation(summary = "创建模型")
|
||||
@PreAuthorize("@ss.hasPermission('ai:model:create')")
|
||||
public CommonResult<Long> createModel(@Valid @RequestBody AiModelSaveReqVO createReqVO) {
|
||||
return success(modelService.createModel(createReqVO));
|
||||
}
|
||||
|
||||
@PutMapping("/update")
|
||||
@Operation(summary = "更新模型")
|
||||
@PreAuthorize("@ss.hasPermission('ai:model:update')")
|
||||
public CommonResult<Boolean> updateModel(@Valid @RequestBody AiModelSaveReqVO updateReqVO) {
|
||||
modelService.updateModel(updateReqVO);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@DeleteMapping("/delete")
|
||||
@Operation(summary = "删除模型")
|
||||
@Parameter(name = "id", description = "编号", required = true)
|
||||
@PreAuthorize("@ss.hasPermission('ai:model:delete')")
|
||||
public CommonResult<Boolean> deleteModel(@RequestParam("id") Long id) {
|
||||
modelService.deleteModel(id);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@GetMapping("/get")
|
||||
@Operation(summary = "获得模型")
|
||||
@Parameter(name = "id", description = "编号", required = true, example = "1024")
|
||||
@PreAuthorize("@ss.hasPermission('ai:model:query')")
|
||||
public CommonResult<AiModelRespVO> getModel(@RequestParam("id") Long id) {
|
||||
AiModelDO model = modelService.getModel(id);
|
||||
return success(BeanUtils.toBean(model, AiModelRespVO.class));
|
||||
}
|
||||
|
||||
@GetMapping("/page")
|
||||
@Operation(summary = "获得模型分页")
|
||||
@PreAuthorize("@ss.hasPermission('ai:model:query')")
|
||||
public CommonResult<PageResult<AiModelRespVO>> getModelPage(@Valid AiModelPageReqVO pageReqVO) {
|
||||
PageResult<AiModelDO> pageResult = modelService.getModelPage(pageReqVO);
|
||||
return success(BeanUtils.toBean(pageResult, AiModelRespVO.class));
|
||||
}
|
||||
|
||||
@GetMapping("/simple-list")
|
||||
@Operation(summary = "获得模型列表")
|
||||
@Parameter(name = "type", description = "类型", required = true, example = "1")
|
||||
@Parameter(name = "platform", description = "平台", example = "midjourney")
|
||||
public CommonResult<List<AiModelRespVO>> getModelSimpleList(
|
||||
@RequestParam("type") Integer type,
|
||||
@RequestParam(value = "platform", required = false) String platform) {
|
||||
List<AiModelDO> list = modelService.getModelListByStatusAndType(
|
||||
CommonStatusEnum.ENABLE.getStatus(), type, platform);
|
||||
return success(convertList(list, model -> new AiModelRespVO().setId(model.getId())
|
||||
.setName(model.getName()).setModel(model.getModel()).setPlatform(model.getPlatform())));
|
||||
}
|
||||
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import com.fhs.core.trans.anno.Trans;
|
||||
import com.fhs.core.trans.constant.TransType;
|
||||
import com.fhs.core.trans.vo.VO;
|
||||
|
@ -20,7 +20,7 @@ public class AiChatRoleRespVO implements VO {
|
|||
private Long userId;
|
||||
|
||||
@Schema(description = "模型编号", example = "17640")
|
||||
@Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"})
|
||||
@Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"})
|
||||
private Long modelId;
|
||||
@Schema(description = "模型名字", example = "张三")
|
||||
private String modelName;
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
|
||||
|
||||
import lombok.*;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
|
||||
@Schema(description = "管理后台 - API 聊天模型分页 Request VO")
|
||||
@Schema(description = "管理后台 - API 模型分页 Request VO")
|
||||
@Data
|
||||
public class AiChatModelPageReqVO extends PageParam {
|
||||
public class AiModelPageReqVO extends PageParam {
|
||||
|
||||
@Schema(description = "模型名字", example = "张三")
|
||||
private String name;
|
|
@ -1,13 +1,13 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Schema(description = "管理后台 - AI 聊天模型 Response VO")
|
||||
@Schema(description = "管理后台 - AI 模型 Response VO")
|
||||
@Data
|
||||
public class AiChatModelRespVO {
|
||||
public class AiModelRespVO {
|
||||
|
||||
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2630")
|
||||
private Long id;
|
||||
|
@ -24,6 +24,9 @@ public class AiChatModelRespVO {
|
|||
@Schema(description = "模型平台", example = "OpenAI")
|
||||
private String platform;
|
||||
|
||||
@Schema(description = "模型类型", example = "1")
|
||||
private Integer type;
|
||||
|
||||
@Schema(description = "排序", example = "1")
|
||||
private Integer sort;
|
||||
|
|
@ -1,14 +1,17 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.validation.InEnum;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.*;
|
||||
import jakarta.validation.constraints.*;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - API 聊天模型新增/修改 Request VO")
|
||||
@Schema(description = "管理后台 - API 模型新增/修改 Request VO")
|
||||
@Data
|
||||
public class AiChatModelSaveReqVO {
|
||||
public class AiModelSaveReqVO {
|
||||
|
||||
@Schema(description = "编号", example = "2630")
|
||||
private Long id;
|
||||
|
@ -27,8 +30,14 @@ public class AiChatModelSaveReqVO {
|
|||
|
||||
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
|
||||
@NotEmpty(message = "模型平台不能为空")
|
||||
@InEnum(AiPlatformEnum.class)
|
||||
private String platform;
|
||||
|
||||
@Schema(description = "模型类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
@NotNull(message = "模型类型不能为空")
|
||||
@InEnum(AiModelTypeEnum.class)
|
||||
private Integer type;
|
||||
|
||||
@Schema(description = "排序", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
@NotNull(message = "排序不能为空")
|
||||
private Integer sort;
|
|
@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
|
|||
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
|
@ -76,13 +76,13 @@ public class AiChatConversationDO extends BaseDO {
|
|||
/**
|
||||
* 模型编号
|
||||
*
|
||||
* 关联 {@link AiChatModelDO#getId()} 字段
|
||||
* 关联 {@link AiModelDO#getId()} 字段
|
||||
*/
|
||||
private Long modelId;
|
||||
/**
|
||||
* 模型标志
|
||||
*
|
||||
* 冗余 {@link AiChatModelDO#getModel()} 字段
|
||||
* 冗余 {@link AiModelDO#getModel()} 字段
|
||||
*/
|
||||
private String model;
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
|
|||
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
|
@ -83,13 +83,13 @@ public class AiChatMessageDO extends BaseDO {
|
|||
/**
|
||||
* 模型标志
|
||||
*
|
||||
* 冗余 {@link AiChatModelDO#getModel()}
|
||||
* 冗余 {@link AiModelDO#getModel()}
|
||||
*/
|
||||
private String model;
|
||||
/**
|
||||
* 模型编号
|
||||
*
|
||||
* 关联 {@link AiChatModelDO#getId()} 字段
|
||||
* 关联 {@link AiModelDO#getId()} 字段
|
||||
*/
|
||||
private Long modelId;
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.image;
|
|||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
||||
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
|
@ -52,11 +52,16 @@ public class AiImageDO extends BaseDO {
|
|||
* 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
|
||||
*/
|
||||
private String platform;
|
||||
// TODO @芋艿:modelId?
|
||||
/**
|
||||
* 模型
|
||||
* 模型编号
|
||||
*
|
||||
* 冗余 {@link AiChatModelDO#getModel()}
|
||||
* 关联 {@link AiModelDO#getId()}
|
||||
*/
|
||||
private Long modelId;
|
||||
/**
|
||||
* 模型标识
|
||||
*
|
||||
* 冗余 {@link AiModelDO#getModel()}
|
||||
*/
|
||||
private String model;
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
|
|||
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
|
@ -35,13 +35,13 @@ public class AiKnowledgeDO extends BaseDO {
|
|||
/**
|
||||
* 向量模型编号
|
||||
*
|
||||
* 关联 {@link AiChatModelDO#getId()}
|
||||
* 关联 {@link AiModelDO#getId()}
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
/**
|
||||
* 模型标识
|
||||
*
|
||||
* 冗余 {@link AiChatModelDO#getModel()}
|
||||
* 冗余 {@link AiModelDO#getModel()}
|
||||
*/
|
||||
private String embeddingModel;
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap;
|
|||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
|
@ -36,7 +37,12 @@ public class AiMindMapDO extends BaseDO {
|
|||
* 枚举 {@link AiPlatformEnum}
|
||||
*/
|
||||
private String platform;
|
||||
// TODO @芋艿:modelId?
|
||||
/**
|
||||
* 模型编号
|
||||
*
|
||||
* 关联 {@link AiModelDO#getId()}
|
||||
*/
|
||||
private Long modelId;
|
||||
/**
|
||||
* 模型
|
||||
*/
|
||||
|
|
|
@ -58,7 +58,7 @@ public class AiChatRoleDO extends BaseDO {
|
|||
/**
|
||||
* 模型编号
|
||||
*
|
||||
* 关联 {@link AiChatModelDO#getId()} 字段
|
||||
* 关联 {@link AiModelDO#getId()} 字段
|
||||
*/
|
||||
private Long modelId;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package cn.iocoder.yudao.module.ai.dal.dataobject.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
|
@ -8,23 +9,22 @@ import com.baomidou.mybatisplus.annotation.TableId;
|
|||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.*;
|
||||
|
||||
// TODO @芋艿,需要改造,增加 type
|
||||
/**
|
||||
* AI 聊天模型 DO
|
||||
* AI 模型 DO
|
||||
*
|
||||
* 默认聊天模型:{@link #status} 为开启,并且 {@link #sort} 排序第一
|
||||
* 默认模型:{@link #status} 为开启,并且 {@link #sort} 排序第一
|
||||
*
|
||||
* @author fansili
|
||||
* @since 2024/4/24 19:39
|
||||
*/
|
||||
@TableName("ai_chat_model")
|
||||
@KeySequence("ai_chat_model_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
|
||||
@TableName("ai_model")
|
||||
@KeySequence("ai_model_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class AiChatModelDO extends BaseDO {
|
||||
public class AiModelDO extends BaseDO {
|
||||
|
||||
/**
|
||||
* 编号
|
||||
|
@ -51,6 +51,12 @@ public class AiChatModelDO extends BaseDO {
|
|||
* 枚举 {@link AiPlatformEnum}
|
||||
*/
|
||||
private String platform;
|
||||
/**
|
||||
* 类型
|
||||
*
|
||||
* 枚举 {@link AiModelTypeEnum}
|
||||
*/
|
||||
private Integer type;
|
||||
|
||||
/**
|
||||
* 排序值
|
|
@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.write;
|
|||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
|
||||
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
|
@ -44,7 +46,12 @@ public class AiWriteDO extends BaseDO {
|
|||
* 枚举 {@link AiPlatformEnum}
|
||||
*/
|
||||
private String platform;
|
||||
// TODO @芋艿:modelId?
|
||||
/**
|
||||
* 模型编号
|
||||
*
|
||||
* 关联 {@link AiModelDO#getId()}
|
||||
*/
|
||||
private Long modelId;
|
||||
/**
|
||||
* 模型
|
||||
*/
|
||||
|
@ -67,25 +74,25 @@ public class AiWriteDO extends BaseDO {
|
|||
/**
|
||||
* 长度提示词
|
||||
*
|
||||
* 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LENGTH}
|
||||
* 字典:{@link DictTypeConstants#AI_WRITE_LENGTH}
|
||||
*/
|
||||
private Integer length;
|
||||
/**
|
||||
* 格式提示词
|
||||
*
|
||||
* 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_FORMAT}
|
||||
* 字典:{@link DictTypeConstants#AI_WRITE_FORMAT}
|
||||
*/
|
||||
private Integer format;
|
||||
/**
|
||||
* 语气提示词
|
||||
*
|
||||
* 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_TONE}
|
||||
* 字典:{@link DictTypeConstants#AI_WRITE_TONE}
|
||||
*/
|
||||
private Integer tone;
|
||||
/**
|
||||
* 语言提示词
|
||||
*
|
||||
* 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LANGUAGE}
|
||||
* 字典:{@link DictTypeConstants#AI_WRITE_LANGUAGE}
|
||||
*/
|
||||
private Integer language;
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
package cn.iocoder.yudao.module.ai.dal.mysql.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* API 模型 Mapper
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Mapper
|
||||
public interface AiChatMapper extends BaseMapperX<AiModelDO> {
|
||||
|
||||
default AiModelDO selectFirstByStatus(Integer type, Integer status) {
|
||||
return selectOne(new QueryWrapperX<AiModelDO>()
|
||||
.eq("type", type)
|
||||
.eq("status", status)
|
||||
.limitN(1)
|
||||
.orderByAsc("sort"));
|
||||
}
|
||||
|
||||
default PageResult<AiModelDO> selectPage(AiModelPageReqVO reqVO) {
|
||||
return selectPage(reqVO, new LambdaQueryWrapperX<AiModelDO>()
|
||||
.likeIfPresent(AiModelDO::getName, reqVO.getName())
|
||||
.eqIfPresent(AiModelDO::getModel, reqVO.getModel())
|
||||
.eqIfPresent(AiModelDO::getPlatform, reqVO.getPlatform())
|
||||
.orderByAsc(AiModelDO::getSort));
|
||||
}
|
||||
|
||||
default List<AiModelDO> selectListByStatusAndType(Integer status, Integer type,
|
||||
@Nullable String platform) {
|
||||
return selectList(new LambdaQueryWrapperX<AiModelDO>()
|
||||
.eq(AiModelDO::getStatus, status)
|
||||
.eq(AiModelDO::getType, type)
|
||||
.eqIfPresent(AiModelDO::getPlatform, platform)
|
||||
.orderByAsc(AiModelDO::getSort));
|
||||
}
|
||||
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
package cn.iocoder.yudao.module.ai.dal.mysql.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* API 聊天模型 Mapper
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Mapper
|
||||
public interface AiChatModelMapper extends BaseMapperX<AiChatModelDO> {
|
||||
|
||||
default AiChatModelDO selectFirstByStatus(Integer status) {
|
||||
return selectOne(new QueryWrapperX<AiChatModelDO>()
|
||||
.eq("status", status)
|
||||
.limitN(1)
|
||||
.orderByAsc("sort"));
|
||||
}
|
||||
|
||||
default PageResult<AiChatModelDO> selectPage(AiChatModelPageReqVO reqVO) {
|
||||
return selectPage(reqVO, new LambdaQueryWrapperX<AiChatModelDO>()
|
||||
.likeIfPresent(AiChatModelDO::getName, reqVO.getName())
|
||||
.eqIfPresent(AiChatModelDO::getModel, reqVO.getModel())
|
||||
.eqIfPresent(AiChatModelDO::getPlatform, reqVO.getPlatform())
|
||||
.orderByAsc(AiChatModelDO::getSort));
|
||||
}
|
||||
|
||||
default List<AiChatModelDO> selectList(Integer status) {
|
||||
return selectList(new LambdaQueryWrapperX<AiChatModelDO>()
|
||||
.eq(AiChatModelDO::getStatus, status)
|
||||
.orderByAsc(AiChatModelDO::getSort));
|
||||
}
|
||||
|
||||
}
|
|
@ -4,17 +4,18 @@ import cn.hutool.core.collection.CollUtil;
|
|||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.ObjectUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
|
||||
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -44,7 +45,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
|
|||
private AiChatConversationMapper chatConversationMapper;
|
||||
|
||||
@Resource
|
||||
private AiChatModelService chatModalService;
|
||||
private AiModelService chatModalService;
|
||||
@Resource
|
||||
private AiChatRoleService chatRoleService;
|
||||
@Resource
|
||||
|
@ -54,9 +55,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
|
|||
public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) {
|
||||
// 1.1 获得 AiChatRoleDO 聊天角色
|
||||
AiChatRoleDO role = createReqVO.getRoleId() != null ? chatRoleService.validateChatRole(createReqVO.getRoleId()) : null;
|
||||
// 1.2 获得 AiChatModelDO 聊天模型
|
||||
AiChatModelDO model = role != null && role.getModelId() != null ? chatModalService.validateChatModel(role.getModelId())
|
||||
: chatModalService.getRequiredDefaultChatModel();
|
||||
// 1.2 获得 AiModelDO 聊天模型
|
||||
AiModelDO model = role != null && role.getModelId() != null ? chatModalService.validateModel(role.getModelId())
|
||||
: chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
|
||||
Assert.notNull(model, "必须找到默认模型");
|
||||
validateChatModel(model);
|
||||
|
||||
|
@ -86,9 +87,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
|
|||
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
|
||||
}
|
||||
// 1.2 校验模型是否存在(修改模型的情况)
|
||||
AiChatModelDO model = null;
|
||||
AiModelDO model = null;
|
||||
if (updateReqVO.getModelId() != null) {
|
||||
model = chatModalService.validateChatModel(updateReqVO.getModelId());
|
||||
model = chatModalService.validateModel(updateReqVO.getModelId());
|
||||
}
|
||||
|
||||
// 1.3 校验知识库是否存在
|
||||
|
@ -139,7 +140,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
|
|||
chatConversationMapper.deleteById(id);
|
||||
}
|
||||
|
||||
private void validateChatModel(AiChatModelDO model) {
|
||||
private void validateChatModel(AiModelDO model) {
|
||||
if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -15,13 +15,12 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
|
|||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
||||
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
|
@ -63,9 +62,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||
@Resource
|
||||
private AiChatConversationService chatConversationService;
|
||||
@Resource
|
||||
private AiChatModelService chatModalService;
|
||||
@Resource
|
||||
private AiApiKeyService apiKeyService;
|
||||
private AiModelService modalService;
|
||||
@Resource
|
||||
private AiKnowledgeSegmentService knowledgeSegmentService;
|
||||
|
||||
|
@ -78,8 +75,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||
}
|
||||
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
|
||||
// 1.2 校验模型
|
||||
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
|
||||
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
||||
AiModelDO model = modalService.validateModel(conversation.getModelId());
|
||||
ChatModel chatModel = modalService.getChatModel(model.getKeyId());
|
||||
|
||||
// 2. 插入 user 发送消息
|
||||
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
||||
|
@ -112,8 +109,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||
}
|
||||
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
|
||||
// 1.2 校验模型
|
||||
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
|
||||
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
||||
AiModelDO model = modalService.validateModel(conversation.getModelId());
|
||||
StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId());
|
||||
|
||||
// 2. 插入 user 发送消息
|
||||
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
||||
|
@ -161,8 +158,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||
return null;
|
||||
}
|
||||
|
||||
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
|
||||
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
||||
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, List<AiKnowledgeSegmentDO> segmentList,
|
||||
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
||||
// 1. 构建 Prompt Message 列表
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
|
||||
|
@ -232,7 +229,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||
}
|
||||
|
||||
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
|
||||
AiChatModelDO model, Long userId, Long roleId,
|
||||
AiModelDO model, Long userId, Long roleId,
|
||||
MessageType messageType, String content, Boolean useContext) {
|
||||
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
|
||||
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
|
||||
|
|
|
@ -12,13 +12,17 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
||||
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
||||
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
|
||||
import jakarta.annotation.Resource;
|
||||
|
@ -54,15 +58,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
|
|||
@Slf4j
|
||||
public class AiImageServiceImpl implements AiImageService {
|
||||
|
||||
@Resource
|
||||
private AiModelService modelService;
|
||||
|
||||
@Resource
|
||||
private AiImageMapper imageMapper;
|
||||
|
||||
@Resource
|
||||
private FileApi fileApi;
|
||||
|
||||
@Resource
|
||||
private AiApiKeyService apiKeyService;
|
||||
|
||||
@Override
|
||||
public PageResult<AiImageDO> getImagePageMy(Long userId, AiImagePageReqVO pageReqVO) {
|
||||
return imageMapper.selectPageMy(userId, pageReqVO);
|
||||
|
@ -88,23 +92,31 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
|
||||
@Override
|
||||
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
|
||||
// 1. 保存数据库
|
||||
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
||||
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
||||
// 1. 校验模型
|
||||
AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
|
||||
|
||||
// 2. 保存数据库
|
||||
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId)
|
||||
.setPlatform(model.getPlatform()).setModelId(model.getId()).setModel(model.getModel())
|
||||
.setPublicStatus(false).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
||||
imageMapper.insert(image);
|
||||
// 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
|
||||
getSelf().executeDrawImage(image, drawReqVO);
|
||||
|
||||
// 3. 异步绘制,后续前端通过返回的 id 进行轮询结果
|
||||
getSelf().executeDrawImage(image, drawReqVO, model);
|
||||
return image.getId();
|
||||
}
|
||||
|
||||
@Async
|
||||
public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) {
|
||||
public void executeDrawImage(AiImageDO image, AiImageDrawReqVO reqVO, AiModelDO model) {
|
||||
try {
|
||||
// 1.1 构建请求
|
||||
ImageOptions request = buildImageOptions(req);
|
||||
ImageOptions request = buildImageOptions(reqVO, model);
|
||||
// 1.2 执行请求
|
||||
ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform()));
|
||||
ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
|
||||
ImageModel imageModel = modelService.getImageModel(model.getId());
|
||||
ImageResponse response = imageModel.call(new ImagePrompt(reqVO.getPrompt(), request));
|
||||
if (response.getResult() == null) {
|
||||
throw new IllegalArgumentException("生成结果为空");
|
||||
}
|
||||
|
||||
// 2. 上传到文件服务
|
||||
String b64Json = response.getResult().getOutput().getB64Json();
|
||||
|
@ -116,25 +128,25 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus())
|
||||
.setPicUrl(filePath).setFinishTime(LocalDateTime.now()));
|
||||
} catch (Exception ex) {
|
||||
log.error("[doDall][image({}) 生成异常]", image, ex);
|
||||
log.error("[executeDrawImage][image({}) 生成异常]", image, ex);
|
||||
imageMapper.updateById(new AiImageDO().setId(image.getId())
|
||||
.setStatus(AiImageStatusEnum.FAIL.getStatus())
|
||||
.setErrorMessage(ex.getMessage()).setFinishTime(LocalDateTime.now()));
|
||||
}
|
||||
}
|
||||
|
||||
private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) {
|
||||
if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
|
||||
private static ImageOptions buildImageOptions(AiImageDrawReqVO draw, AiModelDO model) {
|
||||
if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
return OpenAiImageOptions.builder().withModel(draw.getModel())
|
||||
return OpenAiImageOptions.builder().withModel(model.getModel())
|
||||
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
||||
.withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
|
||||
.withResponseFormat("b64_json")
|
||||
.build();
|
||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
||||
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
||||
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
|
||||
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
|
||||
return StabilityAiImageOptions.builder().model(draw.getModel())
|
||||
return StabilityAiImageOptions.builder().model(model.getModel())
|
||||
.height(draw.getHeight()).width(draw.getWidth())
|
||||
.seed(Long.valueOf(draw.getOptions().get("seed")))
|
||||
.cfgScale(Float.valueOf(draw.getOptions().get("scale")))
|
||||
|
@ -143,22 +155,22 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
.stylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
|
||||
.clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
|
||||
.build();
|
||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
|
||||
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
|
||||
return DashScopeImageOptions.builder()
|
||||
.withModel(draw.getModel()).withN(1)
|
||||
.withModel(model.getModel()).withN(1)
|
||||
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
||||
.build();
|
||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
|
||||
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
|
||||
return QianFanImageOptions.builder()
|
||||
.model(draw.getModel()).N(1)
|
||||
.model(model.getModel()).N(1)
|
||||
.height(draw.getHeight()).width(draw.getWidth())
|
||||
.build();
|
||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
|
||||
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
|
||||
return ZhiPuAiImageOptions.builder()
|
||||
.model(draw.getModel())
|
||||
.model(model.getModel())
|
||||
.build();
|
||||
}
|
||||
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
|
||||
throw new IllegalArgumentException("不支持的 AI 平台:" + model.getPlatform());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -206,7 +218,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
@Override
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
|
||||
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
|
||||
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
||||
// 1. 保存数据库
|
||||
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
||||
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
|
||||
|
@ -237,7 +249,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
|
||||
@Override
|
||||
public Integer midjourneySync() {
|
||||
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
|
||||
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
||||
// 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image
|
||||
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
|
||||
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
|
||||
|
@ -308,7 +320,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
|
||||
@Override
|
||||
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
|
||||
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
|
||||
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
||||
// 1.1 检查 image
|
||||
AiImageDO image = validateImageExists(reqVO.getId());
|
||||
if (ObjUtil.notEqual(userId, image.getUserId())) {
|
||||
|
|
|
@ -7,14 +7,17 @@ import cn.hutool.core.util.StrUtil;
|
|||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.*;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
|
||||
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.document.Document;
|
||||
|
@ -33,8 +36,8 @@ import java.util.Objects;
|
|||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
|
||||
|
||||
/**
|
||||
* AI 知识库分片 Service 实现类
|
||||
|
@ -58,7 +61,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
|
|||
@Lazy // 延迟加载,避免循环依赖
|
||||
private AiKnowledgeDocumentService knowledgeDocumentService;
|
||||
@Resource
|
||||
private AiApiKeyService apiKeyService;
|
||||
private AiModelService modelService;
|
||||
|
||||
@Resource
|
||||
private TokenCountEstimator tokenCountEstimator;
|
||||
|
@ -180,7 +183,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
|
|||
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
|
||||
|
||||
// 2.1 向量检索
|
||||
VectorStore vectorStore = apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
|
||||
VectorStore vectorStore = getVectorStoreById(knowledge);
|
||||
List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
|
||||
.query(reqBO.getContent())
|
||||
.topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()))
|
||||
|
@ -251,11 +254,12 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
|
|||
}
|
||||
|
||||
private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) {
|
||||
return apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
|
||||
return modelService.getOrCreateVectorStore(knowledge.getEmbeddingModelId());
|
||||
}
|
||||
|
||||
private VectorStore getVectorStoreById(Long knowledgeId) {
|
||||
return getVectorStoreById(knowledgeService.validateKnowledgeExists(knowledgeId));
|
||||
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(knowledgeId);
|
||||
return getVectorStoreById(knowledge);
|
||||
}
|
||||
|
||||
private static List<Document> splitContentByToken(String content, Integer segmentMaxTokens) {
|
||||
|
|
|
@ -5,9 +5,9 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
@ -28,12 +28,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
|
|||
private AiKnowledgeMapper knowledgeMapper;
|
||||
|
||||
@Resource
|
||||
private AiChatModelService chatModelService;
|
||||
private AiModelService chatModelService;
|
||||
|
||||
@Override
|
||||
public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) {
|
||||
// 1. 校验模型配置
|
||||
AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getEmbeddingModelId());
|
||||
AiModelDO model = chatModelService.validateModel(createReqVO.getEmbeddingModelId());
|
||||
|
||||
// 2. 插入知识库
|
||||
AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
|
||||
|
@ -47,7 +47,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
|
|||
// 1.1 校验知识库存在
|
||||
validateKnowledgeExists(updateReqVO.getId());
|
||||
// 1.2 校验模型配置
|
||||
AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getEmbeddingModelId());
|
||||
AiModelDO model = chatModelService.validateModel(updateReqVO.getEmbeddingModelId());
|
||||
|
||||
// 2. 更新知识库
|
||||
AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class)
|
||||
|
|
|
@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.mindmap;
|
|||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
|
@ -12,14 +13,13 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
|||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper;
|
||||
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
|
@ -50,9 +50,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_E
|
|||
public class AiMindMapServiceImpl implements AiMindMapService {
|
||||
|
||||
@Resource
|
||||
private AiApiKeyService apiKeyService;
|
||||
@Resource
|
||||
private AiChatModelService chatModalService;
|
||||
private AiModelService modalService;
|
||||
@Resource
|
||||
private AiChatRoleService chatRoleService;
|
||||
|
||||
|
@ -65,17 +63,17 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
|||
AiChatRoleDO role = CollUtil.getFirst(
|
||||
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
|
||||
// 1.1 获取导图执行模型
|
||||
AiChatModelDO model = getModel(role);
|
||||
AiModelDO model = getModel(role);
|
||||
// 1.2 获取角色设定消息
|
||||
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
|
||||
? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
|
||||
// 1.3 校验平台
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
||||
ChatModel chatModel = modalService.getChatModel(model.getId());
|
||||
|
||||
// 2. 插入思维导图信息
|
||||
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
|
||||
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, mindMap -> mindMap.setUserId(userId)
|
||||
.setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
|
||||
mindMapMapper.insert(mindMapDO);
|
||||
|
||||
// 3.1 构建 Prompt,并进行调用
|
||||
|
@ -103,7 +101,7 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
|||
|
||||
}
|
||||
|
||||
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
|
||||
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
|
||||
// 1. 构建 message 列表
|
||||
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
||||
// 2. 构建 options 对象
|
||||
|
@ -123,13 +121,13 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
|||
return chatMessages;
|
||||
}
|
||||
|
||||
private AiChatModelDO getModel(AiChatRoleDO role) {
|
||||
AiChatModelDO model = null;
|
||||
private AiModelDO getModel(AiChatRoleDO role) {
|
||||
AiModelDO model = null;
|
||||
if (role != null && role.getModelId() != null) {
|
||||
model = chatModalService.getChatModel(role.getModelId());
|
||||
model = modalService.getModel(role.getModelId());
|
||||
}
|
||||
if (model == null) {
|
||||
model = chatModalService.getRequiredDefaultChatModel();
|
||||
model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
|
||||
}
|
||||
Assert.notNull(model, "[AI] 获取不到模型");
|
||||
return model;
|
||||
|
|
|
@ -1,16 +1,10 @@
|
|||
package cn.iocoder.yudao.module.ai.service.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||
import jakarta.validation.Valid;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -74,50 +68,13 @@ public interface AiApiKeyService {
|
|||
*/
|
||||
List<AiApiKeyDO> getApiKeyList();
|
||||
|
||||
// ========== 与 spring-ai 集成 ==========
|
||||
|
||||
/**
|
||||
* 获得 ChatModel 对象
|
||||
*
|
||||
* @param id 编号
|
||||
* @return ChatModel 对象
|
||||
*/
|
||||
ChatModel getChatModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得 ImageModel 对象
|
||||
*
|
||||
* TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择
|
||||
* 获得默认的 API 密钥
|
||||
*
|
||||
* @param platform 平台
|
||||
* @return ImageModel 对象
|
||||
* @param status 状态
|
||||
* @return API 密钥
|
||||
*/
|
||||
ImageModel getImageModel(AiPlatformEnum platform);
|
||||
|
||||
/**
|
||||
* 获得 MidjourneyApi 对象
|
||||
*
|
||||
* TODO 可优化点:目前默认获取 Midjourney 对应的第一个开启的配置用于绘画;后续可以支持配置选择
|
||||
*
|
||||
* @return MidjourneyApi 对象
|
||||
*/
|
||||
MidjourneyApi getMidjourneyApi();
|
||||
|
||||
/**
|
||||
* 获得 SunoApi 对象
|
||||
*
|
||||
* TODO 可优化点:目前默认获取 Suno 对应的第一个开启的配置用于音乐;后续可以支持配置选择
|
||||
*
|
||||
* @return SunoApi 对象
|
||||
*/
|
||||
SunoApi getSunoApi();
|
||||
|
||||
/**
|
||||
* 获得 VectorStore 对象
|
||||
*
|
||||
* @param modelId 编号
|
||||
* @return VectorStore 对象
|
||||
*/
|
||||
VectorStore getOrCreateVectorStoreByModelId(Long modelId);
|
||||
AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status);
|
||||
|
||||
}
|
|
@ -1,31 +1,21 @@
|
|||
package cn.iocoder.yudao.module.ai.service.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_DISABLE;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_NOT_EXISTS;
|
||||
|
||||
/**
|
||||
* AI API 密钥 Service 实现类
|
||||
|
@ -39,14 +29,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
|
|||
@Resource
|
||||
private AiApiKeyMapper apiKeyMapper;
|
||||
|
||||
// TODO @芋艿:后续要不要改?
|
||||
@Resource
|
||||
@Lazy // 延迟加载,解决渲染依赖
|
||||
private AiChatModelService chatModelService;
|
||||
|
||||
@Resource
|
||||
private AiModelFactory modelFactory;
|
||||
|
||||
@Override
|
||||
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
|
||||
// 插入
|
||||
|
@ -105,57 +87,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
|
|||
return apiKeyMapper.selectList();
|
||||
}
|
||||
|
||||
// ========== 与 spring-ai 集成 ==========
|
||||
|
||||
@Override
|
||||
public ChatModel getChatModel(Long id) {
|
||||
AiApiKeyDO apiKey = validateApiKey(id);
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
||||
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageModel getImageModel(AiPlatformEnum platform) {
|
||||
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
|
||||
public AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status) {
|
||||
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform, status);
|
||||
if (apiKey == null) {
|
||||
throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName());
|
||||
throw exception(API_KEY_NOT_EXISTS);
|
||||
}
|
||||
return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||
}
|
||||
|
||||
@Override
|
||||
public MidjourneyApi getMidjourneyApi() {
|
||||
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
|
||||
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
|
||||
if (apiKey == null) {
|
||||
throw exception(API_KEY_MIDJOURNEY_NOT_FOUND);
|
||||
}
|
||||
return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
|
||||
}
|
||||
|
||||
@Override
|
||||
public SunoApi getSunoApi() {
|
||||
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
|
||||
AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
|
||||
if (apiKey == null) {
|
||||
throw exception(API_KEY_SUNO_NOT_FOUND);
|
||||
}
|
||||
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorStore getOrCreateVectorStoreByModelId(Long modelId) {
|
||||
// 获取模型 + 密钥
|
||||
AiChatModelDO chatModel = chatModelService.validateChatModel(modelId);
|
||||
AiApiKeyDO apiKey = validateApiKey(chatModel.getKeyId());
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
||||
|
||||
// 创建或获取 EmbeddingModel 对象
|
||||
EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(),
|
||||
apiKey.getUrl(), chatModel.getModel());
|
||||
|
||||
// 创建或获取 VectorStore 对象
|
||||
return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
|
||||
return apiKey;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,92 +0,0 @@
|
|||
package cn.iocoder.yudao.module.ai.service.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import jakarta.validation.Valid;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* AI 聊天模型 Service 接口
|
||||
*
|
||||
* @author fansili
|
||||
* @since 2024/4/24 19:42
|
||||
*/
|
||||
public interface AiChatModelService {
|
||||
|
||||
/**
|
||||
* 创建聊天模型
|
||||
*
|
||||
* @param createReqVO 创建信息
|
||||
* @return 编号
|
||||
*/
|
||||
Long createChatModel(@Valid AiChatModelSaveReqVO createReqVO);
|
||||
|
||||
/**
|
||||
* 更新聊天模型
|
||||
*
|
||||
* @param updateReqVO 更新信息
|
||||
*/
|
||||
void updateChatModel(@Valid AiChatModelSaveReqVO updateReqVO);
|
||||
|
||||
/**
|
||||
* 删除聊天模型
|
||||
*
|
||||
* @param id 编号
|
||||
*/
|
||||
void deleteChatModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得聊天模型
|
||||
*
|
||||
* @param id 编号
|
||||
* @return 聊天模型
|
||||
*/
|
||||
AiChatModelDO getChatModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得默认的聊天模型
|
||||
*
|
||||
* 如果获取不到,则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
|
||||
*
|
||||
* @return 聊天模型
|
||||
*/
|
||||
AiChatModelDO getRequiredDefaultChatModel();
|
||||
|
||||
/**
|
||||
* 获得聊天模型分页
|
||||
*
|
||||
* @param pageReqVO 分页查询
|
||||
* @return 聊天模型分页
|
||||
*/
|
||||
PageResult<AiChatModelDO> getChatModelPage(AiChatModelPageReqVO pageReqVO);
|
||||
|
||||
/**
|
||||
* 校验聊天模型
|
||||
*
|
||||
* @param id 编号
|
||||
* @return 聊天模型
|
||||
*/
|
||||
AiChatModelDO validateChatModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得聊天模型列表
|
||||
*
|
||||
* @param status 状态
|
||||
* @return 聊天模型列表
|
||||
*/
|
||||
List<AiChatModelDO> getChatModelListByStatus(Integer status);
|
||||
|
||||
/**
|
||||
* 获得聊天模型列表
|
||||
*
|
||||
* @param ids 编号数组
|
||||
* @return 模型列表
|
||||
*/
|
||||
List<AiChatModelDO> getChatModelList(Collection<Long> ids);
|
||||
}
|
|
@ -1,114 +1,168 @@
|
|||
package cn.iocoder.yudao.module.ai.service.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatModelMapper;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatMapper;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
|
||||
|
||||
/**
|
||||
* AI 聊天模型 Service 实现类
|
||||
* AI 模型 Service 实现类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Service
|
||||
@Validated
|
||||
public class AiChatModelServiceImpl implements AiChatModelService {
|
||||
public class AiChatModelServiceImpl implements AiModelService {
|
||||
|
||||
@Resource
|
||||
private AiApiKeyService apiKeyService;
|
||||
|
||||
@Resource
|
||||
private AiChatModelMapper chatModelMapper;
|
||||
private AiChatMapper modelMapper;
|
||||
|
||||
@Resource
|
||||
private AiModelFactory modelFactory;
|
||||
|
||||
@Override
|
||||
public Long createChatModel(AiChatModelSaveReqVO createReqVO) {
|
||||
public Long createModel(AiModelSaveReqVO createReqVO) {
|
||||
// 1. 校验
|
||||
AiPlatformEnum.validatePlatform(createReqVO.getPlatform());
|
||||
apiKeyService.validateApiKey(createReqVO.getKeyId());
|
||||
|
||||
// 2. 插入
|
||||
AiChatModelDO chatModel = BeanUtils.toBean(createReqVO, AiChatModelDO.class);
|
||||
chatModelMapper.insert(chatModel);
|
||||
return chatModel.getId();
|
||||
AiModelDO model = BeanUtils.toBean(createReqVO, AiModelDO.class);
|
||||
modelMapper.insert(model);
|
||||
return model.getId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateChatModel(AiChatModelSaveReqVO updateReqVO) {
|
||||
public void updateModel(AiModelSaveReqVO updateReqVO) {
|
||||
// 1. 校验
|
||||
validateChatModelExists(updateReqVO.getId());
|
||||
validateModelExists(updateReqVO.getId());
|
||||
AiPlatformEnum.validatePlatform(updateReqVO.getPlatform());
|
||||
apiKeyService.validateApiKey(updateReqVO.getKeyId());
|
||||
|
||||
// 2. 更新
|
||||
AiChatModelDO updateObj = BeanUtils.toBean(updateReqVO, AiChatModelDO.class);
|
||||
chatModelMapper.updateById(updateObj);
|
||||
AiModelDO updateObj = BeanUtils.toBean(updateReqVO, AiModelDO.class);
|
||||
modelMapper.updateById(updateObj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteChatModel(Long id) {
|
||||
public void deleteModel(Long id) {
|
||||
// 校验存在
|
||||
validateChatModelExists(id);
|
||||
validateModelExists(id);
|
||||
// 删除
|
||||
chatModelMapper.deleteById(id);
|
||||
modelMapper.deleteById(id);
|
||||
}
|
||||
|
||||
private AiChatModelDO validateChatModelExists(Long id) {
|
||||
AiChatModelDO model = chatModelMapper.selectById(id);
|
||||
if (chatModelMapper.selectById(id) == null) {
|
||||
throw exception(CHAT_MODEL_NOT_EXISTS);
|
||||
private AiModelDO validateModelExists(Long id) {
|
||||
AiModelDO model = modelMapper.selectById(id);
|
||||
if (modelMapper.selectById(id) == null) {
|
||||
throw exception(MODEL_NOT_EXISTS);
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiChatModelDO getChatModel(Long id) {
|
||||
return chatModelMapper.selectById(id);
|
||||
public AiModelDO getModel(Long id) {
|
||||
return modelMapper.selectById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiChatModelDO getRequiredDefaultChatModel() {
|
||||
AiChatModelDO model = chatModelMapper.selectFirstByStatus(CommonStatusEnum.ENABLE.getStatus());
|
||||
public AiModelDO getRequiredDefaultModel(Integer type) {
|
||||
AiModelDO model = modelMapper.selectFirstByStatus(type, CommonStatusEnum.ENABLE.getStatus());
|
||||
if (model == null) {
|
||||
throw exception(CHAT_MODEL_DEFAULT_NOT_EXISTS);
|
||||
throw exception(MODEL_DEFAULT_NOT_EXISTS);
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageResult<AiChatModelDO> getChatModelPage(AiChatModelPageReqVO pageReqVO) {
|
||||
return chatModelMapper.selectPage(pageReqVO);
|
||||
public PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO) {
|
||||
return modelMapper.selectPage(pageReqVO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiChatModelDO validateChatModel(Long id) {
|
||||
AiChatModelDO model = validateChatModelExists(id);
|
||||
public AiModelDO validateModel(Long id) {
|
||||
AiModelDO model = validateModelExists(id);
|
||||
if (CommonStatusEnum.isDisable(model.getStatus())) {
|
||||
throw exception(CHAT_MODEL_DISABLE);
|
||||
throw exception(MODEL_DISABLE);
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AiChatModelDO> getChatModelListByStatus(Integer status) {
|
||||
return chatModelMapper.selectList(status);
|
||||
public List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type,
|
||||
String platform) {
|
||||
return modelMapper.selectListByStatusAndType(status, type, platform);
|
||||
}
|
||||
|
||||
// ========== 与 Spring AI 集成 ==========
|
||||
|
||||
@Override
|
||||
public ChatModel getChatModel(Long id) {
|
||||
AiModelDO model = validateModel(id);
|
||||
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
||||
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AiChatModelDO> getChatModelList(Collection<Long> ids) {
|
||||
return chatModelMapper.selectBatchIds(ids);
|
||||
public ImageModel getImageModel(Long id) {
|
||||
AiModelDO model = validateModel(id);
|
||||
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
||||
return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||
}
|
||||
|
||||
@Override
|
||||
public MidjourneyApi getMidjourneyApi() {
|
||||
AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
|
||||
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
|
||||
return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
|
||||
}
|
||||
|
||||
@Override
|
||||
public SunoApi getSunoApi() {
|
||||
AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
|
||||
AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
|
||||
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorStore getOrCreateVectorStore(Long id) {
|
||||
// 获取模型 + 密钥
|
||||
AiModelDO model = validateModel(id);
|
||||
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
||||
|
||||
// 创建或获取 EmbeddingModel 对象
|
||||
EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(
|
||||
platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel());
|
||||
|
||||
// 创建或获取 VectorStore 对象
|
||||
return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
package cn.iocoder.yudao.module.ai.service.model;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import jakarta.validation.Valid;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* AI 模型 Service 接口
|
||||
*
|
||||
* @author fansili
|
||||
* @since 2024/4/24 19:42
|
||||
*/
|
||||
public interface AiModelService {
|
||||
|
||||
/**
|
||||
* 创建模型
|
||||
*
|
||||
* @param createReqVO 创建信息
|
||||
* @return 编号
|
||||
*/
|
||||
Long createModel(@Valid AiModelSaveReqVO createReqVO);
|
||||
|
||||
/**
|
||||
* 更新模型
|
||||
*
|
||||
* @param updateReqVO 更新信息
|
||||
*/
|
||||
void updateModel(@Valid AiModelSaveReqVO updateReqVO);
|
||||
|
||||
/**
|
||||
* 删除模型
|
||||
*
|
||||
* @param id 编号
|
||||
*/
|
||||
void deleteModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得模型
|
||||
*
|
||||
* @param id 编号
|
||||
* @return 模型
|
||||
*/
|
||||
AiModelDO getModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得默认的模型
|
||||
*
|
||||
* 如果获取不到,则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
|
||||
*
|
||||
* @return 模型
|
||||
*/
|
||||
AiModelDO getRequiredDefaultModel(Integer type);
|
||||
|
||||
/**
|
||||
* 获得模型分页
|
||||
*
|
||||
* @param pageReqVO 分页查询
|
||||
* @return 模型分页
|
||||
*/
|
||||
PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO);
|
||||
|
||||
/**
|
||||
* 校验模型是否可使用
|
||||
*
|
||||
* @param id 编号
|
||||
* @return 模型
|
||||
*/
|
||||
AiModelDO validateModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得模型列表
|
||||
*
|
||||
* @param status 状态
|
||||
* @param type 类型
|
||||
* @param platform 平台,允许空
|
||||
* @return 模型列表
|
||||
*/
|
||||
List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type,
|
||||
@Nullable String platform);
|
||||
|
||||
// ========== 与 Spring AI 集成 ==========
|
||||
|
||||
/**
|
||||
* 获得 ChatModel 对象
|
||||
*
|
||||
* @param id 编号
|
||||
* @return ChatModel 对象
|
||||
*/
|
||||
ChatModel getChatModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得 ImageModel 对象
|
||||
*
|
||||
* @param id 编号
|
||||
* @return ImageModel 对象
|
||||
*/
|
||||
ImageModel getImageModel(Long id);
|
||||
|
||||
/**
|
||||
* 获得 MidjourneyApi 对象
|
||||
*
|
||||
* @return MidjourneyApi 对象
|
||||
*/
|
||||
MidjourneyApi getMidjourneyApi();
|
||||
|
||||
/**
|
||||
* 获得 SunoApi 对象
|
||||
*
|
||||
* @return SunoApi 对象
|
||||
*/
|
||||
SunoApi getSunoApi();
|
||||
|
||||
/**
|
||||
* 获得 VectorStore 对象
|
||||
*
|
||||
* @param id 编号
|
||||
* @return VectorStore 对象
|
||||
*/
|
||||
VectorStore getOrCreateVectorStore(Long id);
|
||||
|
||||
}
|
|
@ -16,7 +16,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
|
|||
import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
|
||||
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -41,7 +41,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MUSIC_NOT_EXIS
|
|||
public class AiMusicServiceImpl implements AiMusicService {
|
||||
|
||||
@Resource
|
||||
private AiApiKeyService apiKeyService;
|
||||
private AiModelService modelService;
|
||||
|
||||
@Resource
|
||||
private AiMusicMapper musicMapper;
|
||||
|
@ -53,7 +53,7 @@ public class AiMusicServiceImpl implements AiMusicService {
|
|||
@Transactional(rollbackFor = Exception.class)
|
||||
public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
|
||||
// 1. 调用 Suno 生成音乐
|
||||
SunoApi sunoApi = apiKeyService.getSunoApi();
|
||||
SunoApi sunoApi = modelService.getSunoApi();
|
||||
List<SunoApi.MusicData> musicDataList;
|
||||
if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
|
||||
// 1.1 描述模式
|
||||
|
@ -88,7 +88,7 @@ public class AiMusicServiceImpl implements AiMusicService {
|
|||
log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size());
|
||||
|
||||
// GET 请求,为避免参数过长,分批次处理
|
||||
SunoApi sunoApi = apiKeyService.getSunoApi();
|
||||
SunoApi sunoApi = modelService.getSunoApi();
|
||||
CollUtil.split(streamingTask, 36).forEach(chunkList -> {
|
||||
Map<String, Long> taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId);
|
||||
List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));
|
||||
|
|
|
@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.write;
|
|||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
|
@ -11,17 +12,16 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|||
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWritePageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
|
||||
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
|
||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -54,17 +54,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.WRITE_NOT_EXIS
|
|||
public class AiWriteServiceImpl implements AiWriteService {
|
||||
|
||||
@Resource
|
||||
private AiApiKeyService apiKeyService;
|
||||
@Resource
|
||||
private AiChatModelService chatModalService;
|
||||
private AiModelService chatModalService;
|
||||
@Resource
|
||||
private AiChatRoleService chatRoleService;
|
||||
|
||||
@Resource
|
||||
private DictDataApi dictDataApi;
|
||||
private AiWriteMapper writeMapper;
|
||||
|
||||
@Resource
|
||||
private AiWriteMapper writeMapper;
|
||||
private DictDataApi dictDataApi;
|
||||
|
||||
@Override
|
||||
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
||||
|
@ -72,17 +70,17 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||
AiChatRoleDO writeRole = CollUtil.getFirst(
|
||||
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
|
||||
// 1.1 获取写作执行模型
|
||||
AiChatModelDO model = getModel(writeRole);
|
||||
AiModelDO model = getModel(writeRole);
|
||||
// 1.2 获取角色设定消息
|
||||
String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
|
||||
? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
|
||||
// 1.3 校验平台
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
||||
StreamingChatModel chatModel = chatModalService.getChatModel(model.getKeyId());
|
||||
|
||||
// 2. 插入写作信息
|
||||
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class,
|
||||
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
|
||||
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId)
|
||||
.setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
|
||||
writeMapper.insert(writeDO);
|
||||
|
||||
// 3.1 构建 Prompt,并进行调用
|
||||
|
@ -109,19 +107,19 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
||||
}
|
||||
|
||||
private AiChatModelDO getModel(AiChatRoleDO writeRole) {
|
||||
AiChatModelDO model = null;
|
||||
private AiModelDO getModel(AiChatRoleDO writeRole) {
|
||||
AiModelDO model = null;
|
||||
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
|
||||
model = chatModalService.getChatModel(writeRole.getModelId());
|
||||
model = chatModalService.getModel(writeRole.getModelId());
|
||||
}
|
||||
if (model == null) {
|
||||
model = chatModalService.getRequiredDefaultChatModel();
|
||||
model = chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
|
||||
}
|
||||
Assert.notNull(model, "[AI] 获取不到模型");
|
||||
return model;
|
||||
}
|
||||
|
||||
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
|
||||
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
|
||||
// 1. 构建 message 列表
|
||||
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
||||
// 2. 构建 options 对象
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.enums;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* AI 模型类型的枚举
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
@Getter
|
||||
@RequiredArgsConstructor
|
||||
public enum AiModelTypeEnum implements ArrayValuable<Integer> {
|
||||
|
||||
CHAT(1, "对话"),
|
||||
IMAGE(2, "图片"),
|
||||
VOICE(3, "语音"),
|
||||
VIDEO(4, "视频"),
|
||||
EMBEDDING(5, "向量"),
|
||||
RERANK(6, "重排序");
|
||||
|
||||
/**
|
||||
* 类型
|
||||
*/
|
||||
private final Integer type;
|
||||
/**
|
||||
* 类型名
|
||||
*/
|
||||
private final String name;
|
||||
|
||||
public static final Integer[] ARRAYS = Arrays.stream(values()).map(AiModelTypeEnum::getType).toArray(Integer[]::new);
|
||||
|
||||
@Override
|
||||
public Integer[] array() {
|
||||
return ARRAYS;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,8 +1,11 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.enums;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* AI 模型平台
|
||||
*
|
||||
|
@ -10,7 +13,7 @@ import lombok.Getter;
|
|||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum AiPlatformEnum {
|
||||
public enum AiPlatformEnum implements ArrayValuable<String> {
|
||||
|
||||
// ========== 国内平台 ==========
|
||||
|
||||
|
@ -44,6 +47,8 @@ public enum AiPlatformEnum {
|
|||
*/
|
||||
private final String name;
|
||||
|
||||
public static final String[] ARRAYS = Arrays.stream(values()).map(AiPlatformEnum::getPlatform).toArray(String[]::new);
|
||||
|
||||
public static AiPlatformEnum validatePlatform(String platform) {
|
||||
for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) {
|
||||
if (platformEnum.getPlatform().equals(platform)) {
|
||||
|
@ -53,4 +58,9 @@ public enum AiPlatformEnum {
|
|||
throw new IllegalArgumentException("非法平台: " + platform);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] array() {
|
||||
return ARRAYS;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -456,25 +456,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
return vectorStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建向量存储文件
|
||||
*
|
||||
* @param embeddingModel 嵌入模型
|
||||
* @return 向量存储文件
|
||||
*/
|
||||
private File createVectorStoreFile(EmbeddingModel embeddingModel) {
|
||||
// 获取简单类名
|
||||
String simpleClassName = embeddingModel.getClass().getSimpleName();
|
||||
// 获取用户主目录
|
||||
String userHome = FileUtil.getUserHomePath();
|
||||
// 创建vector_store目录
|
||||
File vectorStoreDir = new File(userHome, "vector_store");
|
||||
if (!vectorStoreDir.exists()) {
|
||||
vectorStoreDir.mkdirs();
|
||||
}
|
||||
|
||||
// 创建文件
|
||||
return new File(vectorStoreDir, "simple_" + simpleClassName + ".json");
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue