【代码重构】AI:spring ai 依赖,升级到 1.0.0-M6
This commit is contained in:
parent
3b7b81829d
commit
7ef73b7d09
|
@ -1,4 +1,4 @@
|
||||||
/**
|
/**
|
||||||
* crm 模块的 web 拓展封装
|
* ai 模块的 web 拓展封装
|
||||||
*/
|
*/
|
||||||
package cn.iocoder.yudao.module.crm.framework.web;
|
package cn.iocoder.yudao.module.ai.framework.web;
|
||||||
|
|
|
@ -20,7 +20,6 @@ import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
||||||
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
||||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesOptions;
|
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
|
@ -144,7 +143,7 @@ public class AiImageServiceImpl implements AiImageService {
|
||||||
.withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
|
.withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
|
||||||
.build();
|
.build();
|
||||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
|
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
|
||||||
return TongYiImagesOptions.builder()
|
return DashScopeImageOptions.builder()
|
||||||
.withModel(draw.getModel()).withN(1)
|
.withModel(draw.getModel()).withN(1)
|
||||||
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
<name>${project.artifactId}</name>
|
<name>${project.artifactId}</name>
|
||||||
<description>AI 大模型拓展,接入国内外大模型</description>
|
<description>AI 大模型拓展,接入国内外大模型</description>
|
||||||
<properties>
|
<properties>
|
||||||
<spring-ai.groupId>group.springframework.ai</spring-ai.groupId>
|
<spring-ai.groupId>org.springframework.ai</spring-ai.groupId>
|
||||||
<spring-ai.version>1.1.0</spring-ai.version>
|
<spring-ai.version>1.0.0-M6</spring-ai.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
@ -90,6 +90,11 @@
|
||||||
<artifactId>spring-ai-qianfan-spring-boot-starter</artifactId>
|
<artifactId>spring-ai-qianfan-spring-boot-starter</artifactId>
|
||||||
<version>${spring-ai.version}</version>
|
<version>${spring-ai.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.alibaba.cloud.ai</groupId>
|
||||||
|
<artifactId>spring-ai-alibaba-starter</artifactId>
|
||||||
|
<version>1.0.0-M5.1</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<!-- 阿里云 通义千问 -->
|
<!-- 阿里云 通义千问 -->
|
||||||
<!-- TODO 芋艿:等 spring cloud alibaba ai 发布最新的时候,可以替换掉这个依赖,并且删除我们直接 cv 的代码 -->
|
<!-- TODO 芋艿:等 spring cloud alibaba ai 发布最新的时候,可以替换掉这个依赖,并且删除我们直接 cv 的代码 -->
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
package cn.iocoder.yudao.framework.ai.config;
|
package cn.iocoder.yudao.framework.ai.config;
|
||||||
|
|
||||||
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
||||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
|
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
|
|
||||||
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.ai.openai.OpenAiChatModel;
|
||||||
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
|
import org.springframework.ai.openai.api.OpenAiApi;
|
||||||
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
|
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
|
||||||
import org.springframework.ai.tokenizer.TokenCountEstimator;
|
import org.springframework.ai.tokenizer.TokenCountEstimator;
|
||||||
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||||||
|
@ -17,7 +18,6 @@ import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Import;
|
|
||||||
import org.springframework.context.annotation.Lazy;
|
import org.springframework.context.annotation.Lazy;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -28,7 +28,6 @@ import org.springframework.context.annotation.Lazy;
|
||||||
@AutoConfiguration
|
@AutoConfiguration
|
||||||
@EnableConfigurationProperties(YudaoAiProperties.class)
|
@EnableConfigurationProperties(YudaoAiProperties.class)
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Import(TongYiAutoConfiguration.class)
|
|
||||||
public class YudaoAiAutoConfiguration {
|
public class YudaoAiAutoConfiguration {
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
|
@ -43,26 +42,52 @@ public class YudaoAiAutoConfiguration {
|
||||||
@ConditionalOnProperty(value = "yudao.ai.deepseek.enable", havingValue = "true")
|
@ConditionalOnProperty(value = "yudao.ai.deepseek.enable", havingValue = "true")
|
||||||
public DeepSeekChatModel deepSeekChatModel(YudaoAiProperties yudaoAiProperties) {
|
public DeepSeekChatModel deepSeekChatModel(YudaoAiProperties yudaoAiProperties) {
|
||||||
YudaoAiProperties.DeepSeekProperties properties = yudaoAiProperties.getDeepSeek();
|
YudaoAiProperties.DeepSeekProperties properties = yudaoAiProperties.getDeepSeek();
|
||||||
DeepSeekChatOptions options = DeepSeekChatOptions.builder()
|
return buildDeepSeekChatModel(properties);
|
||||||
.model(properties.getModel())
|
}
|
||||||
.temperature(properties.getTemperature())
|
|
||||||
.maxTokens(properties.getMaxTokens())
|
public DeepSeekChatModel buildDeepSeekChatModel(YudaoAiProperties.DeepSeekProperties properties) {
|
||||||
.topP(properties.getTopP())
|
if (StrUtil.isEmpty(properties.getModel())) {
|
||||||
|
properties.setModel(DeepSeekChatModel.MODEL_DEFAULT);
|
||||||
|
}
|
||||||
|
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||||
|
.openAiApi(OpenAiApi.builder()
|
||||||
|
.baseUrl(DeepSeekChatModel.BASE_URL)
|
||||||
|
.apiKey(properties.getApiKey())
|
||||||
|
.build())
|
||||||
|
.defaultOptions(OpenAiChatOptions.builder()
|
||||||
|
.model(properties.getModel())
|
||||||
|
.temperature(properties.getTemperature())
|
||||||
|
.maxTokens(properties.getMaxTokens())
|
||||||
|
.topP(properties.getTopP())
|
||||||
|
.build())
|
||||||
.build();
|
.build();
|
||||||
return new DeepSeekChatModel(properties.getApiKey(), options);
|
return new DeepSeekChatModel(openAiChatModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
@ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true")
|
@ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true")
|
||||||
public XingHuoChatModel xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
|
public XingHuoChatModel xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
|
||||||
YudaoAiProperties.XingHuoProperties properties = yudaoAiProperties.getXinghuo();
|
YudaoAiProperties.XingHuoProperties properties = yudaoAiProperties.getXinghuo();
|
||||||
XingHuoChatOptions options = XingHuoChatOptions.builder()
|
return buildXingHuoChatClient(properties);
|
||||||
.model(properties.getModel())
|
}
|
||||||
.temperature(properties.getTemperature())
|
|
||||||
.maxTokens(properties.getMaxTokens())
|
public XingHuoChatModel buildXingHuoChatClient(YudaoAiProperties.XingHuoProperties properties) {
|
||||||
.topK(properties.getTopK())
|
if (StrUtil.isEmpty(properties.getModel())) {
|
||||||
|
properties.setModel(XingHuoChatModel.MODEL_DEFAULT);
|
||||||
|
}
|
||||||
|
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||||
|
.openAiApi(OpenAiApi.builder()
|
||||||
|
.baseUrl(XingHuoChatModel.BASE_URL)
|
||||||
|
.apiKey(properties.getAppKey() + ":" + properties.getSecretKey())
|
||||||
|
.build())
|
||||||
|
.defaultOptions(OpenAiChatOptions.builder()
|
||||||
|
.model(properties.getModel())
|
||||||
|
.temperature(properties.getTemperature())
|
||||||
|
.maxTokens(properties.getMaxTokens())
|
||||||
|
.topP(properties.getTopP())
|
||||||
|
.build())
|
||||||
.build();
|
.build();
|
||||||
return new XingHuoChatModel(properties.getAppKey(), properties.getSecretKey(), options);
|
return new XingHuoChatModel(openAiChatModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
|
|
|
@ -42,9 +42,9 @@ public class YudaoAiProperties {
|
||||||
private String secretKey;
|
private String secretKey;
|
||||||
|
|
||||||
private String model;
|
private String model;
|
||||||
private Float temperature;
|
private Double temperature;
|
||||||
private Integer maxTokens;
|
private Integer maxTokens;
|
||||||
private Integer topK;
|
private Double topP;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,9 +55,9 @@ public class YudaoAiProperties {
|
||||||
private String apiKey;
|
private String apiKey;
|
||||||
|
|
||||||
private String model;
|
private String model;
|
||||||
private Float temperature;
|
private Double temperature;
|
||||||
private Integer maxTokens;
|
private Integer maxTokens;
|
||||||
private Float topP;
|
private Double topP;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,60 +13,45 @@ import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||||
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
|
import com.alibaba.cloud.ai.autoconfigure.dashscope.DashScopeAutoConfiguration;
|
||||||
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
|
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
|
||||||
import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
|
import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
|
||||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
|
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
||||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
|
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
|
||||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
|
||||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
|
import com.azure.ai.openai.OpenAIClientBuilder;
|
||||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
|
||||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
|
||||||
import com.alibaba.dashscope.embeddings.TextEmbedding;
|
|
||||||
import com.azure.ai.openai.OpenAIClient;
|
|
||||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
||||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
||||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
|
||||||
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
||||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||||
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
||||||
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
|
|
||||||
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
|
|
||||||
import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
|
|
||||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
|
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
|
||||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
|
|
||||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
|
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
|
||||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
|
|
||||||
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
||||||
import org.springframework.ai.chat.model.ChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.ai.embedding.EmbeddingModel;
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
|
||||||
import org.springframework.ai.ollama.OllamaChatModel;
|
import org.springframework.ai.ollama.OllamaChatModel;
|
||||||
import org.springframework.ai.ollama.api.OllamaApi;
|
import org.springframework.ai.ollama.api.OllamaApi;
|
||||||
import org.springframework.ai.openai.OpenAiChatModel;
|
import org.springframework.ai.openai.OpenAiChatModel;
|
||||||
import org.springframework.ai.openai.OpenAiImageModel;
|
import org.springframework.ai.openai.OpenAiImageModel;
|
||||||
import org.springframework.ai.openai.api.ApiUtils;
|
|
||||||
import org.springframework.ai.openai.api.OpenAiApi;
|
import org.springframework.ai.openai.api.OpenAiApi;
|
||||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||||
|
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
|
||||||
import org.springframework.ai.qianfan.QianFanChatModel;
|
import org.springframework.ai.qianfan.QianFanChatModel;
|
||||||
import org.springframework.ai.qianfan.QianFanImageModel;
|
import org.springframework.ai.qianfan.QianFanImageModel;
|
||||||
import org.springframework.ai.qianfan.api.QianFanApi;
|
import org.springframework.ai.qianfan.api.QianFanApi;
|
||||||
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
||||||
import org.springframework.ai.stabilityai.StabilityAiImageModel;
|
import org.springframework.ai.stabilityai.StabilityAiImageModel;
|
||||||
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
||||||
import org.springframework.ai.vectorstore.RedisVectorStore;
|
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||||
import org.springframework.ai.vectorstore.VectorStore;
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
|
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
|
||||||
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
|
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
|
||||||
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
||||||
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
||||||
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
|
|
||||||
import org.springframework.retry.support.RetryTemplate;
|
|
||||||
import org.springframework.web.client.ResponseErrorHandler;
|
|
||||||
import org.springframework.web.client.RestClient;
|
import org.springframework.web.client.RestClient;
|
||||||
import redis.clients.jedis.JedisPooled;
|
|
||||||
import redis.clients.jedis.search.Schema;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -110,7 +95,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
//noinspection EnhancedSwitchMigration
|
//noinspection EnhancedSwitchMigration
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case TONG_YI:
|
case TONG_YI:
|
||||||
return SpringUtil.getBean(TongYiChatModel.class);
|
return SpringUtil.getBean(DashScopeChatModel.class);
|
||||||
case YI_YAN:
|
case YI_YAN:
|
||||||
return SpringUtil.getBean(QianFanChatModel.class);
|
return SpringUtil.getBean(QianFanChatModel.class);
|
||||||
case DEEP_SEEK:
|
case DEEP_SEEK:
|
||||||
|
@ -135,7 +120,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
//noinspection EnhancedSwitchMigration
|
//noinspection EnhancedSwitchMigration
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case TONG_YI:
|
case TONG_YI:
|
||||||
return SpringUtil.getBean(TongYiImagesModel.class);
|
return SpringUtil.getBean(DashScopeImageModel.class);
|
||||||
case YI_YAN:
|
case YI_YAN:
|
||||||
return SpringUtil.getBean(QianFanImageModel.class);
|
return SpringUtil.getBean(QianFanImageModel.class);
|
||||||
case ZHI_PU:
|
case ZHI_PU:
|
||||||
|
@ -202,17 +187,20 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
|
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
|
||||||
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
|
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
|
||||||
String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
|
String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
|
||||||
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
// TODO @芋艿:先临时使用 store
|
||||||
.withIndexName(cacheKey)
|
return SimpleVectorStore.builder(embeddingModel).build();
|
||||||
.withPrefix(prefix)
|
// TODO @芋艿:@xin:后续看看,是不是切到阿里云之类的
|
||||||
.withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
|
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
||||||
.build();
|
// .withIndexName(cacheKey)
|
||||||
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
|
// .withPrefix(prefix)
|
||||||
RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
|
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
|
||||||
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
// .build();
|
||||||
true);
|
// RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
|
||||||
redisVectorStore.afterPropertiesSet();
|
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
|
||||||
return redisVectorStore;
|
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
||||||
|
// true);
|
||||||
|
// redisVectorStore.afterPropertiesSet();
|
||||||
|
// return redisVectorStore;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,29 +214,23 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
// ========== 各种创建 spring-ai 客户端的方法 ==========
|
// ========== 各种创建 spring-ai 客户端的方法 ==========
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
|
* 可参考 {@link DashScopeAutoConfiguration} 的 dashscopeChatModel 方法
|
||||||
*/
|
*/
|
||||||
private static TongYiChatModel buildTongYiChatModel(String key) {
|
private static DashScopeChatModel buildTongYiChatModel(String key) {
|
||||||
com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
|
DashScopeApi dashScopeApi = new DashScopeApi(key);
|
||||||
TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
|
return new DashScopeChatModel(dashScopeApi);
|
||||||
// TODO @芋艿:貌似 apiKey 是全局唯一的???得测试下
|
|
||||||
// TODO @芋艿:貌似阿里云不是增量返回的
|
|
||||||
// 该 issue 进行跟进中 https://github.com/alibaba/spring-cloud-alibaba/issues/3790
|
|
||||||
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
|
|
||||||
connectionProperties.setApiKey(key);
|
|
||||||
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static TongYiImagesModel buildTongYiImagesModel(String key) {
|
|
||||||
ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
|
|
||||||
TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
|
|
||||||
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
|
|
||||||
connectionProperties.setApiKey(key);
|
|
||||||
return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
* 可参考 {@link DashScopeAutoConfiguration} 的 dashScopeImageModel 方法
|
||||||
|
*/
|
||||||
|
private static DashScopeImageModel buildTongYiImagesModel(String key) {
|
||||||
|
DashScopeImageApi dashScopeImageApi = new DashScopeImageApi(key);
|
||||||
|
return new DashScopeImageModel(dashScopeImageApi);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 可参考 {@link QianFanAutoConfiguration} 的 qianFanChatModel 方法
|
||||||
*/
|
*/
|
||||||
private static QianFanChatModel buildYiYanChatModel(String key) {
|
private static QianFanChatModel buildYiYanChatModel(String key) {
|
||||||
List<String> keys = StrUtil.split(key, '|');
|
List<String> keys = StrUtil.split(key, '|');
|
||||||
|
@ -260,7 +242,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
* 可参考 {@link QianFanAutoConfiguration} 的 qianFanImageModel 方法
|
||||||
*/
|
*/
|
||||||
private QianFanImageModel buildQianFanImageModel(String key) {
|
private QianFanImageModel buildQianFanImageModel(String key) {
|
||||||
List<String> keys = StrUtil.split(key, '|');
|
List<String> keys = StrUtil.split(key, '|');
|
||||||
|
@ -275,11 +257,13 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
|
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
|
||||||
*/
|
*/
|
||||||
private static DeepSeekChatModel buildDeepSeekChatModel(String apiKey) {
|
private static DeepSeekChatModel buildDeepSeekChatModel(String apiKey) {
|
||||||
return new DeepSeekChatModel(apiKey);
|
YudaoAiProperties.DeepSeekProperties properties = new YudaoAiProperties.DeepSeekProperties()
|
||||||
|
.setApiKey(apiKey);
|
||||||
|
return new YudaoAiAutoConfiguration().buildDeepSeekChatModel(properties);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
|
* 可参考 {@link ZhiPuAiAutoConfiguration} 的 zhiPuAiChatModel 方法
|
||||||
*/
|
*/
|
||||||
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
|
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
|
||||||
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
||||||
|
@ -288,7 +272,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
* 可参考 {@link ZhiPuAiAutoConfiguration} 的 zhiPuAiImageModel 方法
|
||||||
*/
|
*/
|
||||||
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
|
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
|
||||||
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
||||||
|
@ -301,21 +285,22 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
*/
|
*/
|
||||||
private static XingHuoChatModel buildXingHuoChatModel(String key) {
|
private static XingHuoChatModel buildXingHuoChatModel(String key) {
|
||||||
List<String> keys = StrUtil.split(key, '|');
|
List<String> keys = StrUtil.split(key, '|');
|
||||||
Assert.equals(keys.size(), 3, "XingHuoChatClient 的密钥需要 (appid|appKey|secretKey) 格式");
|
Assert.equals(keys.size(), 2, "XingHuoChatClient 的密钥需要 (appKey|secretKey) 格式");
|
||||||
String appKey = keys.get(1);
|
YudaoAiProperties.XingHuoProperties properties = new YudaoAiProperties.XingHuoProperties()
|
||||||
String secretKey = keys.get(2);
|
.setAppKey(keys.get(0)).setSecretKey(keys.get(1));
|
||||||
return new XingHuoChatModel(appKey, secretKey);
|
return new YudaoAiAutoConfiguration().buildXingHuoChatClient(properties);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link OpenAiAutoConfiguration}
|
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiChatModel 方法
|
||||||
*/
|
*/
|
||||||
private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
|
private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
|
||||||
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
|
url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
|
||||||
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
|
OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build();
|
||||||
return new OpenAiChatModel(openAiApi);
|
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO @芋艿:手头暂时没密钥,使用建议再测试下
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link AzureOpenAiAutoConfiguration}
|
* 可参考 {@link AzureOpenAiAutoConfiguration}
|
||||||
*/
|
*/
|
||||||
|
@ -325,27 +310,28 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
|
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
|
||||||
connectionProperties.setApiKey(apiKey);
|
connectionProperties.setApiKey(apiKey);
|
||||||
connectionProperties.setEndpoint(url);
|
connectionProperties.setEndpoint(url);
|
||||||
OpenAIClient openAIClient = azureOpenAiAutoConfiguration.openAIClient(connectionProperties);
|
OpenAIClientBuilder openAIClient = azureOpenAiAutoConfiguration.openAIClientBuilder(connectionProperties, null);
|
||||||
// 获取 AzureOpenAiChatProperties 对象
|
// 获取 AzureOpenAiChatProperties 对象
|
||||||
AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class);
|
AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class);
|
||||||
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties, null, null);
|
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties,
|
||||||
|
null, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link OpenAiAutoConfiguration}
|
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiImageModel 方法
|
||||||
*/
|
*/
|
||||||
private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) {
|
private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) {
|
||||||
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
|
url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
|
||||||
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
|
OpenAiImageApi openAiApi = OpenAiImageApi.builder().baseUrl(url).apiKey(openAiToken).build();
|
||||||
return new OpenAiImageModel(openAiApi);
|
return new OpenAiImageModel(openAiApi);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link OllamaAutoConfiguration}
|
* 可参考 {@link OllamaAutoConfiguration} 的 ollamaApi 方法
|
||||||
*/
|
*/
|
||||||
private static OllamaChatModel buildOllamaChatModel(String url) {
|
private static OllamaChatModel buildOllamaChatModel(String url) {
|
||||||
OllamaApi ollamaApi = new OllamaApi(url);
|
OllamaApi ollamaApi = new OllamaApi(url);
|
||||||
return new OllamaChatModel(ollamaApi);
|
return OllamaChatModel.builder().ollamaApi(ollamaApi).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) {
|
private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) {
|
||||||
|
@ -356,13 +342,13 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||||
|
|
||||||
// ========== 各种创建 EmbeddingModel 的方法 ==========
|
// ========== 各种创建 EmbeddingModel 的方法 ==========
|
||||||
|
|
||||||
|
// TODO @芋艿:需要测试下
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link TongYiAutoConfiguration#tongYiTextEmbeddingClient(TextEmbedding, TongYiConnectionProperties)}
|
* 可参考 {@link DashScopeAutoConfiguration} 的 dashscopeEmbeddingModel 方法
|
||||||
*/
|
*/
|
||||||
private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) {
|
private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) {
|
||||||
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
|
DashScopeApi dashScopeApi = new DashScopeApi(apiKey);
|
||||||
connectionProperties.setApiKey(apiKey);
|
return new DashScopeEmbeddingModel(dashScopeApi);
|
||||||
return new TongYiAutoConfiguration().tongYiTextEmbeddingClient(SpringUtil.getBean(TextEmbedding.class), connectionProperties);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,9 +8,10 @@ import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.openai.api.ApiUtils;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpRequest;
|
import org.springframework.http.HttpRequest;
|
||||||
import org.springframework.http.HttpStatusCode;
|
import org.springframework.http.HttpStatusCode;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||||
import org.springframework.web.reactive.function.client.WebClient;
|
import org.springframework.web.reactive.function.client.WebClient;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
@ -18,6 +19,7 @@ import reactor.core.publisher.Mono;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.function.Consumer;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
|
@ -50,11 +52,19 @@ public class MidjourneyApi {
|
||||||
public MidjourneyApi(String baseUrl, String apiKey, String notifyUrl) {
|
public MidjourneyApi(String baseUrl, String apiKey, String notifyUrl) {
|
||||||
this.webClient = WebClient.builder()
|
this.webClient = WebClient.builder()
|
||||||
.baseUrl(baseUrl)
|
.baseUrl(baseUrl)
|
||||||
.defaultHeaders(ApiUtils.getJsonContentHeaders(apiKey))
|
.defaultHeaders(getJsonContentHeaders(apiKey))
|
||||||
.build();
|
.build();
|
||||||
this.notifyUrl = notifyUrl;
|
this.notifyUrl = notifyUrl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO @芋艿:这里,看看怎么调整下???https://github.com/spring-projects/spring-ai/issues/741
|
||||||
|
public static Consumer<HttpHeaders> getJsonContentHeaders(String apiKey) {
|
||||||
|
return (headers) -> {
|
||||||
|
headers.setBearerAuth(apiKey);
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* imagine - 根据提示词提交绘画任务
|
* imagine - 根据提示词提交绘画任务
|
||||||
*
|
*
|
||||||
|
|
|
@ -2,9 +2,7 @@ package cn.iocoder.yudao.framework.ai.core.util;
|
||||||
|
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
|
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
|
|
||||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
|
||||||
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||||
import org.springframework.ai.chat.messages.*;
|
import org.springframework.ai.chat.messages.*;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
|
@ -21,26 +19,24 @@ import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||||
public class AiUtils {
|
public class AiUtils {
|
||||||
|
|
||||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||||
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
|
||||||
//noinspection EnhancedSwitchMigration
|
//noinspection EnhancedSwitchMigration
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case TONG_YI:
|
case TONG_YI:
|
||||||
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
// TODO @芋艿:tongyi 暂时没 maxTokens 选项
|
||||||
|
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).build();
|
||||||
case YI_YAN:
|
case YI_YAN:
|
||||||
return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
||||||
case DEEP_SEEK:
|
|
||||||
return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
|
|
||||||
case ZHI_PU:
|
case ZHI_PU:
|
||||||
return ZhiPuAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
||||||
case XING_HUO:
|
|
||||||
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
|
|
||||||
case OPENAI:
|
case OPENAI:
|
||||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
case DEEP_SEEK: // 复用 OpenAI 客户端
|
||||||
|
case XING_HUO: // 复用 OpenAI 客户端
|
||||||
|
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
||||||
case AZURE_OPENAI:
|
case AZURE_OPENAI:
|
||||||
// TODO 芋艿:貌似没 model 字段???!
|
// TODO 芋艿:貌似没 model 字段???!
|
||||||
return AzureOpenAiChatOptions.builder().withDeploymentName(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens).build();
|
||||||
case OLLAMA:
|
case OLLAMA:
|
||||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens).build();
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
}
|
}
|
||||||
|
@ -56,8 +52,8 @@ public class AiUtils {
|
||||||
if (MessageType.SYSTEM.getValue().equals(type)) {
|
if (MessageType.SYSTEM.getValue().equals(type)) {
|
||||||
return new SystemMessage(content);
|
return new SystemMessage(content);
|
||||||
}
|
}
|
||||||
if (MessageType.FUNCTION.getValue().equals(type)) {
|
if (MessageType.TOOL.getValue().equals(type)) {
|
||||||
return new FunctionMessage(content);
|
throw new UnsupportedOperationException("暂不支持 tool 消息:" + content);
|
||||||
}
|
}
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
|
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,61 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2023 - 2024 the original author or authors.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
package org.springframework.ai.autoconfigure.vectorstore.redis;
|
|
||||||
|
|
||||||
import org.springframework.ai.embedding.EmbeddingModel;
|
|
||||||
import org.springframework.ai.vectorstore.RedisVectorStore;
|
|
||||||
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
|
|
||||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
|
||||||
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
|
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
|
||||||
import org.springframework.context.annotation.Bean;
|
|
||||||
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
|
|
||||||
import redis.clients.jedis.JedisPooled;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TODO @xin 先拿 spring-ai 最新代码覆盖,1.0.0-M1 跟 redis 自动配置会冲突
|
|
||||||
*
|
|
||||||
* TODO 这个官方,有说啥时候 fix 哇?
|
|
||||||
* TODO 看着是列在1.0.0-M2版本
|
|
||||||
*
|
|
||||||
* @author Christian Tzolov
|
|
||||||
* @author Eddú Meléndez
|
|
||||||
*/
|
|
||||||
@AutoConfiguration(after = RedisAutoConfiguration.class)
|
|
||||||
@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class})
|
|
||||||
@ConditionalOnBean(JedisConnectionFactory.class)
|
|
||||||
@EnableConfigurationProperties(RedisVectorStoreProperties.class)
|
|
||||||
public class RedisVectorStoreAutoConfiguration {
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnMissingBean
|
|
||||||
public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
|
|
||||||
JedisConnectionFactory jedisConnectionFactory) {
|
|
||||||
|
|
||||||
var config = RedisVectorStoreConfig.builder()
|
|
||||||
.withIndexName(properties.getIndex())
|
|
||||||
.withPrefix(properties.getPrefix())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
return new RedisVectorStore(config, embeddingModel,
|
|
||||||
new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()),
|
|
||||||
properties.isInitializeSchema());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,456 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2023 - 2024 the original author or authors.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
package org.springframework.ai.vectorstore;
|
|
||||||
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.springframework.ai.document.Document;
|
|
||||||
import org.springframework.ai.embedding.EmbeddingModel;
|
|
||||||
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.util.Assert;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
import redis.clients.jedis.JedisPooled;
|
|
||||||
import redis.clients.jedis.Pipeline;
|
|
||||||
import redis.clients.jedis.json.Path2;
|
|
||||||
import redis.clients.jedis.search.*;
|
|
||||||
import redis.clients.jedis.search.Schema.FieldType;
|
|
||||||
import redis.clients.jedis.search.schemafields.*;
|
|
||||||
import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm;
|
|
||||||
|
|
||||||
import java.text.MessageFormat;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.function.Predicate;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The RedisVectorStore is for managing and querying vector data in a Redis database. It
|
|
||||||
* offers functionalities like adding, deleting, and performing similarity searches on
|
|
||||||
* documents.
|
|
||||||
*
|
|
||||||
* The store utilizes RedisJSON and RedisSearch to handle JSON documents and to index and
|
|
||||||
* search vector data. It supports various vector algorithms (e.g., FLAT, HSNW) for
|
|
||||||
* efficient similarity searches. Additionally, it allows for custom metadata fields in
|
|
||||||
* the documents to be stored alongside the vector and content data.
|
|
||||||
*
|
|
||||||
* This class requires a RedisVectorStoreConfig configuration object for initialization,
|
|
||||||
* which includes settings like Redis URI, index name, field names, and vector algorithms.
|
|
||||||
* It also requires an EmbeddingModel to convert documents into embeddings before storing
|
|
||||||
* them.
|
|
||||||
*
|
|
||||||
* @author Julien Ruaux
|
|
||||||
* @author Christian Tzolov
|
|
||||||
* @author Eddú Meléndez
|
|
||||||
* @see VectorStore
|
|
||||||
* @see RedisVectorStoreConfig
|
|
||||||
* @see EmbeddingModel
|
|
||||||
*/
|
|
||||||
public class RedisVectorStore implements VectorStore, InitializingBean {
|
|
||||||
|
|
||||||
public enum Algorithm {
|
|
||||||
|
|
||||||
FLAT, HSNW
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public record MetadataField(String name, FieldType fieldType) {
|
|
||||||
|
|
||||||
public static MetadataField text(String name) {
|
|
||||||
return new MetadataField(name, FieldType.TEXT);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static MetadataField numeric(String name) {
|
|
||||||
return new MetadataField(name, FieldType.NUMERIC);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static MetadataField tag(String name) {
|
|
||||||
return new MetadataField(name, FieldType.TAG);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configuration for the Redis vector store.
|
|
||||||
*/
|
|
||||||
public static final class RedisVectorStoreConfig {
|
|
||||||
|
|
||||||
private final String indexName;
|
|
||||||
|
|
||||||
private final String prefix;
|
|
||||||
|
|
||||||
private final String contentFieldName;
|
|
||||||
|
|
||||||
private final String embeddingFieldName;
|
|
||||||
|
|
||||||
private final Algorithm vectorAlgorithm;
|
|
||||||
|
|
||||||
private final List<MetadataField> metadataFields;
|
|
||||||
|
|
||||||
private RedisVectorStoreConfig() {
|
|
||||||
this(builder());
|
|
||||||
}
|
|
||||||
|
|
||||||
private RedisVectorStoreConfig(Builder builder) {
|
|
||||||
this.indexName = builder.indexName;
|
|
||||||
this.prefix = builder.prefix;
|
|
||||||
this.contentFieldName = builder.contentFieldName;
|
|
||||||
this.embeddingFieldName = builder.embeddingFieldName;
|
|
||||||
this.vectorAlgorithm = builder.vectorAlgorithm;
|
|
||||||
this.metadataFields = builder.metadataFields;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Start building a new configuration.
|
|
||||||
* @return The entry point for creating a new configuration.
|
|
||||||
*/
|
|
||||||
public static Builder builder() {
|
|
||||||
|
|
||||||
return new Builder();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@return the default config}
|
|
||||||
*/
|
|
||||||
public static RedisVectorStoreConfig defaultConfig() {
|
|
||||||
|
|
||||||
return builder().build();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
|
|
||||||
private String indexName = DEFAULT_INDEX_NAME;
|
|
||||||
|
|
||||||
private String prefix = DEFAULT_PREFIX;
|
|
||||||
|
|
||||||
private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME;
|
|
||||||
|
|
||||||
private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME;
|
|
||||||
|
|
||||||
private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
|
|
||||||
|
|
||||||
private List<MetadataField> metadataFields = new ArrayList<>();
|
|
||||||
|
|
||||||
private Builder() {
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configures the Redis index name to use.
|
|
||||||
* @param name the index name to use
|
|
||||||
* @return this builder
|
|
||||||
*/
|
|
||||||
public Builder withIndexName(String name) {
|
|
||||||
this.indexName = name;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configures the Redis key prefix to use (default: "embedding:").
|
|
||||||
* @param prefix the prefix to use
|
|
||||||
* @return this builder
|
|
||||||
*/
|
|
||||||
public Builder withPrefix(String prefix) {
|
|
||||||
this.prefix = prefix;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configures the Redis content field name to use.
|
|
||||||
* @param name the content field name to use
|
|
||||||
* @return this builder
|
|
||||||
*/
|
|
||||||
public Builder withContentFieldName(String name) {
|
|
||||||
this.contentFieldName = name;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configures the Redis embedding field name to use.
|
|
||||||
* @param name the embedding field name to use
|
|
||||||
* @return this builder
|
|
||||||
*/
|
|
||||||
public Builder withEmbeddingFieldName(String name) {
|
|
||||||
this.embeddingFieldName = name;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configures the Redis vector algorithmto use.
|
|
||||||
* @param algorithm the vector algorithm to use
|
|
||||||
* @return this builder
|
|
||||||
*/
|
|
||||||
public Builder withVectorAlgorithm(Algorithm algorithm) {
|
|
||||||
this.vectorAlgorithm = algorithm;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder withMetadataFields(MetadataField... fields) {
|
|
||||||
return withMetadataFields(Arrays.asList(fields));
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder withMetadataFields(List<MetadataField> fields) {
|
|
||||||
this.metadataFields = fields;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@return the immutable configuration}
|
|
||||||
*/
|
|
||||||
public RedisVectorStoreConfig build() {
|
|
||||||
|
|
||||||
return new RedisVectorStoreConfig(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private final boolean initializeSchema;
|
|
||||||
|
|
||||||
public static final String DEFAULT_INDEX_NAME = "spring-ai-index";
|
|
||||||
|
|
||||||
public static final String DEFAULT_CONTENT_FIELD_NAME = "content";
|
|
||||||
|
|
||||||
public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding";
|
|
||||||
|
|
||||||
public static final String DEFAULT_PREFIX = "embedding:";
|
|
||||||
|
|
||||||
public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW;
|
|
||||||
|
|
||||||
private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]";
|
|
||||||
|
|
||||||
private static final Path2 JSON_SET_PATH = Path2.of("$");
|
|
||||||
|
|
||||||
private static final String JSON_PATH_PREFIX = "$.";
|
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class);
|
|
||||||
|
|
||||||
private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK");
|
|
||||||
|
|
||||||
private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1l);
|
|
||||||
|
|
||||||
private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32";
|
|
||||||
|
|
||||||
private static final String EMBEDDING_PARAM_NAME = "BLOB";
|
|
||||||
|
|
||||||
public static final String DISTANCE_FIELD_NAME = "vector_score";
|
|
||||||
|
|
||||||
private static final String DEFAULT_DISTANCE_METRIC = "COSINE";
|
|
||||||
|
|
||||||
private final JedisPooled jedis;
|
|
||||||
|
|
||||||
private final EmbeddingModel embeddingModel;
|
|
||||||
|
|
||||||
private final RedisVectorStoreConfig config;
|
|
||||||
|
|
||||||
private FilterExpressionConverter filterExpressionConverter;
|
|
||||||
|
|
||||||
public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis,
|
|
||||||
boolean initializeSchema) {
|
|
||||||
|
|
||||||
Assert.notNull(config, "Config must not be null");
|
|
||||||
Assert.notNull(embeddingModel, "Embedding model must not be null");
|
|
||||||
this.initializeSchema = initializeSchema;
|
|
||||||
|
|
||||||
this.jedis = jedis;
|
|
||||||
this.embeddingModel = embeddingModel;
|
|
||||||
this.config = config;
|
|
||||||
this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JedisPooled getJedis() {
|
|
||||||
return this.jedis;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void add(List<Document> documents) {
|
|
||||||
try (Pipeline pipeline = this.jedis.pipelined()) {
|
|
||||||
for (Document document : documents) {
|
|
||||||
var embedding = this.embeddingModel.embed(document);
|
|
||||||
document.setEmbedding(embedding);
|
|
||||||
|
|
||||||
var fields = new HashMap<String, Object>();
|
|
||||||
fields.put(this.config.embeddingFieldName, embedding);
|
|
||||||
fields.put(this.config.contentFieldName, document.getContent());
|
|
||||||
fields.putAll(document.getMetadata());
|
|
||||||
pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields);
|
|
||||||
}
|
|
||||||
List<Object> responses = pipeline.syncAndReturnAll();
|
|
||||||
Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny();
|
|
||||||
if (errResponse.isPresent()) {
|
|
||||||
String message = MessageFormat.format("Could not add document: {0}", errResponse.get());
|
|
||||||
if (logger.isErrorEnabled()) {
|
|
||||||
logger.error(message);
|
|
||||||
}
|
|
||||||
throw new RuntimeException(message);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private String key(String id) {
|
|
||||||
return this.config.prefix + id;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Optional<Boolean> delete(List<String> idList) {
|
|
||||||
try (Pipeline pipeline = this.jedis.pipelined()) {
|
|
||||||
for (String id : idList) {
|
|
||||||
pipeline.jsonDel(key(id));
|
|
||||||
}
|
|
||||||
List<Object> responses = pipeline.syncAndReturnAll();
|
|
||||||
Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
|
|
||||||
if (errResponse.isPresent()) {
|
|
||||||
if (logger.isErrorEnabled()) {
|
|
||||||
logger.error("Could not delete document: {}", errResponse.get());
|
|
||||||
}
|
|
||||||
return Optional.of(false);
|
|
||||||
}
|
|
||||||
return Optional.of(true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Document> similaritySearch(SearchRequest request) {
|
|
||||||
|
|
||||||
Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero");
|
|
||||||
Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1,
|
|
||||||
"The similarity score is bounded between 0 and 1; least to most similar respectively.");
|
|
||||||
|
|
||||||
String filter = nativeExpressionFilter(request);
|
|
||||||
|
|
||||||
String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName,
|
|
||||||
EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
|
|
||||||
|
|
||||||
List<String> returnFields = new ArrayList<>();
|
|
||||||
this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
|
|
||||||
returnFields.add(this.config.embeddingFieldName);
|
|
||||||
returnFields.add(this.config.contentFieldName);
|
|
||||||
returnFields.add(DISTANCE_FIELD_NAME);
|
|
||||||
var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
|
|
||||||
Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding))
|
|
||||||
.returnFields(returnFields.toArray(new String[0]))
|
|
||||||
.setSortBy(DISTANCE_FIELD_NAME, true)
|
|
||||||
.dialect(2);
|
|
||||||
|
|
||||||
SearchResult result = this.jedis.ftSearch(this.config.indexName, query);
|
|
||||||
return result.getDocuments()
|
|
||||||
.stream()
|
|
||||||
.filter(d -> similarityScore(d) >= request.getSimilarityThreshold())
|
|
||||||
.map(this::toDocument)
|
|
||||||
.toList();
|
|
||||||
}
|
|
||||||
|
|
||||||
private Document toDocument(redis.clients.jedis.search.Document doc) {
|
|
||||||
var id = doc.getId().substring(this.config.prefix.length());
|
|
||||||
var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName)
|
|
||||||
: null;
|
|
||||||
Map<String, Object> metadata = this.config.metadataFields.stream()
|
|
||||||
.map(MetadataField::name)
|
|
||||||
.filter(doc::hasProperty)
|
|
||||||
.collect(Collectors.toMap(Function.identity(), doc::getString));
|
|
||||||
metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc));
|
|
||||||
return new Document(id, content, metadata);
|
|
||||||
}
|
|
||||||
|
|
||||||
private float similarityScore(redis.clients.jedis.search.Document doc) {
|
|
||||||
return (2 - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
private String nativeExpressionFilter(SearchRequest request) {
|
|
||||||
if (request.getFilterExpression() == null) {
|
|
||||||
return "*";
|
|
||||||
}
|
|
||||||
return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
|
|
||||||
if (!this.initializeSchema) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If index already exists don't do anything
|
|
||||||
if (this.jedis.ftList().contains(this.config.indexName)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
String response = this.jedis.ftCreate(this.config.indexName,
|
|
||||||
FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields());
|
|
||||||
if (!RESPONSE_OK.test(response)) {
|
|
||||||
String message = MessageFormat.format("Could not create index: {0}", response);
|
|
||||||
throw new RuntimeException(message);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Iterable<SchemaField> schemaFields() {
|
|
||||||
Map<String, Object> vectorAttrs = new HashMap<>();
|
|
||||||
vectorAttrs.put("DIM", this.embeddingModel.dimensions());
|
|
||||||
vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC);
|
|
||||||
vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32);
|
|
||||||
List<SchemaField> fields = new ArrayList<>();
|
|
||||||
fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0));
|
|
||||||
fields.add(VectorField.builder()
|
|
||||||
.fieldName(jsonPath(this.config.embeddingFieldName))
|
|
||||||
.algorithm(vectorAlgorithm())
|
|
||||||
.attributes(vectorAttrs)
|
|
||||||
.as(this.config.embeddingFieldName)
|
|
||||||
.build());
|
|
||||||
|
|
||||||
if (!CollectionUtils.isEmpty(this.config.metadataFields)) {
|
|
||||||
for (MetadataField field : this.config.metadataFields) {
|
|
||||||
fields.add(schemaField(field));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fields;
|
|
||||||
}
|
|
||||||
|
|
||||||
private SchemaField schemaField(MetadataField field) {
|
|
||||||
String fieldName = jsonPath(field.name);
|
|
||||||
switch (field.fieldType) {
|
|
||||||
case NUMERIC:
|
|
||||||
return NumericField.of(fieldName).as(field.name);
|
|
||||||
case TAG:
|
|
||||||
return TagField.of(fieldName).as(field.name);
|
|
||||||
case TEXT:
|
|
||||||
return TextField.of(fieldName).as(field.name);
|
|
||||||
default:
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private VectorAlgorithm vectorAlgorithm() {
|
|
||||||
if (config.vectorAlgorithm == Algorithm.HSNW) {
|
|
||||||
return VectorAlgorithm.HNSW;
|
|
||||||
}
|
|
||||||
return VectorAlgorithm.FLAT;
|
|
||||||
}
|
|
||||||
|
|
||||||
private String jsonPath(String field) {
|
|
||||||
return JSON_PATH_PREFIX + field;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static float[] toFloatArray(List<Double> embeddingDouble) {
|
|
||||||
float[] embeddingFloat = new float[embeddingDouble.size()];
|
|
||||||
int i = 0;
|
|
||||||
for (Double d : embeddingDouble) {
|
|
||||||
embeddingFloat[i++] = d.floatValue();
|
|
||||||
}
|
|
||||||
return embeddingFloat;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,6 +1,5 @@
|
||||||
package cn.iocoder.yudao.framework.ai.chat;
|
package cn.iocoder.yudao.framework.ai.chat;
|
||||||
|
|
||||||
import com.azure.ai.openai.OpenAIClient;
|
|
||||||
import com.azure.ai.openai.OpenAIClientBuilder;
|
import com.azure.ai.openai.OpenAIClientBuilder;
|
||||||
import com.azure.core.credential.AzureKeyCredential;
|
import com.azure.core.credential.AzureKeyCredential;
|
||||||
import com.azure.core.util.ClientOptions;
|
import com.azure.core.util.ClientOptions;
|
||||||
|
@ -27,13 +26,13 @@ import static org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatP
|
||||||
*/
|
*/
|
||||||
public class AzureOpenAIChatModelTests {
|
public class AzureOpenAIChatModelTests {
|
||||||
|
|
||||||
private final OpenAIClient openAiApi = (new OpenAIClientBuilder())
|
// TODO @芋艿:晚点在调整
|
||||||
|
private final OpenAIClientBuilder openAiApi = new OpenAIClientBuilder()
|
||||||
.endpoint("https://eastusprejade.openai.azure.com")
|
.endpoint("https://eastusprejade.openai.azure.com")
|
||||||
.credential(new AzureKeyCredential("xxx"))
|
.credential(new AzureKeyCredential("xxx"))
|
||||||
.clientOptions((new ClientOptions()).setApplicationId("spring-ai"))
|
.clientOptions((new ClientOptions()).setApplicationId("spring-ai"));
|
||||||
.buildClient();
|
|
||||||
private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi,
|
private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi,
|
||||||
AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build());
|
AzureOpenAiChatOptions.builder().deploymentName(DEFAULT_DEPLOYMENT_NAME).build());
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Disabled
|
@Disabled
|
||||||
|
|
|
@ -1,20 +1,6 @@
|
||||||
package cn.iocoder.yudao.framework.ai.chat;
|
package cn.iocoder.yudao.framework.ai.chat;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Disabled;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.springframework.ai.chat.messages.Message;
|
|
||||||
import org.springframework.ai.chat.messages.SystemMessage;
|
|
||||||
import org.springframework.ai.chat.messages.UserMessage;
|
|
||||||
import org.springframework.ai.chat.model.ChatResponse;
|
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
|
||||||
import org.springframework.ai.ollama.OllamaChatModel;
|
import org.springframework.ai.ollama.OllamaChatModel;
|
||||||
import org.springframework.ai.ollama.api.OllamaApi;
|
|
||||||
import org.springframework.ai.ollama.api.OllamaModel;
|
|
||||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
|
||||||
import reactor.core.publisher.Flux;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link OllamaChatModel} 集成测试
|
* {@link OllamaChatModel} 集成测试
|
||||||
|
@ -23,41 +9,41 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public class LlamaChatModelTests {
|
public class LlamaChatModelTests {
|
||||||
|
|
||||||
private final OllamaApi ollamaApi = new OllamaApi(
|
// private final OllamaApi ollamaApi = new OllamaApi(
|
||||||
"http://127.0.0.1:11434");
|
// "http://127.0.0.1:11434");
|
||||||
private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi,
|
// private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi,
|
||||||
OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName()));
|
// OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName()));
|
||||||
|
//
|
||||||
@Test
|
// @Test
|
||||||
@Disabled
|
// @Disabled
|
||||||
public void testCall() {
|
// public void testCall() {
|
||||||
// 准备参数
|
// // 准备参数
|
||||||
List<Message> messages = new ArrayList<>();
|
// List<Message> messages = new ArrayList<>();
|
||||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||||
messages.add(new UserMessage("1 + 1 = ?"));
|
// messages.add(new UserMessage("1 + 1 = ?"));
|
||||||
|
//
|
||||||
// 调用
|
// // 调用
|
||||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
// ChatResponse response = chatModel.call(new Prompt(messages));
|
||||||
// 打印结果
|
// // 打印结果
|
||||||
System.out.println(response);
|
// System.out.println(response);
|
||||||
System.out.println(response.getResult().getOutput());
|
// System.out.println(response.getResult().getOutput());
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
@Test
|
// @Test
|
||||||
@Disabled
|
// @Disabled
|
||||||
public void testStream() {
|
// public void testStream() {
|
||||||
// 准备参数
|
// // 准备参数
|
||||||
List<Message> messages = new ArrayList<>();
|
// List<Message> messages = new ArrayList<>();
|
||||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||||
messages.add(new UserMessage("1 + 1 = ?"));
|
// messages.add(new UserMessage("1 + 1 = ?"));
|
||||||
|
//
|
||||||
// 调用
|
// // 调用
|
||||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
// Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||||
// 打印结果
|
// // 打印结果
|
||||||
flux.doOnNext(response -> {
|
// flux.doOnNext(response -> {
|
||||||
// System.out.println(response);
|
//// System.out.println(response);
|
||||||
System.out.println(response.getResult().getOutput());
|
// System.out.println(response.getResult().getOutput());
|
||||||
}).then().block();
|
// }).then().block();
|
||||||
}
|
// }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ public class OpenAIChatModelTests {
|
||||||
"https://api.holdai.top",
|
"https://api.holdai.top",
|
||||||
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf");
|
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf");
|
||||||
private final OpenAiChatModel chatModel = new OpenAiChatModel(openAiApi,
|
private final OpenAiChatModel chatModel = new OpenAiChatModel(openAiApi,
|
||||||
OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O).build());
|
OpenAiChatOptions.builder().model(OpenAiApi.ChatModel.GPT_4_O).build());
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Disabled
|
@Disabled
|
||||||
|
|
|
@ -25,7 +25,7 @@ public class YiYanChatModelTests {
|
||||||
"qS8k8dYr2nXunagK4SSU8Xjj",
|
"qS8k8dYr2nXunagK4SSU8Xjj",
|
||||||
"pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
|
"pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
|
||||||
private final QianFanChatModel chatModel = new QianFanChatModel(qianFanApi,
|
private final QianFanChatModel chatModel = new QianFanChatModel(qianFanApi,
|
||||||
QianFanChatOptions.builder().withModel(QianFanApi.ChatModel.ERNIE_Tiny_8K.getValue()).build()
|
QianFanChatOptions.builder().model(QianFanApi.ChatModel.ERNIE_Tiny_8K.getValue()).build()
|
||||||
);
|
);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -24,7 +24,7 @@ public class ZhiPuAiChatModelTests {
|
||||||
|
|
||||||
private final ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi("32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs");
|
private final ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi("32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs");
|
||||||
private final ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(zhiPuAiApi,
|
private final ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(zhiPuAiApi,
|
||||||
ZhiPuAiChatOptions.builder().withModel(ZhiPuAiApi.ChatModel.GLM_4.getModelName()).build());
|
ZhiPuAiChatOptions.builder().model(ZhiPuAiApi.ChatModel.GLM_4.getName()).build());
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Disabled
|
@Disabled
|
||||||
|
|
|
@ -25,9 +25,9 @@ public class QianFanImageTests {
|
||||||
// 准备参数
|
// 准备参数
|
||||||
// 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576
|
// 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576
|
||||||
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
|
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
|
||||||
.withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
|
.model(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
|
||||||
.withWidth(1024).withHeight(1024)
|
.width(1024).height(1024)
|
||||||
.withN(1)
|
.N(1)
|
||||||
.build();
|
.build();
|
||||||
ImagePrompt prompt = new ImagePrompt("good", imageOptions);
|
ImagePrompt prompt = new ImagePrompt("good", imageOptions);
|
||||||
|
|
||||||
|
|
|
@ -1,34 +1,30 @@
|
||||||
package cn.iocoder.yudao.framework.ai.image;
|
package cn.iocoder.yudao.framework.ai.image;
|
||||||
|
|
||||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
|
||||||
|
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
|
||||||
|
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
|
||||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||||
import com.alibaba.dashscope.utils.Constants;
|
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.springframework.ai.image.ImageOptions;
|
import org.springframework.ai.image.ImageOptions;
|
||||||
import org.springframework.ai.image.ImagePrompt;
|
import org.springframework.ai.image.ImagePrompt;
|
||||||
import org.springframework.ai.image.ImageResponse;
|
import org.springframework.ai.image.ImageResponse;
|
||||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
|
* {@link DashScopeImageModel} 集成测试类
|
||||||
*
|
*
|
||||||
* @author fansili
|
* @author fansili
|
||||||
*/
|
*/
|
||||||
public class TongYiImagesModelTest {
|
public class TongYiImagesModelTest {
|
||||||
|
|
||||||
private final ImageSynthesis imageApi = new ImageSynthesis();
|
private final DashScopeImageModel imageModel = new DashScopeImageModel(
|
||||||
private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
|
new DashScopeImageApi("sk-7d903764249848cfa912733146da12d1"));
|
||||||
|
|
||||||
static {
|
|
||||||
Constants.apiKey = "sk-Zsd81gZYg7";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Disabled
|
@Disabled
|
||||||
public void imageCallTest() {
|
public void imageCallTest() {
|
||||||
// 准备参数
|
// 准备参数
|
||||||
ImageOptions options = OpenAiImageOptions.builder()
|
ImageOptions options = DashScopeImageOptions.builder()
|
||||||
.withModel(ImageSynthesis.Models.WANX_V1)
|
.withModel(ImageSynthesis.Models.WANX_V1)
|
||||||
.withHeight(256).withWidth(256)
|
.withHeight(256).withWidth(256)
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -22,7 +22,7 @@ public class ZhiPuAiImageModelTests {
|
||||||
public void testCall() {
|
public void testCall() {
|
||||||
// 准备参数
|
// 准备参数
|
||||||
ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
|
ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
|
||||||
.withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
|
.model(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
|
||||||
.build();
|
.build();
|
||||||
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
|
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue