diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 3e8c09f079..523014d530 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -20,6 +20,7 @@ import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.infra.api.file.FileApi; +import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.image.ImageModel; @@ -133,14 +134,14 @@ public class AiImageServiceImpl implements AiImageService { } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) { // https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage // https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage - return StabilityAiImageOptions.builder().withModel(draw.getModel()) - .withHeight(draw.getHeight()).withWidth(draw.getWidth()) - .withSeed(Long.valueOf(draw.getOptions().get("seed"))) - .withCfgScale(Float.valueOf(draw.getOptions().get("scale"))) - .withSteps(Integer.valueOf(draw.getOptions().get("steps"))) - .withSampler(String.valueOf(draw.getOptions().get("sampler"))) - .withStylePreset(String.valueOf(draw.getOptions().get("stylePreset"))) - .withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset"))) + return StabilityAiImageOptions.builder().model(draw.getModel()) + .height(draw.getHeight()).width(draw.getWidth()) + .seed(Long.valueOf(draw.getOptions().get("seed"))) + .cfgScale(Float.valueOf(draw.getOptions().get("scale"))) + .steps(Integer.valueOf(draw.getOptions().get("steps"))) + .sampler(String.valueOf(draw.getOptions().get("sampler"))) + .stylePreset(String.valueOf(draw.getOptions().get("stylePreset"))) + .clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset"))) .build(); } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) { return DashScopeImageOptions.builder() @@ -149,12 +150,12 @@ public class AiImageServiceImpl implements AiImageService { .build(); } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) { return QianFanImageOptions.builder() - .withModel(draw.getModel()).withN(1) - .withHeight(draw.getHeight()).withWidth(draw.getWidth()) + .model(draw.getModel()).N(1) + .height(draw.getHeight()).width(draw.getWidth()) .build(); } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) { return ZhiPuAiImageOptions.builder() - .withModel(draw.getModel()) + .model(draw.getModel()) .build(); } throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index 3aed3ee1b9..d9b143bb04 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -35,7 +35,6 @@ public class YudaoAiAutoConfiguration { return new AiModelFactoryImpl(); } - // ========== 各种 AI Client 创建 ========== @Bean diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java index a72b556fcd..497a6fe9a9 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java @@ -1,6 +1,20 @@ package cn.iocoder.yudao.framework.ai.chat; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; /** * {@link OllamaChatModel} 集成测试 @@ -9,41 +23,43 @@ import org.springframework.ai.ollama.OllamaChatModel; */ public class LlamaChatModelTests { -// private final OllamaApi ollamaApi = new OllamaApi( -// "http://127.0.0.1:11434"); -// private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi, -// OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName())); -// -// @Test -// @Disabled -// public void testCall() { -// // 准备参数 -// List messages = new ArrayList<>(); -// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); -// messages.add(new UserMessage("1 + 1 = ?")); -// -// // 调用 -// ChatResponse response = chatModel.call(new Prompt(messages)); -// // 打印结果 -// System.out.println(response); -// System.out.println(response.getResult().getOutput()); -// } -// -// @Test -// @Disabled -// public void testStream() { -// // 准备参数 -// List messages = new ArrayList<>(); -// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); -// messages.add(new UserMessage("1 + 1 = ?")); -// -// // 调用 -// Flux flux = chatModel.stream(new Prompt(messages)); -// // 打印结果 -// flux.doOnNext(response -> { -//// System.out.println(response); -// System.out.println(response.getResult().getOutput()); -// }).then().block(); -// } + private final OllamaChatModel chatModel = OllamaChatModel.builder() + .ollamaApi(new OllamaApi("http://127.0.0.1:11434")) // Ollama 服务地址 + .defaultOptions(OllamaOptions.builder() + .model(OllamaModel.LLAMA3.getName()) // 模型 + .build()) + .build(); + + @Test + @Disabled + public void testCall() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + ChatResponse response = chatModel.call(new Prompt(messages)); + // 打印结果 + System.out.println(response); + System.out.println(response.getResult().getOutput()); + } + + @Test + @Disabled + public void testStream() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + Flux flux = chatModel.stream(new Prompt(messages)); + // 打印结果 + flux.doOnNext(response -> { +// System.out.println(response); + System.out.println(response.getResult().getOutput()); + }).then().block(); + } }