【代码优化】AI:适配 Spring AI 1.0.6 对 Ollama 的逻辑
This commit is contained in:
parent
5655ae925c
commit
d05a7bd59a
|
@ -20,6 +20,7 @@ import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
||||||
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
||||||
|
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
|
@ -133,14 +134,14 @@ public class AiImageServiceImpl implements AiImageService {
|
||||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
||||||
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
|
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
|
||||||
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
|
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
|
||||||
return StabilityAiImageOptions.builder().withModel(draw.getModel())
|
return StabilityAiImageOptions.builder().model(draw.getModel())
|
||||||
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
.height(draw.getHeight()).width(draw.getWidth())
|
||||||
.withSeed(Long.valueOf(draw.getOptions().get("seed")))
|
.seed(Long.valueOf(draw.getOptions().get("seed")))
|
||||||
.withCfgScale(Float.valueOf(draw.getOptions().get("scale")))
|
.cfgScale(Float.valueOf(draw.getOptions().get("scale")))
|
||||||
.withSteps(Integer.valueOf(draw.getOptions().get("steps")))
|
.steps(Integer.valueOf(draw.getOptions().get("steps")))
|
||||||
.withSampler(String.valueOf(draw.getOptions().get("sampler")))
|
.sampler(String.valueOf(draw.getOptions().get("sampler")))
|
||||||
.withStylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
|
.stylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
|
||||||
.withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
|
.clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
|
||||||
.build();
|
.build();
|
||||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
|
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
|
||||||
return DashScopeImageOptions.builder()
|
return DashScopeImageOptions.builder()
|
||||||
|
@ -149,12 +150,12 @@ public class AiImageServiceImpl implements AiImageService {
|
||||||
.build();
|
.build();
|
||||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
|
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
|
||||||
return QianFanImageOptions.builder()
|
return QianFanImageOptions.builder()
|
||||||
.withModel(draw.getModel()).withN(1)
|
.model(draw.getModel()).N(1)
|
||||||
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
.height(draw.getHeight()).width(draw.getWidth())
|
||||||
.build();
|
.build();
|
||||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
|
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
|
||||||
return ZhiPuAiImageOptions.builder()
|
return ZhiPuAiImageOptions.builder()
|
||||||
.withModel(draw.getModel())
|
.model(draw.getModel())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
|
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
|
||||||
|
|
|
@ -35,7 +35,6 @@ public class YudaoAiAutoConfiguration {
|
||||||
return new AiModelFactoryImpl();
|
return new AiModelFactoryImpl();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ========== 各种 AI Client 创建 ==========
|
// ========== 各种 AI Client 创建 ==========
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
|
|
|
@ -1,6 +1,20 @@
|
||||||
package cn.iocoder.yudao.framework.ai.chat;
|
package cn.iocoder.yudao.framework.ai.chat;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.springframework.ai.chat.messages.Message;
|
||||||
|
import org.springframework.ai.chat.messages.SystemMessage;
|
||||||
|
import org.springframework.ai.chat.messages.UserMessage;
|
||||||
|
import org.springframework.ai.chat.model.ChatResponse;
|
||||||
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
import org.springframework.ai.ollama.OllamaChatModel;
|
import org.springframework.ai.ollama.OllamaChatModel;
|
||||||
|
import org.springframework.ai.ollama.api.OllamaApi;
|
||||||
|
import org.springframework.ai.ollama.api.OllamaModel;
|
||||||
|
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link OllamaChatModel} 集成测试
|
* {@link OllamaChatModel} 集成测试
|
||||||
|
@ -9,41 +23,43 @@ import org.springframework.ai.ollama.OllamaChatModel;
|
||||||
*/
|
*/
|
||||||
public class LlamaChatModelTests {
|
public class LlamaChatModelTests {
|
||||||
|
|
||||||
// private final OllamaApi ollamaApi = new OllamaApi(
|
private final OllamaChatModel chatModel = OllamaChatModel.builder()
|
||||||
// "http://127.0.0.1:11434");
|
.ollamaApi(new OllamaApi("http://127.0.0.1:11434")) // Ollama 服务地址
|
||||||
// private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi,
|
.defaultOptions(OllamaOptions.builder()
|
||||||
// OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName()));
|
.model(OllamaModel.LLAMA3.getName()) // 模型
|
||||||
//
|
.build())
|
||||||
// @Test
|
.build();
|
||||||
// @Disabled
|
|
||||||
// public void testCall() {
|
@Test
|
||||||
// // 准备参数
|
@Disabled
|
||||||
// List<Message> messages = new ArrayList<>();
|
public void testCall() {
|
||||||
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
// 准备参数
|
||||||
// messages.add(new UserMessage("1 + 1 = ?"));
|
List<Message> messages = new ArrayList<>();
|
||||||
//
|
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||||
// // 调用
|
messages.add(new UserMessage("1 + 1 = ?"));
|
||||||
// ChatResponse response = chatModel.call(new Prompt(messages));
|
|
||||||
// // 打印结果
|
// 调用
|
||||||
|
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||||
|
// 打印结果
|
||||||
|
System.out.println(response);
|
||||||
|
System.out.println(response.getResult().getOutput());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Disabled
|
||||||
|
public void testStream() {
|
||||||
|
// 准备参数
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||||
|
messages.add(new UserMessage("1 + 1 = ?"));
|
||||||
|
|
||||||
|
// 调用
|
||||||
|
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||||
|
// 打印结果
|
||||||
|
flux.doOnNext(response -> {
|
||||||
// System.out.println(response);
|
// System.out.println(response);
|
||||||
// System.out.println(response.getResult().getOutput());
|
System.out.println(response.getResult().getOutput());
|
||||||
// }
|
}).then().block();
|
||||||
//
|
}
|
||||||
// @Test
|
|
||||||
// @Disabled
|
|
||||||
// public void testStream() {
|
|
||||||
// // 准备参数
|
|
||||||
// List<Message> messages = new ArrayList<>();
|
|
||||||
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
|
||||||
// messages.add(new UserMessage("1 + 1 = ?"));
|
|
||||||
//
|
|
||||||
// // 调用
|
|
||||||
// Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
|
||||||
// // 打印结果
|
|
||||||
// flux.doOnNext(response -> {
|
|
||||||
//// System.out.println(response);
|
|
||||||
// System.out.println(response.getResult().getOutput());
|
|
||||||
// }).then().block();
|
|
||||||
// }
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue