【代码优化】AI:硅基流动的图片生成

This commit is contained in:
YunaiV 2025-03-23 11:22:46 +08:00
parent 59c744520a
commit ef5e56d560
4 changed files with 35 additions and 6 deletions

View File

@ -207,6 +207,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return SpringUtil.getBean(QianFanImageModel.class); return SpringUtil.getBean(QianFanImageModel.class);
case ZHI_PU: case ZHI_PU:
return SpringUtil.getBean(ZhiPuAiImageModel.class); return SpringUtil.getBean(ZhiPuAiImageModel.class);
case SILICON_FLOW:
return SpringUtil.getBean(SiliconFlowImageModel.class);
case OPENAI: case OPENAI:
return SpringUtil.getBean(OpenAiImageModel.class); return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION: case STABLE_DIFFUSION:

View File

@ -27,6 +27,8 @@ public final class SiliconFlowApiConstants {
public static final String MODEL_DEFAULT = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"; public static final String MODEL_DEFAULT = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B";
public static final String DEFAULT_IMAGE_MODEL = "Kwai-Kolors/Kolors";
public static final String PROVIDER_NAME = "Siiconflow"; public static final String PROVIDER_NAME = "Siiconflow";
} }

View File

@ -31,6 +31,7 @@ import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate; import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -82,7 +83,8 @@ public class SiliconFlowImageModel implements ImageModel {
@Override @Override
public ImageResponse call(ImagePrompt imagePrompt) { public ImageResponse call(ImagePrompt imagePrompt) {
SiliconFlowImageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt); SiliconFlowImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
SiliconFlowImageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);
var observationContext = ImageModelObservationContext.builder() var observationContext = ImageModelObservationContext.builder()
.imagePrompt(imagePrompt) .imagePrompt(imagePrompt)
@ -105,13 +107,14 @@ public class SiliconFlowImageModel implements ImageModel {
}); });
} }
private SiliconFlowImageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt) { private SiliconFlowImageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt,
SiliconFlowImageOptions requestImageOptions) {
String instructions = imagePrompt.getInstructions().get(0).getText(); String instructions = imagePrompt.getInstructions().get(0).getText();
SiliconFlowImageApi.SiliconflowImageRequest imageRequest = new SiliconFlowImageApi.SiliconflowImageRequest(instructions, SiliconFlowImageApi.SiliconflowImageRequest imageRequest = new SiliconFlowImageApi.SiliconflowImageRequest(instructions,
imagePrompt.getOptions().getModel()); SiliconFlowApiConstants.DEFAULT_IMAGE_MODEL);
return ModelOptionsUtils.merge(imagePrompt.getOptions(), imageRequest, SiliconFlowImageApi.SiliconflowImageRequest.class); return ModelOptionsUtils.merge(requestImageOptions, imageRequest, SiliconFlowImageApi.SiliconflowImageRequest.class);
} }
private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity, private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
@ -131,4 +134,26 @@ public class SiliconFlowImageModel implements ImageModel {
ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created()); ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created());
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata); return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
} }
private SiliconFlowImageOptions mergeOptions(@Nullable ImageOptions runtimeOptions, SiliconFlowImageOptions defaultOptions) {
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
SiliconFlowImageOptions.class);
if (runtimeOptionsForProvider == null) {
return defaultOptions;
}
return SiliconFlowImageOptions.builder()
// Handle portable image options
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.batchSize(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
.width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
.height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
// Handle OpenAI specific image options
.negativePrompt(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNegativePrompt(), defaultOptions.getNegativePrompt()))
.numInferenceSteps(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNumInferenceSteps(), defaultOptions.getNumInferenceSteps()))
.guidanceScale(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getGuidanceScale(), defaultOptions.getGuidanceScale()))
.seed(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getSeed(), defaultOptions.getSeed()))
.build();
}
} }

View File

@ -89,12 +89,12 @@ public class SiliconFlowImageOptions implements ImageOptions {
@Override @Override
public Integer getN() { public Integer getN() {
return null; return batchSize;
} }
@Override @Override
public String getResponseFormat() { public String getResponseFormat() {
return null; return "url";
} }
@Override @Override