【代码优化】AI:硅基流动的图片生成

This commit is contained in:
YunaiV 2025-03-23 10:41:51 +08:00
parent 813e7af846
commit 59c744520a
10 changed files with 179 additions and 364 deletions

View File

@ -11,7 +11,7 @@ import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconflowImageOptions; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageOptions;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
@ -147,8 +147,8 @@ public class AiImageServiceImpl implements AiImageService {
.build(); .build();
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.SILICON_FLOW.getPlatform())) { } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.SILICON_FLOW.getPlatform())) {
// https://docs.siliconflow.cn/cn/api-reference/images/images-generations // https://docs.siliconflow.cn/cn/api-reference/images/images-generations
return SiliconflowImageOptions.builder().model(model.getModel()) return SiliconFlowImageOptions.builder().model(model.getModel())
.withHeight(draw.getHeight()).withHeight(draw.getWidth()) .height(draw.getHeight()).width(draw.getWidth())
.build(); .build();
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) { } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage // https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage

View File

@ -8,7 +8,7 @@ import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel; import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel; import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiiconflowApiConstants; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
@ -114,11 +114,11 @@ public class YudaoAiAutoConfiguration {
public SiliconFlowChatModel buildSiliconFlowChatClient(YudaoAiProperties.SiliconFlowProperties properties) { public SiliconFlowChatModel buildSiliconFlowChatClient(YudaoAiProperties.SiliconFlowProperties properties) {
if (StrUtil.isEmpty(properties.getModel())) { if (StrUtil.isEmpty(properties.getModel())) {
properties.setModel(SiiconflowApiConstants.MODEL_DEFAULT); properties.setModel(SiliconFlowApiConstants.MODEL_DEFAULT);
} }
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder() OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder() .openAiApi(OpenAiApi.builder()
.baseUrl(SiiconflowApiConstants.DEFAULT_BASE_URL) .baseUrl(SiliconFlowApiConstants.DEFAULT_BASE_URL)
.apiKey(properties.getApiKey()) .apiKey(properties.getApiKey())
.build()) .build())
.defaultOptions(OpenAiChatOptions.builder() .defaultOptions(OpenAiChatOptions.builder()

View File

@ -15,10 +15,10 @@ import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel; import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel; import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiiconflowApiConstants; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiiconflowmageApi;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconflowImageModel; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageApi;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageModel;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils; import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
@ -45,6 +45,7 @@ import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.stabilityai.StabilityAiImageAutoConfiguration;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientConnectionDetails; import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientConnectionDetails;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientProperties; import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientProperties;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreAutoConfiguration; import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreAutoConfiguration;
@ -228,7 +229,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
case OPENAI: case OPENAI:
return buildOpenAiImageModel(apiKey, url); return buildOpenAiImageModel(apiKey, url);
case SILICON_FLOW: case SILICON_FLOW:
return buildSiiconflowImageModel(apiKey,url); return buildSiliconFlowImageModel(apiKey,url);
case STABLE_DIFFUSION: case STABLE_DIFFUSION:
return buildStabilityAiImageModel(apiKey, url); return buildStabilityAiImageModel(apiKey, url);
default: default:
@ -474,12 +475,12 @@ public class AiModelFactoryImpl implements AiModelFactory {
} }
/** /**
* Siiconflow * 创建 SiliconFlowImageModel 对象
*/ */
private SiliconflowImageModel buildSiiconflowImageModel(String apiToken, String url) { private SiliconFlowImageModel buildSiliconFlowImageModel(String apiToken, String url) {
url = StrUtil.blankToDefault(url, SiiconflowApiConstants.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, SiliconFlowApiConstants.DEFAULT_BASE_URL);
SiiconflowmageApi openAiApi = SiiconflowmageApi.builder().baseUrl(url).apiKey(apiToken).build(); SiliconFlowImageApi openAiApi = new SiliconFlowImageApi(url, apiToken);
return new SiliconflowImageModel(openAiApi); return new SiliconFlowImageModel(openAiApi);
} }
/** /**
@ -490,6 +491,9 @@ public class AiModelFactoryImpl implements AiModelFactory {
return OllamaChatModel.builder().ollamaApi(ollamaApi).toolCallingManager(getToolCallingManager()).build(); return OllamaChatModel.builder().ollamaApi(ollamaApi).toolCallingManager(getToolCallingManager()).build();
} }
/**
* 可参考 {@link StabilityAiImageAutoConfiguration} stabilityAiImageModel 方法
*/
private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) { private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url); StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);

View File

@ -1,207 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.model.ApiKey;
import org.springframework.ai.model.NoopApiKey;
import org.springframework.ai.model.SimpleApiKey;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import java.util.Map;
/**
* Siiconflow Image API.
*
* @see <a href= "https://docs.siliconflow.cn/cn/api-reference/images/images-generations">Images</a>
*
* @author zzt
*/
public class SiiconflowmageApi {
private final RestClient restClient;
/**
* Create a new Siiconflow Image api with base URL set.
* @param aiToken OpenAI apiKey.
*/
public SiiconflowmageApi(String aiToken) {
this(SiiconflowApiConstants.DEFAULT_BASE_URL, aiToken, RestClient.builder());
}
/**
* Create a new Siiconflow Image API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.
* @param openAiToken Siiconflow apiKey.
*/
public SiiconflowmageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
/**
* Create a new OpenAI Image API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.
* @param apiKey OpenAI apiKey.
* @param restClientBuilder the rest client builder to use.
*/
public SiiconflowmageApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, responseErrorHandler);
}
/**
* Create a new OpenAI Image API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.
* @param apiKey OpenAI apiKey.
* @param headers the http headers to use.
* @param restClientBuilder the rest client builder to use.
* @param responseErrorHandler the response error handler to use.
*/
public SiiconflowmageApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, responseErrorHandler);
}
/**
* Create a new OpenAI Image API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.
* @param apiKey OpenAI apiKey.
* @param headers the http headers to use.
* @param restClientBuilder the rest client builder to use.
* @param responseErrorHandler the response error handler to use.
*/
public SiiconflowmageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
// @formatter:off
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(h -> {
if(!(apiKey instanceof NoopApiKey)) {
h.setBearerAuth(apiKey.getValue());
}
h.setContentType(MediaType.APPLICATION_JSON);
h.addAll(headers);
})
.defaultStatusHandler(responseErrorHandler)
.build();
// @formatter:on
}
public ResponseEntity<OpenAiImageApi.OpenAiImageResponse> createImage(SiliconflowImageRequest siliconflowImageRequest) {
Assert.notNull(siliconflowImageRequest, "Image request cannot be null.");
Assert.hasLength(siliconflowImageRequest.prompt(), "Prompt cannot be empty.");
return this.restClient.post()
.uri("v1/images/generations")
.body(siliconflowImageRequest)
.retrieve()
.toEntity(OpenAiImageApi.OpenAiImageResponse.class);
}
// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
public record SiliconflowImageRequest (
@JsonProperty("prompt") String prompt,
@JsonProperty("model") String model,
@JsonProperty("batch_size") Integer batchSize,
@JsonProperty("negative_prompt") String negativePrompt,
@JsonProperty("seed") Integer seed,
@JsonProperty("num_inference_steps") Integer numInferenceSteps,
@JsonProperty("guidance_scale") Float guidanceScale,
@JsonProperty("image") String image) {
public SiliconflowImageRequest(String prompt, String model) {
this(prompt, model, null, null, null, null, null, null);
}
}
public static Builder builder() {
return new Builder();
}
/**
* Builder to construct {@link SiiconflowmageApi} instance.
*/
public static class Builder {
private String baseUrl = SiiconflowApiConstants.DEFAULT_BASE_URL;
private ApiKey apiKey;
private MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
private RestClient.Builder restClientBuilder = RestClient.builder();
private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER;
public Builder baseUrl(String baseUrl) {
Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
this.baseUrl = baseUrl;
return this;
}
public Builder apiKey(ApiKey apiKey) {
Assert.notNull(apiKey, "apiKey cannot be null");
this.apiKey = apiKey;
return this;
}
public Builder apiKey(String simpleApiKey) {
Assert.notNull(simpleApiKey, "simpleApiKey cannot be null");
this.apiKey = new SimpleApiKey(simpleApiKey);
return this;
}
public Builder headers(MultiValueMap<String, String> headers) {
Assert.notNull(headers, "headers cannot be null");
this.headers = headers;
return this;
}
public Builder restClientBuilder(RestClient.Builder restClientBuilder) {
Assert.notNull(restClientBuilder, "restClientBuilder cannot be null");
this.restClientBuilder = restClientBuilder;
return this;
}
public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) {
Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null");
this.responseErrorHandler = responseErrorHandler;
return this;
}
public SiiconflowmageApi build() {
Assert.notNull(this.apiKey, "apiKey must be set");
return new SiiconflowmageApi(this.baseUrl, this.apiKey, this.headers, this.restClientBuilder,
this.responseErrorHandler);
}
}
}

View File

@ -17,11 +17,11 @@
package cn.iocoder.yudao.framework.ai.core.model.siliconflow; package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
/** /**
* Common value constants for Siiconflow api. * SiliconFlow API 枚举类
* *
* @author zzt * @author zzt
*/ */
public final class SiiconflowApiConstants { public final class SiliconFlowApiConstants {
public static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn"; public static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn";
@ -29,8 +29,4 @@ public final class SiiconflowApiConstants {
public static final String PROVIDER_NAME = "Siiconflow"; public static final String PROVIDER_NAME = "Siiconflow";
private SiiconflowApiConstants() {
}
} }

View File

@ -0,0 +1,115 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.model.ApiKey;
import org.springframework.ai.model.NoopApiKey;
import org.springframework.ai.model.SimpleApiKey;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import java.util.Map;
/**
* 硅基流动 Image API
*
* @see <a href= "https://docs.siliconflow.cn/cn/api-reference/images/images-generations">Images</a>
*
* @author zzt
*/
public class SiliconFlowImageApi {
private final RestClient restClient;
public SiliconFlowImageApi(String aiToken) {
this(SiliconFlowApiConstants.DEFAULT_BASE_URL, aiToken, RestClient.builder());
}
public SiliconFlowImageApi(String baseUrl, String openAiToken) {
this(baseUrl, openAiToken, RestClient.builder());
}
public SiliconFlowImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
public SiliconFlowImageApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, responseErrorHandler);
}
public SiliconFlowImageApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, responseErrorHandler);
}
public SiliconFlowImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
// @formatter:off
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(h -> {
if(!(apiKey instanceof NoopApiKey)) {
h.setBearerAuth(apiKey.getValue());
}
h.setContentType(MediaType.APPLICATION_JSON);
h.addAll(headers);
})
.defaultStatusHandler(responseErrorHandler)
.build();
// @formatter:on
}
public ResponseEntity<OpenAiImageApi.OpenAiImageResponse> createImage(SiliconflowImageRequest siliconflowImageRequest) {
Assert.notNull(siliconflowImageRequest, "Image request cannot be null.");
Assert.hasLength(siliconflowImageRequest.prompt(), "Prompt cannot be empty.");
return this.restClient.post()
.uri("v1/images/generations")
.body(siliconflowImageRequest)
.retrieve()
.toEntity(OpenAiImageApi.OpenAiImageResponse.class);
}
// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
public record SiliconflowImageRequest (
@JsonProperty("prompt") String prompt,
@JsonProperty("model") String model,
@JsonProperty("batch_size") Integer batchSize,
@JsonProperty("negative_prompt") String negativePrompt,
@JsonProperty("seed") Integer seed,
@JsonProperty("num_inference_steps") Integer numInferenceSteps,
@JsonProperty("guidance_scale") Float guidanceScale,
@JsonProperty("image") String image) {
public SiliconflowImageRequest(String prompt, String model) {
this(prompt, model, null, null, null, null, null, null);
}
}
}

