【功能新增】AI:百川模型的接入

This commit is contained in:
YunaiV 2025-03-23 12:18:37 +08:00
parent ef5e56d560
commit cd4813f7dd
8 changed files with 177 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
@ -193,6 +194,33 @@ public class YudaoAiAutoConfiguration {
return new XingHuoChatModel(openAiChatModel);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.baichuan.enable", havingValue = "true")
public BaiChuanChatModel baiChuanChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.BaiChuanProperties properties = yudaoAiProperties.getBaichuan();
return buildBaiChuanChatClient(properties);
}
public BaiChuanChatModel buildBaiChuanChatClient(YudaoAiProperties.BaiChuanProperties properties) {
if (StrUtil.isEmpty(properties.getModel())) {
properties.setModel(BaiChuanChatModel.MODEL_DEFAULT);
}
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(BaiChuanChatModel.BASE_URL)
.apiKey(properties.getApiKey())
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topP(properties.getTopP())
.build())
.toolCallingManager(getToolCallingManager())
.build();
return new BaiChuanChatModel(openAiChatModel);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {

View File

@ -43,6 +43,12 @@ public class YudaoAiProperties {
@SuppressWarnings("SpellCheckingInspection")
private XingHuoProperties xinghuo;
/**
* 百川
*/
@SuppressWarnings("SpellCheckingInspection")
private BaiChuanProperties baichuan;
/**
* Midjourney 绘图
*/
@ -122,6 +128,19 @@ public class YudaoAiProperties {
}
@Data
public static class BaiChuanProperties {
private String enable;
private String apiKey;
private String model;
private Double temperature;
private Integer maxTokens;
private Double topP;
}
@Data
public static class MidjourneyProperties {

View File

@ -27,6 +27,7 @@ public enum AiPlatformEnum implements ArrayValuable<String> {
SILICON_FLOW("SiliconFlow", "硅基流动"), // 硅基流动
MINI_MAX("MiniMax", "MiniMax"), // 稀宇科技
MOONSHOT("Moonshot", "月之暗灭"), // KIMI
BAI_CHUAN("BaiChuan", "百川智能"), // 百川智能
// ========== 国外平台 ==========

View File

@ -11,6 +11,7 @@ import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
@ -150,6 +151,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildMoonshotChatModel(apiKey, url);
case XING_HUO:
return buildXingHuoChatModel(apiKey);
case BAI_CHUAN:
return buildBaiChuanChatModel(apiKey);
case OPENAI:
return buildOpenAiChatModel(apiKey, url);
case AZURE_OPENAI:
@ -186,6 +189,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return SpringUtil.getBean(MoonshotChatModel.class);
case XING_HUO:
return SpringUtil.getBean(XingHuoChatModel.class);
case BAI_CHUAN:
return SpringUtil.getBean(AzureOpenAiChatModel.class);
case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class);
case AZURE_OPENAI:
@ -441,6 +446,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new YudaoAiAutoConfiguration().buildXingHuoChatClient(properties);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#baiChuanChatClient(YudaoAiProperties)}
*/
private BaiChuanChatModel buildBaiChuanChatModel(String apiKey) {
YudaoAiProperties.BaiChuanProperties properties = new YudaoAiProperties.BaiChuanProperties()
.setApiKey(apiKey);
return new YudaoAiAutoConfiguration().buildBaiChuanChatClient(properties);
}
/**
* 可参考 {@link OpenAiAutoConfiguration} openAiChatModel 方法
*/

View File

@ -0,0 +1,45 @@
package cn.iocoder.yudao.framework.ai.core.model.baichuan;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import reactor.core.publisher.Flux;
/**
* 百川 {@link ChatModel} 实现类
*
* @author 芋道源码
*/
@Slf4j
@RequiredArgsConstructor
public class BaiChuanChatModel implements ChatModel {
public static final String BASE_URL = "https://api.baichuan-ai.com";
public static final String MODEL_DEFAULT = "Baichuan4-Turbo";
/**
* 兼容 OpenAI 接口进行复用
*/
private final OpenAiChatModel openAiChatModel;
@Override
public ChatResponse call(Prompt prompt) {
return openAiChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return openAiChatModel.stream(prompt);
}
@Override
public ChatOptions getDefaultOptions() {
return openAiChatModel.getDefaultOptions();
}
}

View File

@ -149,7 +149,7 @@ public class SiliconFlowImageModel implements ImageModel {
.batchSize(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
.width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
.height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
// Handle OpenAI specific image options
// Handle SiliconFlow specific image options
.negativePrompt(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNegativePrompt(), defaultOptions.getNegativePrompt()))
.numInferenceSteps(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNumInferenceSteps(), defaultOptions.getNumInferenceSteps()))
.guidanceScale(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getGuidanceScale(), defaultOptions.getGuidanceScale()))

View File

@ -50,6 +50,7 @@ public class AiUtils {
case HUN_YUAN: // 复用 OpenAI 客户端
case XING_HUO: // 复用 OpenAI 客户端
case SILICON_FLOW: // 复用 OpenAI 客户端
case BAI_CHUAN: // 复用 OpenAI 客户端
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).build();
case AZURE_OPENAI:

View File

@ -0,0 +1,68 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link BaiChuanChatModel} 集成测试
*
* @author 芋道源码
*/
public class BaiChuanChatModelTests {
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(BaiChuanChatModel.BASE_URL)
.apiKey("sk-61b6766a94c70786ed02673f5e16af3c") // apiKey
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model("Baichuan4-Turbo") // 模型https://platform.baichuan-ai.com/docs/api
.temperature(0.7)
.build())
.build();
private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel);
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(System.out::println).then().block();
}
}