【代码重构】AI:“聊天模型”重构为“模型”,支持 type 模型类型

This commit is contained in:
YunaiV 2025-03-03 21:48:22 +08:00
parent 89d079349c
commit 1c9c9790cd
4 changed files with 33 additions and 27 deletions

View File

@ -13,9 +13,9 @@ public class AiMidjourneyImagineReqVO {
@NotEmpty(message = "提示词不能为空!") @NotEmpty(message = "提示词不能为空!")
private String prompt; private String prompt;
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "midjourney") @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotEmpty(message = "模型不能为空") @NotNull(message = "模型编号不能为空")
private String model; // 参考 MidjourneyApi.ModelEnum private Long modelId;
@Schema(description = "图片宽度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "图片宽度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "图片宽度不能为空") @NotNull(message = "图片宽度不能为空")

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.codec.Base64; import cn.hutool.core.codec.Base64;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.map.MapUtil; import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
@ -217,52 +218,56 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) { public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO drawReqVO) {
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(); // 1. 校验模型
// 1. 保存数据库 AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) Assert.equals(model.getPlatform(), AiPlatformEnum.MIDJOURNEY.getPlatform(), "平台不匹配");
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(model.getId());
// 2. 保存数据库
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()) .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()); .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()).setModelId(model.getId()).setModel(model.getName());
imageMapper.insert(image); imageMapper.insert(image);
// 2. 调用 Midjourney Proxy 提交任务 // 3. 调用 Midjourney Proxy 提交任务
List<String> base64Array = StrUtil.isBlank(reqVO.getReferImageUrl()) ? null : List<String> base64Array = StrUtil.isBlank(drawReqVO.getReferImageUrl()) ? null :
Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(reqVO.getReferImageUrl())))); Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(drawReqVO.getReferImageUrl()))));
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest( MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
base64Array, reqVO.getPrompt(),null, base64Array, drawReqVO.getPrompt(),null,
MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), MidjourneyApi.ImagineRequest.buildState(drawReqVO.getWidth(),
reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel())); drawReqVO.getHeight(), drawReqVO.getVersion(), model.getModel()));
MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest); MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
// 3. 情况一失败抛出业务异常 // 4.1 情况一失败抛出业务异常
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) { if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) {
String description = imagineResponse.description().contains("quota_not_enough") ? String description = imagineResponse.description().contains("quota_not_enough") ?
"账户余额不足" : imagineResponse.description(); "账户余额不足" : imagineResponse.description();
throw exception(IMAGE_MIDJOURNEY_SUBMIT_FAIL, description); throw exception(IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
} }
// 4. 情况二成功更新 taskId 和参数 // 4.2 情况二成功更新 taskId 和参数
imageMapper.updateById(new AiImageDO().setId(image.getId()) imageMapper.updateById(new AiImageDO().setId(image.getId())
.setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(reqVO))); .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(drawReqVO)));
return image.getId(); return image.getId();
} }
@Override @Override
public Integer midjourneySync() { public Integer midjourneySync() {
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
// 1.1 获取 Midjourney 平台状态在 进行中 image // 1.1 获取 Midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform( List<AiImageDO> images = imageMapper.selectListByStatusAndPlatform(
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform()); AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
if (CollUtil.isEmpty(imageList)) { if (CollUtil.isEmpty(images)) {
return 0; return 0;
} }
// 1.2 调用 Midjourney Proxy 获取任务进展 // 1.2 调用 Midjourney Proxy 获取任务进展
List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(imageList, AiImageDO::getTaskId)); MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(images.get(0).getModelId());
List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(images, AiImageDO::getTaskId));
Map<String, MidjourneyApi.Notify> taskMap = convertMap(taskList, MidjourneyApi.Notify::id); Map<String, MidjourneyApi.Notify> taskMap = convertMap(taskList, MidjourneyApi.Notify::id);
// 2. 逐个处理更新进展 // 2. 逐个处理更新进展
int count = 0; int count = 0;
for (AiImageDO image : imageList) { for (AiImageDO image : images) {
MidjourneyApi.Notify notify = taskMap.get(image.getTaskId()); MidjourneyApi.Notify notify = taskMap.get(image.getTaskId());
if (notify == null) { if (notify == null) {
log.error("[midjourneySync][image({}) 查询不到进展]", image); log.error("[midjourneySync][image({}) 查询不到进展]", image);
@ -320,12 +325,12 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) { public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
// 1.1 检查 image // 1.1 检查 image
AiImageDO image = validateImageExists(reqVO.getId()); AiImageDO image = validateImageExists(reqVO.getId());
if (ObjUtil.notEqual(userId, image.getUserId())) { if (ObjUtil.notEqual(userId, image.getUserId())) {
throw exception(IMAGE_NOT_EXISTS); throw exception(IMAGE_NOT_EXISTS);
} }
MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(image.getModelId());
// 1.2 检查 customId // 1.2 检查 customId
MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(), MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(),
buttonX -> buttonX.customId().equals(reqVO.getCustomId())); buttonX -> buttonX.customId().equals(reqVO.getCustomId()));

View File

@ -137,9 +137,9 @@ public class AiChatModelServiceImpl implements AiModelService {
} }
@Override @Override
public MidjourneyApi getMidjourneyApi() { public MidjourneyApi getMidjourneyApi(Long id) {
AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey( AiModelDO model = validateModel(id);
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
} }

View File

@ -109,9 +109,10 @@ public interface AiModelService {
/** /**
* 获得 MidjourneyApi 对象 * 获得 MidjourneyApi 对象
* *
* @param id 编号
* @return MidjourneyApi 对象 * @return MidjourneyApi 对象
*/ */
MidjourneyApi getMidjourneyApi(); MidjourneyApi getMidjourneyApi(Long id);
/** /**
* 获得 SunoApi 对象 * 获得 SunoApi 对象