View File

@ -17,6 +17,7 @@
package cn.iocoder.yudao.framework.ai.core.model.siliconflow; package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.ObservationRegistry;
import lombok.Setter;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.ai.image.*; import org.springframework.ai.image.*;
@ -25,8 +26,8 @@ import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelObservationConvention; import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationDocumentation; import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
@ -36,77 +37,44 @@ import org.springframework.util.Assert;
import java.util.List; import java.util.List;
/** /**
* cv openapi图片模型方法 * 硅基流动 {@link ImageModel} 实现类
*
* 参考 {@link OpenAiImageModel} 实现
* *
* @author zzt * @author zzt
*/ */
public class SiliconflowImageModel implements ImageModel { public class SiliconFlowImageModel implements ImageModel {
private static final Logger logger = LoggerFactory.getLogger(SiliconflowImageModel.class); private static final Logger logger = LoggerFactory.getLogger(SiliconFlowImageModel.class);
private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention(); private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention();
/** private final SiliconFlowImageOptions defaultOptions;
* The default options used for the image completion requests.
*/
private final SiliconflowImageOptions defaultOptions;
/**
* The retry template used to retry the OpenAI Image API calls.
*/
private final RetryTemplate retryTemplate; private final RetryTemplate retryTemplate;
/** private final SiliconFlowImageApi siliconFlowImageApi;
* Low-level access to the OpenAI Image API.
*/
private final SiiconflowmageApi siiconflowmageApi;
/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry; private final ObservationRegistry observationRegistry;
/** @Setter
* Conventions to use for generating observations.
*/
private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
/** public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi) {
* Creates an instance of the OpenAiImageModel. this(siliconFlowImageApi, SiliconFlowImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
* @param siiconflowmageApi The OpenAiImageApi instance to be used for interacting with
* the OpenAI Image API.
* @throws IllegalArgumentException if openAiImageApi is null
*/
public SiliconflowImageModel(SiiconflowmageApi siiconflowmageApi) {
this(siiconflowmageApi, SiliconflowImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
} }
/** public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi, SiliconFlowImageOptions options, RetryTemplate retryTemplate) {
* Initializes a new instance of the OpenAiImageModel. this(siliconFlowImageApi, options, retryTemplate, ObservationRegistry.NOOP);
* @param siiconflowmageApi The OpenAiImageApi instance to be used for interacting with
* the OpenAI Image API.
* @param options The OpenAiImageOptions to configure the image model.
* @param retryTemplate The retry template.
*/
public SiliconflowImageModel(SiiconflowmageApi siiconflowmageApi, SiliconflowImageOptions options, RetryTemplate retryTemplate) {
this(siiconflowmageApi, options, retryTemplate, ObservationRegistry.NOOP);
} }
/** public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi, SiliconFlowImageOptions options, RetryTemplate retryTemplate,
* Initializes a new instance of the OpenAiImageModel.
* @param siiconflowmageApi The OpenAiImageApi instance to be used for interacting with
* the OpenAI Image API.
* @param options The OpenAiImageOptions to configure the image model.
* @param retryTemplate The retry template.
* @param observationRegistry The ObservationRegistry used for instrumentation.
*/
public SiliconflowImageModel(SiiconflowmageApi siiconflowmageApi, SiliconflowImageOptions options, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) { ObservationRegistry observationRegistry) {
Assert.notNull(siiconflowmageApi, "OpenAiImageApi must not be null"); Assert.notNull(siliconFlowImageApi, "OpenAiImageApi must not be null");
Assert.notNull(options, "options must not be null"); Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null"); Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null"); Assert.notNull(observationRegistry, "observationRegistry must not be null");
this.siiconflowmageApi = siiconflowmageApi; this.siliconFlowImageApi = siliconFlowImageApi;
this.defaultOptions = options; this.defaultOptions = options;
this.retryTemplate = retryTemplate; this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry; this.observationRegistry = observationRegistry;
@ -114,11 +82,11 @@ public class SiliconflowImageModel implements ImageModel {
@Override @Override
public ImageResponse call(ImagePrompt imagePrompt) { public ImageResponse call(ImagePrompt imagePrompt) {
SiiconflowmageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt); SiliconFlowImageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt);
var observationContext = ImageModelObservationContext.builder() var observationContext = ImageModelObservationContext.builder()
.imagePrompt(imagePrompt) .imagePrompt(imagePrompt)
.provider(OpenAiApiConstants.PROVIDER_NAME) .provider(SiliconFlowApiConstants.PROVIDER_NAME)
.requestOptions(imagePrompt.getOptions()) .requestOptions(imagePrompt.getOptions())
.build(); .build();
@ -127,7 +95,7 @@ public class SiliconflowImageModel implements ImageModel {
this.observationRegistry) this.observationRegistry)
.observe(() -> { .observe(() -> {
ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity = this.retryTemplate ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity = this.retryTemplate
.execute(ctx -> this.siiconflowmageApi.createImage(imageRequest)); .execute(ctx -> this.siliconFlowImageApi.createImage(imageRequest));
ImageResponse imageResponse = convertResponse(imageResponseEntity, imageRequest); ImageResponse imageResponse = convertResponse(imageResponseEntity, imageRequest);
@ -137,17 +105,17 @@ public class SiliconflowImageModel implements ImageModel {
}); });
} }
private SiiconflowmageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt) { private SiliconFlowImageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt) {
String instructions = imagePrompt.getInstructions().get(0).getText(); String instructions = imagePrompt.getInstructions().get(0).getText();
SiiconflowmageApi.SiliconflowImageRequest imageRequest = new SiiconflowmageApi.SiliconflowImageRequest(instructions, SiliconFlowImageApi.SiliconflowImageRequest imageRequest = new SiliconFlowImageApi.SiliconflowImageRequest(instructions,
imagePrompt.getOptions().getModel()); imagePrompt.getOptions().getModel());
return ModelOptionsUtils.merge(imagePrompt.getOptions(), imageRequest, SiiconflowmageApi.SiliconflowImageRequest.class); return ModelOptionsUtils.merge(imagePrompt.getOptions(), imageRequest, SiliconFlowImageApi.SiliconflowImageRequest.class);
} }
private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity, private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
SiiconflowmageApi.SiliconflowImageRequest siliconflowImageRequest) { SiliconFlowImageApi.SiliconflowImageRequest siliconflowImageRequest) {
OpenAiImageApi.OpenAiImageResponse imageApiResponse = imageResponseEntity.getBody(); OpenAiImageApi.OpenAiImageResponse imageApiResponse = imageResponseEntity.getBody();
if (imageApiResponse == null) { if (imageApiResponse == null) {
logger.warn("No image response returned for request: {}", siliconflowImageRequest); logger.warn("No image response returned for request: {}", siliconflowImageRequest);

View File

@ -1,17 +1,22 @@
package cn.iocoder.yudao.framework.ai.core.model.siliconflow; package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.openai.OpenAiImageOptions;
/** /**
* 硅基流动画图能力 * 硅基流动 {@link ImageOptions}
* *
* @author zzt * @author zzt
*/ */
@Data @Data
public class SiliconflowImageOptions implements ImageOptions { @Builder
@AllArgsConstructor
@NoArgsConstructor
public class SiliconFlowImageOptions implements ImageOptions {
@JsonProperty("model") @JsonProperty("model")
private String model; private String model;
@ -37,7 +42,6 @@ public class SiliconflowImageOptions implements ImageOptions {
@JsonProperty("num_inference_steps") @JsonProperty("num_inference_steps")
private Integer numInferenceSteps = 25; private Integer numInferenceSteps = 25;
/** /**
* This value is used to control the degree of match between the generated image and the given prompt. The higher the value, the more the generated image will tend to strictly match the text prompt. The lower the value, the more creative and diverse the generated image will be, potentially containing more unexpected elements. * This value is used to control the degree of match between the generated image and the given prompt. The higher the value, the more the generated image will tend to strictly match the text prompt. The lower the value, the more creative and diverse the generated image will be, potentially containing more unexpected elements.
* *
@ -47,7 +51,7 @@ public class SiliconflowImageOptions implements ImageOptions {
private Float guidanceScale = 0.75F; private Float guidanceScale = 0.75F;
/** /**
* 如果想要每次都生成固定的图片可以把seed设置为固定值 * 如果想要每次都生成固定的图片可以把 seed 设置为固定值
* *
*/ */
@JsonProperty("seed") @JsonProperty("seed")
@ -59,13 +63,11 @@ public class SiliconflowImageOptions implements ImageOptions {
@JsonProperty("image") @JsonProperty("image")
private String image; private String image;
/** /**
* *
*/ */
private Integer width; private Integer width;
/** /**
* *
*/ */
@ -85,21 +87,6 @@ public class SiliconflowImageOptions implements ImageOptions {
} }
} }
/**
* 硅基流动
* @return
*/
public static SiliconflowImageOptions.Builder builder() {
return new SiliconflowImageOptions.Builder();
}
@Override
public String toString() {
return "SiliconflowImageOptions{" + "model='" + getModel() + '\'' + ", batch_size=" + batchSize + ", imageSize=" + imageSize + ", negativePrompt='"
+ negativePrompt + '\'' + '}';
}
@Override @Override
public Integer getN() { public Integer getN() {
return null; return null;
@ -115,52 +102,4 @@ public class SiliconflowImageOptions implements ImageOptions {
return null; return null;
} }
public static class Builder extends OpenAiImageOptions{
private final SiliconflowImageOptions options;
private Builder() {
this.options = new SiliconflowImageOptions();
}
public SiliconflowImageOptions.Builder model(String model) {
this.options.setModel(model);
return this;
}
public SiliconflowImageOptions.Builder withBatchSize(Integer batchSize) {
options.setBatchSize(batchSize);
return this;
}
public SiliconflowImageOptions.Builder withModel(String model) {
options.setModel(model);
return this;
}
public SiliconflowImageOptions.Builder withWidth(Integer width) {
options.setWidth(width);
return this;
}
public SiliconflowImageOptions.Builder withHeight(Integer height) {
options.setHeight(height);
return this;
}
public SiliconflowImageOptions.Builder withSeed(Integer seed) {
options.setSeed(seed);
return this;
}
public SiliconflowImageOptions.Builder withNegativePrompt(String negativePrompt) {
options.setNegativePrompt(negativePrompt);
return this;
}
public SiliconflowImageOptions build() {
return options;
}
}
} }

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.framework.ai.chat; package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiiconflowApiConstants; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -26,11 +26,11 @@ public class SiliconFlowChatModelTests {
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder() private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder() .openAiApi(OpenAiApi.builder()
.baseUrl(SiiconflowApiConstants.DEFAULT_BASE_URL) .baseUrl(SiliconFlowApiConstants.DEFAULT_BASE_URL)
.apiKey("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // apiKey .apiKey("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // apiKey
.build()) .build())
.defaultOptions(OpenAiChatOptions.builder() .defaultOptions(OpenAiChatOptions.builder()
.model(SiiconflowApiConstants.MODEL_DEFAULT) // 模型 .model(SiliconFlowApiConstants.MODEL_DEFAULT) // 模型
// .model("deepseek-ai/DeepSeek-R1") // 模型deepseek-ai/DeepSeek-R1可用赠费 // .model("deepseek-ai/DeepSeek-R1") // 模型deepseek-ai/DeepSeek-R1可用赠费
// .model("Pro/deepseek-ai/DeepSeek-R1") // 模型Pro/deepseek-ai/DeepSeek-R1需要付费 // .model("Pro/deepseek-ai/DeepSeek-R1") // 模型Pro/deepseek-ai/DeepSeek-R1需要付费
.temperature(0.7) .temperature(0.7)

View File

@ -1,27 +1,27 @@
package cn.iocoder.yudao.framework.ai.image; package cn.iocoder.yudao.framework.ai.image;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiiconflowmageApi; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageApi;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconflowImageModel; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageModel;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconflowImageOptions; import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageOptions;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
/** /**
* {@link SiliconflowImageModel} 集成测试 * {@link SiliconFlowImageModel} 集成测试
*/ */
public class SiliconFlowImageModelTests { public class SiliconFlowImageModelTests {
private final SiliconflowImageModel imageModel = new SiliconflowImageModel( private final SiliconFlowImageModel imageModel = new SiliconFlowImageModel(
new SiiconflowmageApi("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // 密钥 new SiliconFlowImageApi("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // 密钥
); );
@Test @Test
@Disabled @Disabled
public void testCall() { public void testCall() {
// 准备参数 // 准备参数
SiliconflowImageOptions imageOptions = SiliconflowImageOptions.builder() SiliconFlowImageOptions imageOptions = SiliconFlowImageOptions.builder()
.model("Kwai-Kolors/Kolors") .model("Kwai-Kolors/Kolors")
.build(); .build();
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions); ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);