From cd4813f7dd78c9ce6bf256bbc351ff622905641f Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sun, 23 Mar 2025 12:18:37 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=8A=9F=E8=83=BD=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E7=99=BE=E5=B7=9D=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E7=9A=84=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/config/YudaoAiAutoConfiguration.java | 28 ++++++++ .../ai/config/YudaoAiProperties.java | 19 ++++++ .../ai/core/enums/AiPlatformEnum.java | 1 + .../ai/core/factory/AiModelFactoryImpl.java | 14 ++++ .../model/baichuan/BaiChuanChatModel.java | 45 ++++++++++++ .../siliconflow/SiliconFlowImageModel.java | 2 +- .../yudao/framework/ai/core/util/AiUtils.java | 1 + .../ai/chat/BaiChuanChatModelTests.java | 68 +++++++++++++++++++ 8 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/baichuan/BaiChuanChatModel.java create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/BaiChuanChatModelTests.java diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index e014a4cd9f..a454e40e8b 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -4,6 +4,7 @@ import cn.hutool.core.util.StrUtil; import cn.hutool.extra.spring.SpringUtil; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl; +import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel; import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel; @@ -193,6 +194,33 @@ public class YudaoAiAutoConfiguration { return new XingHuoChatModel(openAiChatModel); } + @Bean + @ConditionalOnProperty(value = "yudao.ai.baichuan.enable", havingValue = "true") + public BaiChuanChatModel baiChuanChatClient(YudaoAiProperties yudaoAiProperties) { + YudaoAiProperties.BaiChuanProperties properties = yudaoAiProperties.getBaichuan(); + return buildBaiChuanChatClient(properties); + } + + public BaiChuanChatModel buildBaiChuanChatClient(YudaoAiProperties.BaiChuanProperties properties) { + if (StrUtil.isEmpty(properties.getModel())) { + properties.setModel(BaiChuanChatModel.MODEL_DEFAULT); + } + OpenAiChatModel openAiChatModel = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder() + .baseUrl(BaiChuanChatModel.BASE_URL) + .apiKey(properties.getApiKey()) + .build()) + .defaultOptions(OpenAiChatOptions.builder() + .model(properties.getModel()) + .temperature(properties.getTemperature()) + .maxTokens(properties.getMaxTokens()) + .topP(properties.getTopP()) + .build()) + .toolCallingManager(getToolCallingManager()) + .build(); + return new BaiChuanChatModel(openAiChatModel); + } + @Bean @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true") public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) { diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java index 296e0af8bc..86d1084ccc 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java @@ -43,6 +43,12 @@ public class YudaoAiProperties { @SuppressWarnings("SpellCheckingInspection") private XingHuoProperties xinghuo; + /** + * 百川 + */ + @SuppressWarnings("SpellCheckingInspection") + private BaiChuanProperties baichuan; + /** * Midjourney 绘图 */ @@ -122,6 +128,19 @@ public class YudaoAiProperties { } + @Data + public static class BaiChuanProperties { + + private String enable; + private String apiKey; + + private String model; + private Double temperature; + private Integer maxTokens; + private Double topP; + + } + @Data public static class MidjourneyProperties { diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java index 5a8a5c4539..be65f2986f 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java @@ -27,6 +27,7 @@ public enum AiPlatformEnum implements ArrayValuable { SILICON_FLOW("SiliconFlow", "硅基流动"), // 硅基流动 MINI_MAX("MiniMax", "MiniMax"), // 稀宇科技 MOONSHOT("Moonshot", "月之暗灭"), // KIMI + BAI_CHUAN("BaiChuan", "百川智能"), // 百川智能 // ========== 国外平台 ========== diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index 3c9f51cf63..6d664eb65f 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -11,6 +11,7 @@ import cn.hutool.extra.spring.SpringUtil; import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration; import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel; import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel; @@ -150,6 +151,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildMoonshotChatModel(apiKey, url); case XING_HUO: return buildXingHuoChatModel(apiKey); + case BAI_CHUAN: + return buildBaiChuanChatModel(apiKey); case OPENAI: return buildOpenAiChatModel(apiKey, url); case AZURE_OPENAI: @@ -186,6 +189,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return SpringUtil.getBean(MoonshotChatModel.class); case XING_HUO: return SpringUtil.getBean(XingHuoChatModel.class); + case BAI_CHUAN: + return SpringUtil.getBean(AzureOpenAiChatModel.class); case OPENAI: return SpringUtil.getBean(OpenAiChatModel.class); case AZURE_OPENAI: @@ -441,6 +446,15 @@ public class AiModelFactoryImpl implements AiModelFactory { return new YudaoAiAutoConfiguration().buildXingHuoChatClient(properties); } + /** + * 可参考 {@link YudaoAiAutoConfiguration#baiChuanChatClient(YudaoAiProperties)} + */ + private BaiChuanChatModel buildBaiChuanChatModel(String apiKey) { + YudaoAiProperties.BaiChuanProperties properties = new YudaoAiProperties.BaiChuanProperties() + .setApiKey(apiKey); + return new YudaoAiAutoConfiguration().buildBaiChuanChatClient(properties); + } + /** * 可参考 {@link OpenAiAutoConfiguration} 的 openAiChatModel 方法 */ diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/baichuan/BaiChuanChatModel.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/baichuan/BaiChuanChatModel.java new file mode 100644 index 0000000000..ac59b70266 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/baichuan/BaiChuanChatModel.java @@ -0,0 +1,45 @@ +package cn.iocoder.yudao.framework.ai.core.model.baichuan; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.OpenAiChatModel; +import reactor.core.publisher.Flux; + +/** + * 百川 {@link ChatModel} 实现类 + * + * @author 芋道源码 + */ +@Slf4j +@RequiredArgsConstructor +public class BaiChuanChatModel implements ChatModel { + + public static final String BASE_URL = "https://api.baichuan-ai.com"; + + public static final String MODEL_DEFAULT = "Baichuan4-Turbo"; + + /** + * 兼容 OpenAI 接口,进行复用 + */ + private final OpenAiChatModel openAiChatModel; + + @Override + public ChatResponse call(Prompt prompt) { + return openAiChatModel.call(prompt); + } + + @Override + public Flux stream(Prompt prompt) { + return openAiChatModel.stream(prompt); + } + + @Override + public ChatOptions getDefaultOptions() { + return openAiChatModel.getDefaultOptions(); + } + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/siliconflow/SiliconFlowImageModel.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/siliconflow/SiliconFlowImageModel.java index e345ebaf8f..235699ee66 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/siliconflow/SiliconFlowImageModel.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/siliconflow/SiliconFlowImageModel.java @@ -149,7 +149,7 @@ public class SiliconFlowImageModel implements ImageModel { .batchSize(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN())) .width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth())) .height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight())) - // Handle OpenAI specific image options + // Handle SiliconFlow specific image options .negativePrompt(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNegativePrompt(), defaultOptions.getNegativePrompt())) .numInferenceSteps(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNumInferenceSteps(), defaultOptions.getNumInferenceSteps())) .guidanceScale(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getGuidanceScale(), defaultOptions.getGuidanceScale())) diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java index becc54ee43..3b858b4bc5 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java @@ -50,6 +50,7 @@ public class AiUtils { case HUN_YUAN: // 复用 OpenAI 客户端 case XING_HUO: // 复用 OpenAI 客户端 case SILICON_FLOW: // 复用 OpenAI 客户端 + case BAI_CHUAN: // 复用 OpenAI 客户端 return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) .toolNames(toolNames).build(); case AZURE_OPENAI: diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/BaiChuanChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/BaiChuanChatModelTests.java new file mode 100644 index 0000000000..9ae36dbb87 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/BaiChuanChatModelTests.java @@ -0,0 +1,68 @@ +package cn.iocoder.yudao.framework.ai.chat; + +import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel; +import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@link BaiChuanChatModel} 集成测试 + * + * @author 芋道源码 + */ +public class BaiChuanChatModelTests { + + private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder() + .baseUrl(BaiChuanChatModel.BASE_URL) + .apiKey("sk-61b6766a94c70786ed02673f5e16af3c") // apiKey + .build()) + .defaultOptions(OpenAiChatOptions.builder() + .model("Baichuan4-Turbo") // 模型(https://platform.baichuan-ai.com/docs/api) + .temperature(0.7) + .build()) + .build(); + + private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel); + + @Test + @Disabled + public void testCall() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + ChatResponse response = chatModel.call(new Prompt(messages)); + // 打印结果 + System.out.println(response); + } + + @Test + @Disabled + public void testStream() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + Flux flux = chatModel.stream(new Prompt(messages)); + // 打印结果 + flux.doOnNext(System.out::println).then().block(); + } + +}