From b9e8495712de03b9c7e215b43f772c0c9b85ef0a Mon Sep 17 00:00:00 2001 From: YunaiV Date: Tue, 11 Mar 2025 09:42:24 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=8A=9F=E8=83=BD=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E6=96=B0=E5=A2=9E=20AzureOpenAiEmbeddingMo?= =?UTF-8?q?del=E3=80=81OpenAiEmbeddingModel=20=E5=90=91=E9=87=8F=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/core/factory/AiModelFactoryImpl.java | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index 4bd519555c..b588027cde 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -35,6 +35,7 @@ import lombok.SneakyThrows; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties; +import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiEmbeddingProperties; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; @@ -49,6 +50,7 @@ import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStorePr import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; +import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.BatchingStrategy; @@ -59,6 +61,8 @@ import org.springframework.ai.ollama.OllamaEmbeddingModel; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.OpenAiEmbeddingOptions; import org.springframework.ai.openai.OpenAiImageModel; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiImageApi; @@ -227,6 +231,7 @@ public class AiModelFactoryImpl implements AiModelFactory { } @Override + @SuppressWarnings("EnhancedSwitchMigration") public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model) { String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url, model); return Singleton.get(cacheKey, (Func0) () -> { @@ -237,10 +242,10 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildYiYanEmbeddingModel(apiKey, model); case ZHI_PU: return buildZhiPuEmbeddingModel(apiKey, url, model); -// case OPENAI: -// return buildOpenAiChatModel(apiKey, url); -// case AZURE_OPENAI: -// return buildAzureOpenAiChatModel(apiKey, url); + case OPENAI: + return buildOpenAiEmbeddingModel(apiKey, url, model); + case AZURE_OPENAI: + return buildAzureOpenAiEmbeddingModel(apiKey, url, model); case OLLAMA: return buildOllamaEmbeddingModel(url, model); default: @@ -474,6 +479,33 @@ public class AiModelFactoryImpl implements AiModelFactory { return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build(); } + /** + * 可参考 {@link OpenAiAutoConfiguration} 的 openAiEmbeddingModel 方法 + */ + private OpenAiEmbeddingModel buildOpenAiEmbeddingModel(String openAiToken, String url, String model) { + url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL); + OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build(); + OpenAiEmbeddingOptions openAiEmbeddingProperties = OpenAiEmbeddingOptions.builder().model(model).build(); + return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, openAiEmbeddingProperties); + } + + // TODO @芋艿:手头暂时没密钥,使用建议再测试下 + /** + * 可参考 {@link AzureOpenAiAutoConfiguration} 的 azureOpenAiEmbeddingModel 方法 + */ + private AzureOpenAiEmbeddingModel buildAzureOpenAiEmbeddingModel(String apiKey, String url, String model) { + AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration(); + // 创建 OpenAIClient 对象 + AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties(); + connectionProperties.setApiKey(apiKey); + connectionProperties.setEndpoint(url); + OpenAIClientBuilder openAIClient = azureOpenAiAutoConfiguration.openAIClientBuilder(connectionProperties, null); + // 获取 AzureOpenAiChatProperties 对象 + AzureOpenAiEmbeddingProperties embeddingProperties = SpringUtil.getBean(AzureOpenAiEmbeddingProperties.class); + return azureOpenAiAutoConfiguration.azureOpenAiEmbeddingModel(openAIClient, embeddingProperties, + null, null); + } + // ========== 各种创建 VectorStore 的方法 ========== /**