diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatModel.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatModel.java index 501d916db5..330d102a0f 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatModel.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatModel.java @@ -1,163 +1,45 @@ package cn.iocoder.yudao.framework.ai.core.model.xinghuo; -import cn.hutool.core.collection.ListUtil; -import cn.hutool.core.lang.Assert; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; -import org.springframework.ai.retry.RetryUtils; -import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.ai.openai.OpenAiChatModel; import reactor.core.publisher.Flux; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions.MODEL_DEFAULT; - /** * 讯飞星火 {@link ChatModel} 实现类 * * @author fansili */ @Slf4j +@RequiredArgsConstructor public class XingHuoChatModel implements ChatModel { - private static final String BASE_URL = "https://spark-api-open.xf-yun.com"; + public static final String BASE_URL = "https://spark-api-open.xf-yun.com"; - private final XingHuoChatOptions defaultOptions; - private final RetryTemplate retryTemplate; + public static final String MODEL_DEFAULT = "generalv3.5"; /** - * 星火兼容 OpenAI 的 HTTP 接口,所以复用它的实现,简化接入成本 - * - * 不过要注意,星火没有完全兼容,所以不能使用 {@link org.springframework.ai.openai.OpenAiChatModel} 调用,但是实现会参考它 + * 兼容 OpenAI 接口,进行复用 */ - private final OpenAiApi openAiApi; - - public XingHuoChatModel(String apiKey, String secretKey) { - this(apiKey, secretKey, - XingHuoChatOptions.builder().model(MODEL_DEFAULT).temperature(0.7F).build()); - } - - public XingHuoChatModel(String apiKey, String secretKey, XingHuoChatOptions options) { - this(apiKey, secretKey, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); - } - - public XingHuoChatModel(String apiKey, String secretKey, XingHuoChatOptions options, RetryTemplate retryTemplate) { - Assert.notEmpty(apiKey, "apiKey 不能为空"); - Assert.notEmpty(secretKey, "secretKey 不能为空"); - Assert.notNull(options, "options 不能为空"); - Assert.notNull(retryTemplate, "retryTemplate 不能为空"); - this.openAiApi = new OpenAiApi(BASE_URL, apiKey + ":" + secretKey); - this.defaultOptions = options; - this.retryTemplate = retryTemplate; - } + private final OpenAiChatModel openAiChatModel; @Override public ChatResponse call(Prompt prompt) { - OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false); - return this.retryTemplate.execute(ctx -> { - // 1.1 发起调用 - ResponseEntity completionEntity = openAiApi.chatCompletionEntity(request); - // 1.2 校验结果 - OpenAiApi.ChatCompletion chatCompletion = completionEntity.getBody(); - if (chatCompletion == null) { - log.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(ListUtil.of()); - } - List choices = chatCompletion.choices(); - if (choices == null) { - log.warn("No choices returned for prompt: {}", prompt); - return new ChatResponse(ListUtil.of()); - } - - // 2. 转换 ChatResponse 返回 - List generations = choices.stream().map(choice -> { - Generation generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice)); - if (choice.finishReason() != null) { - generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); - } - return generation; - }).toList(); - return new ChatResponse(generations, - OpenAiChatResponseMetadata.from(completionEntity.getBody())); - }); - } - - private Map toMap(String id, OpenAiApi.ChatCompletion.Choice choice) { - Map map = new HashMap<>(); - OpenAiApi.ChatCompletionMessage message = choice.message(); - if (message.role() != null) { - map.put("role", message.role().name()); - } - if (choice.finishReason() != null) { - map.put("finishReason", choice.finishReason().name()); - } - map.put("id", id); - return map; + return openAiChatModel.call(prompt); } @Override public Flux stream(Prompt prompt) { - OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true); - return this.retryTemplate.execute(ctx -> { - // 1. 发起调用 - Flux response = this.openAiApi.chatCompletionStream(request); - return response.map(chatCompletion -> { - String id = chatCompletion.id(); - // 2. 转换 ChatResponse 返回 - List generations = chatCompletion.choices().stream().map(choice -> { - String finish = (choice.finishReason() != null ? choice.finishReason().name() : ""); - Generation generation = new Generation(choice.delta().content(), - Map.of("id", id, "role", choice.delta().role().name(), "finishReason", finish)); - if (choice.finishReason() != null) { - generation = generation.withGenerationMetadata( - ChatGenerationMetadata.from(choice.finishReason().name(), null)); - } - return generation; - }).toList(); - return new ChatResponse(generations); - }); - }); - } - - OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { - // 1. 构建 ChatCompletionMessage 对象 - List chatCompletionMessages = prompt.getInstructions().stream().map(m -> - new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList(); - OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream); - - // 2.1 补充 prompt 内置的 options - if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { - OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, - ChatOptions.class, OpenAiChatOptions.class); - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class); - } else { - throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " - + prompt.getOptions().getClass().getSimpleName()); - } - } - // 2.2 补充默认 options - if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class); - } - return request; + return openAiChatModel.stream(prompt); } @Override public ChatOptions getDefaultOptions() { - return XingHuoChatOptions.fromOptions(defaultOptions); + return openAiChatModel.getDefaultOptions(); } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/XingHuoChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/XingHuoChatModelTests.java index 63f76b96d1..4140b4bfbd 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/XingHuoChatModelTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/XingHuoChatModelTests.java @@ -8,6 +8,9 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; import reactor.core.publisher.Flux; import java.util.ArrayList; @@ -20,9 +23,18 @@ import java.util.List; */ public class XingHuoChatModelTests { - private final XingHuoChatModel chatModel = new XingHuoChatModel( - "cb6415c19d6162cda07b47316fcb0416", - "Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh"); + private static final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder() + .baseUrl(XingHuoChatModel.BASE_URL) + .apiKey("75b161ed2aef4719b275d6e7f2a4d4cd:YWYxYWI2MTA4ODI2NGZlYTQyNjAzZTcz") // appKey:secretKey + .build()) + .defaultOptions(OpenAiChatOptions.builder() + .model("generalv3.5") // 模型 + .temperature(0.7) + .build()) + .build(); + + private final XingHuoChatModel chatModel = new XingHuoChatModel(openAiChatModel); @Test @Disabled