diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index 4bd519555c..b588027cde 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -35,6 +35,7 @@ import lombok.SneakyThrows; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties; +import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiEmbeddingProperties; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; @@ -49,6 +50,7 @@ import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStorePr import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; +import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.BatchingStrategy; @@ -59,6 +61,8 @@ import org.springframework.ai.ollama.OllamaEmbeddingModel; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.OpenAiEmbeddingOptions; import org.springframework.ai.openai.OpenAiImageModel; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiImageApi; @@ -227,6 +231,7 @@ public class AiModelFactoryImpl implements AiModelFactory { } @Override + @SuppressWarnings("EnhancedSwitchMigration") public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model) { String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url, model); return Singleton.get(cacheKey, (Func0) () -> { @@ -237,10 +242,10 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildYiYanEmbeddingModel(apiKey, model); case ZHI_PU: return buildZhiPuEmbeddingModel(apiKey, url, model); -// case OPENAI: -// return buildOpenAiChatModel(apiKey, url); -// case AZURE_OPENAI: -// return buildAzureOpenAiChatModel(apiKey, url); + case OPENAI: + return buildOpenAiEmbeddingModel(apiKey, url, model); + case AZURE_OPENAI: + return buildAzureOpenAiEmbeddingModel(apiKey, url, model); case OLLAMA: return buildOllamaEmbeddingModel(url, model); default: @@ -474,6 +479,33 @@ public class AiModelFactoryImpl implements AiModelFactory { return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build(); } + /** + * 可参考 {@link OpenAiAutoConfiguration} 的 openAiEmbeddingModel 方法 + */ + private OpenAiEmbeddingModel buildOpenAiEmbeddingModel(String openAiToken, String url, String model) { + url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL); + OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build(); + OpenAiEmbeddingOptions openAiEmbeddingProperties = OpenAiEmbeddingOptions.builder().model(model).build(); + return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, openAiEmbeddingProperties); + } + + // TODO @芋艿:手头暂时没密钥,使用建议再测试下 + /** + * 可参考 {@link AzureOpenAiAutoConfiguration} 的 azureOpenAiEmbeddingModel 方法 + */ + private AzureOpenAiEmbeddingModel buildAzureOpenAiEmbeddingModel(String apiKey, String url, String model) { + AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration(); + // 创建 OpenAIClient 对象 + AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties(); + connectionProperties.setApiKey(apiKey); + connectionProperties.setEndpoint(url); + OpenAIClientBuilder openAIClient = azureOpenAiAutoConfiguration.openAIClientBuilder(connectionProperties, null); + // 获取 AzureOpenAiChatProperties 对象 + AzureOpenAiEmbeddingProperties embeddingProperties = SpringUtil.getBean(AzureOpenAiEmbeddingProperties.class); + return azureOpenAiAutoConfiguration.azureOpenAiEmbeddingModel(openAIClient, embeddingProperties, + null, null); + } + // ========== 各种创建 VectorStore 的方法 ========== /**