From f582c9cfa3901921d3b6f75a85ab17bf42cb9fde Mon Sep 17 00:00:00 2001 From: YunaiV Date: Fri, 21 Feb 2025 13:46:54 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E4=BD=BF=E7=94=A8=20OpenAiApi=20=E6=8E=A5?= =?UTF-8?q?=E5=85=A5=20deepseek?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model/deepseek/DeepSeekChatModel.java | 141 ++---------------- .../model/deepseek/DeepSeekChatOptions.java | 55 ------- .../model/xinghuo/XingHuoChatOptions.java | 55 ------- .../ai/chat/DeepSeekChatModelTests.java | 16 +- 4 files changed, 25 insertions(+), 242 deletions(-) delete mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatOptions.java delete mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatOptions.java diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatModel.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatModel.java index e3097b83a3..a136b5a2b5 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatModel.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatModel.java @@ -1,166 +1,45 @@ package cn.iocoder.yudao.framework.ai.core.model.deepseek; -import cn.hutool.core.collection.ListUtil; -import cn.hutool.core.lang.Assert; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; -import org.springframework.ai.retry.RetryUtils; -import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.ai.openai.OpenAiChatModel; import reactor.core.publisher.Flux; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions.MODEL_DEFAULT; - /** * DeepSeek {@link ChatModel} 实现类 * * @author fansili */ @Slf4j +@RequiredArgsConstructor public class DeepSeekChatModel implements ChatModel { - private static final String BASE_URL = "https://api.deepseek.com"; + public static final String BASE_URL = "https://api.deepseek.com"; - private final DeepSeekChatOptions defaultOptions; - private final RetryTemplate retryTemplate; + public static final String MODEL_DEFAULT = "deepseek-chat"; /** - * DeepSeek 兼容 OpenAI 的 HTTP 接口,所以复用它的实现,简化接入成本 - * - * 不过要注意,DeepSeek 没有完全兼容,所以不能使用 {@link org.springframework.ai.openai.OpenAiChatModel} 调用,但是实现会参考它 + * 兼容 OpenAI 接口,进行复用 */ - private final OpenAiApi openAiApi; - - public DeepSeekChatModel(String apiKey) { - this(apiKey, DeepSeekChatOptions.builder().model(MODEL_DEFAULT).temperature(0.7F).build()); - } - - public DeepSeekChatModel(String apiKey, DeepSeekChatOptions options) { - this(apiKey, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); - } - - public DeepSeekChatModel(String apiKey, DeepSeekChatOptions options, RetryTemplate retryTemplate) { - Assert.notEmpty(apiKey, "apiKey 不能为空"); - Assert.notNull(options, "options 不能为空"); - Assert.notNull(retryTemplate, "retryTemplate 不能为空"); - this.openAiApi = new OpenAiApi(BASE_URL, apiKey); - this.defaultOptions = options; - this.retryTemplate = retryTemplate; - } + private final OpenAiChatModel openAiChatModel; @Override public ChatResponse call(Prompt prompt) { - OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false); - return this.retryTemplate.execute(ctx -> { - // 1.1 发起调用 - ResponseEntity completionEntity = openAiApi.chatCompletionEntity(request); - // 1.2 校验结果 - OpenAiApi.ChatCompletion chatCompletion = completionEntity.getBody(); - if (chatCompletion == null) { - log.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(ListUtil.of()); - } - List choices = chatCompletion.choices(); - if (choices == null) { - log.warn("No choices returned for prompt: {}", prompt); - return new ChatResponse(ListUtil.of()); - } - - // 2. 转换 ChatResponse 返回 - List generations = choices.stream().map(choice -> { - Generation generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice)); - if (choice.finishReason() != null) { - generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); - } - return generation; - }).toList(); - return new ChatResponse(generations, - OpenAiChatResponseMetadata.from(completionEntity.getBody())); - }); - } - - private Map toMap(String id, OpenAiApi.ChatCompletion.Choice choice) { - Map map = new HashMap<>(); - OpenAiApi.ChatCompletionMessage message = choice.message(); - if (message.role() != null) { - map.put("role", message.role().name()); - } - if (choice.finishReason() != null) { - map.put("finishReason", choice.finishReason().name()); - } - map.put("id", id); - return map; + return openAiChatModel.call(prompt); } @Override public Flux stream(Prompt prompt) { - OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true); - return this.retryTemplate.execute(ctx -> { - // 1. 发起调用 - Flux response = this.openAiApi.chatCompletionStream(request); - return response.map(chatCompletion -> { - String id = chatCompletion.id(); - // 2. 转换 ChatResponse 返回 - List generations = chatCompletion.choices().stream().map(choice -> { - String finish = (choice.finishReason() != null ? choice.finishReason().name() : ""); - String role = (choice.delta().role() != null ? choice.delta().role().name() : ""); - if (choice.finishReason() == OpenAiApi.ChatCompletionFinishReason.STOP) { - // 兜底处理 DeepSeek 返回 STOP 时,role 为空的情况 - role = OpenAiApi.ChatCompletionMessage.Role.ASSISTANT.name(); - } - Generation generation = new Generation(choice.delta().content(), - Map.of("id", id, "role", role, "finishReason", finish)); - if (choice.finishReason() != null) { - generation = generation.withGenerationMetadata( - ChatGenerationMetadata.from(choice.finishReason().name(), null)); - } - return generation; - }).toList(); - return new ChatResponse(generations); - }); - }); - } - - OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { - // 1. 构建 ChatCompletionMessage 对象 - List chatCompletionMessages = prompt.getInstructions().stream().map(m -> - new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList(); - OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream); - - // 2.1 补充 prompt 内置的 options - if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { - OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, - ChatOptions.class, OpenAiChatOptions.class); - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class); - } else { - throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " - + prompt.getOptions().getClass().getSimpleName()); - } - } - // 2.2 补充默认 options - if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class); - } - return request; + return openAiChatModel.stream(prompt); } @Override public ChatOptions getDefaultOptions() { - return DeepSeekChatOptions.fromOptions(defaultOptions); + return openAiChatModel.getDefaultOptions(); } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatOptions.java deleted file mode 100644 index e07e3f0865..0000000000 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatOptions.java +++ /dev/null @@ -1,55 +0,0 @@ -package cn.iocoder.yudao.framework.ai.core.model.deepseek; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.springframework.ai.chat.prompt.ChatOptions; - -/** - * DeepSeek {@link ChatOptions} 实现类 - * - * 参考文档:快速开始 - * - * @author fansili - */ -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class DeepSeekChatOptions implements ChatOptions { - - public static final String MODEL_DEFAULT = "deepseek-chat"; - - /** - * 模型 - */ - private String model; - /** - * 温度 - */ - private Float temperature; - /** - * 最大 Token - */ - private Integer maxTokens; - /** - * topP - */ - private Float topP; - - @Override - public Integer getTopK() { - return null; - } - - public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { - return DeepSeekChatOptions.builder() - .model(fromOptions.getModel()) - .temperature(fromOptions.getTemperature()) - .maxTokens(fromOptions.getMaxTokens()) - .topP(fromOptions.getTopP()) - .build(); - } - -} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatOptions.java deleted file mode 100644 index e3287b613a..0000000000 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatOptions.java +++ /dev/null @@ -1,55 +0,0 @@ -package cn.iocoder.yudao.framework.ai.core.model.xinghuo; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.springframework.ai.chat.prompt.ChatOptions; - -/** - * 讯飞星火 {@link ChatOptions} 实现类 - * - * 参考文档:HTTP 调用 - * - * @author fansili - */ -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class XingHuoChatOptions implements ChatOptions { - - public static final String MODEL_DEFAULT = "generalv3.5"; - - /** - * 模型 - */ - private String model; - /** - * 温度 - */ - private Float temperature; - /** - * 最大 Token - */ - private Integer maxTokens; - /** - * K 个候选 - */ - private Integer topK; - - @Override - public Float getTopP() { - return null; - } - - public static XingHuoChatOptions fromOptions(XingHuoChatOptions fromOptions) { - return XingHuoChatOptions.builder() - .model(fromOptions.getModel()) - .temperature(fromOptions.getTemperature()) - .maxTokens(fromOptions.getMaxTokens()) - .topK(fromOptions.getTopK()) - .build(); - } - -} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/DeepSeekChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/DeepSeekChatModelTests.java index f66c548176..b3e12bfd16 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/DeepSeekChatModelTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/DeepSeekChatModelTests.java @@ -8,6 +8,9 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; import reactor.core.publisher.Flux; import java.util.ArrayList; @@ -20,7 +23,18 @@ import java.util.List; */ public class DeepSeekChatModelTests { - private final DeepSeekChatModel chatModel = new DeepSeekChatModel("sk-e94db327cc7d457d99a8de8810fc6b12"); + private static final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder() + .baseUrl(DeepSeekChatModel.BASE_URL) + .apiKey("sk-e52047409b144d97b791a6a46a2d") // apiKey + .build()) + .defaultOptions(OpenAiChatOptions.builder() + .model("deepseek-chat") // 模型 + .temperature(0.7) + .build()) + .build(); + + private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel); @Test @Disabled