diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/midjourney/AiMidjourneyImagineReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/midjourney/AiMidjourneyImagineReqVO.java index b90882639d..efb5906157 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/midjourney/AiMidjourneyImagineReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/midjourney/AiMidjourneyImagineReqVO.java @@ -13,9 +13,9 @@ public class AiMidjourneyImagineReqVO { @NotEmpty(message = "提示词不能为空!") private String prompt; - @Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "midjourney") - @NotEmpty(message = "模型不能为空") - private String model; // 参考 MidjourneyApi.ModelEnum + @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") + @NotNull(message = "模型编号不能为空") + private Long modelId; @Schema(description = "图片宽度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @NotNull(message = "图片宽度不能为空") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 7204f4d5b2..60ca9ac996 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.image; import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.codec.Base64; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; import cn.hutool.core.map.MapUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; @@ -217,52 +218,56 @@ public class AiImageServiceImpl implements AiImageService { @Override @Transactional(rollbackFor = Exception.class) - public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) { - MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(); - // 1. 保存数据库 - AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) + public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO drawReqVO) { + // 1. 校验模型 + AiModelDO model = modelService.validateModel(drawReqVO.getModelId()); + Assert.equals(model.getPlatform(), AiPlatformEnum.MIDJOURNEY.getPlatform(), "平台不匹配"); + MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(model.getId()); + + // 2. 保存数据库 + AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()) - .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()); + .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()).setModelId(model.getId()).setModel(model.getName()); imageMapper.insert(image); - // 2. 调用 Midjourney Proxy 提交任务 - List base64Array = StrUtil.isBlank(reqVO.getReferImageUrl()) ? null : - Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(reqVO.getReferImageUrl())))); + // 3. 调用 Midjourney Proxy 提交任务 + List base64Array = StrUtil.isBlank(drawReqVO.getReferImageUrl()) ? null : + Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(drawReqVO.getReferImageUrl())))); MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest( - base64Array, reqVO.getPrompt(),null, - MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), - reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel())); + base64Array, drawReqVO.getPrompt(),null, + MidjourneyApi.ImagineRequest.buildState(drawReqVO.getWidth(), + drawReqVO.getHeight(), drawReqVO.getVersion(), model.getModel())); MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest); - // 3. 情况一【失败】:抛出业务异常 + // 4.1 情况一【失败】:抛出业务异常 if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) { String description = imagineResponse.description().contains("quota_not_enough") ? "账户余额不足" : imagineResponse.description(); throw exception(IMAGE_MIDJOURNEY_SUBMIT_FAIL, description); } - // 4. 情况二【成功】:更新 taskId 和参数 + // 4.2 情况二【成功】:更新 taskId 和参数 imageMapper.updateById(new AiImageDO().setId(image.getId()) - .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(reqVO))); + .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(drawReqVO))); return image.getId(); } @Override public Integer midjourneySync() { - MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(); // 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image - List imageList = imageMapper.selectListByStatusAndPlatform( + List images = imageMapper.selectListByStatusAndPlatform( AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform()); - if (CollUtil.isEmpty(imageList)) { + if (CollUtil.isEmpty(images)) { return 0; } // 1.2 调用 Midjourney Proxy 获取任务进展 - List taskList = midjourneyApi.getTaskList(convertSet(imageList, AiImageDO::getTaskId)); + MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(images.get(0).getModelId()); + List taskList = midjourneyApi.getTaskList(convertSet(images, AiImageDO::getTaskId)); Map taskMap = convertMap(taskList, MidjourneyApi.Notify::id); // 2. 逐个处理,更新进展 int count = 0; - for (AiImageDO image : imageList) { + for (AiImageDO image : images) { MidjourneyApi.Notify notify = taskMap.get(image.getTaskId()); if (notify == null) { log.error("[midjourneySync][image({}) 查询不到进展]", image); @@ -320,12 +325,12 @@ public class AiImageServiceImpl implements AiImageService { @Override public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) { - MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(); // 1.1 检查 image AiImageDO image = validateImageExists(reqVO.getId()); if (ObjUtil.notEqual(userId, image.getUserId())) { throw exception(IMAGE_NOT_EXISTS); } + MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(image.getModelId()); // 1.2 检查 customId MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(), buttonX -> buttonX.customId().equals(reqVO.getCustomId())); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java index 9140882a1f..2b38a68cf7 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java @@ -137,9 +137,9 @@ public class AiChatModelServiceImpl implements AiModelService { } @Override - public MidjourneyApi getMidjourneyApi() { - AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey( - AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); + public MidjourneyApi getMidjourneyApi(Long id) { + AiModelDO model = validateModel(id); + AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId()); return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java index e3746b14d2..9e6185d450 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java @@ -109,9 +109,10 @@ public interface AiModelService { /** * 获得 MidjourneyApi 对象 * + * @param id 编号 * @return MidjourneyApi 对象 */ - MidjourneyApi getMidjourneyApi(); + MidjourneyApi getMidjourneyApi(Long id); /** * 获得 SunoApi 对象