From 489767db09dd36319e22c37cddd86e5d6f529de0 Mon Sep 17 00:00:00 2001 From: Marzouki-Sami Date: Wed, 29 Jan 2025 08:02:36 +0100 Subject: [PATCH 1/9] Adding Vertex AI Imagen models support Closes gh-2133 Signed-off-by: Marzouki-Sami samymarzouki502@gmail.com Signed-off-by: Marzouki-Sami --- models/spring-ai-vertex-ai-imagen/pom.xml | 109 +++++ .../VertexAiImagenConnectionDetails.java | 181 ++++++++ .../imagen/VertexAiImagenImageModel.java | 256 +++++++++++ .../imagen/VertexAiImagenImageModelName.java | 48 ++ .../imagen/VertexAiImagenImageOptions.java | 435 ++++++++++++++++++ .../vertexai/imagen/VertexAiImagenUtils.java | 230 +++++++++ ...VertexAiImagenImageGenerationMetadata.java | 81 ++++ .../imagen/TestVertexAiImagenImageModel.java | 76 +++ .../imagen/VertexAiImagenImageModelIT.java | 87 ++++ ...VertexAiImagenImageModelObservationIT.java | 122 +++++ .../imagen/VertexAiImagenImageRetryTests.java | 148 ++++++ .../VertexAiImagenAutoConfiguration.java | 86 ++++ .../VertexAiImagenConnectionProperties.java | 84 ++++ .../imagen/VertexAiImagenImageProperties.java | 57 +++ ...ertexAiImagenModelAutoConfigurationIT.java | 86 ++++ .../pom.xml | 61 +++ 16 files changed, 2147 insertions(+) create mode 100644 models/spring-ai-vertex-ai-imagen/pom.xml create mode 100644 models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java create mode 100644 models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenConnectionProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenImageProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenModelAutoConfigurationIT.java create mode 100644 spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml diff --git a/models/spring-ai-vertex-ai-imagen/pom.xml b/models/spring-ai-vertex-ai-imagen/pom.xml new file mode 100644 index 0000000000..973c053c0d --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/pom.xml @@ -0,0 +1,109 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-M5 + ../../pom.xml + + + spring-ai-vertex-ai-imagen + jar + Spring AI Model - Vertex AI Imagen + Vertex AI Imagen models support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + + + com.google.cloud + libraries-bom + ${com.google.cloud.version} + pom + import + + + + + + + + com.google.cloud + google-cloud-aiplatform + + + commons-logging + commons-logging + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + io.micrometer + micrometer-observation-test + test + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java new file mode 100644 index 0000000000..a466b95f57 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java @@ -0,0 +1,181 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen; + +import java.io.IOException; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; + +import org.springframework.util.StringUtils; + + +/** + * VertexAiImagenConnectionDetails represents the details of a connection to the Vertex AI imagen service. + * It provides methods to access the project ID, location, publisher, and PredictionServiceSettings. + * + * @author Sami Marzouki + */ +public class VertexAiImagenConnectionDetails { + + public static final String DEFAULT_ENDPOINT = "us-central1-aiplatform.googleapis.com:443"; + + public static final String DEFAULT_ENDPOINT_SUFFIX = "-aiplatform.googleapis.com:443"; + + public static final String DEFAULT_PUBLISHER = "google"; + + private static final String DEFAULT_LOCATION = "us-central1"; + + /** + * Your project ID. + */ + private final String projectId; + + /** + * A location is a region + * you can specify in a request to control where data is stored at rest. For a list of + * available regions, see Generative + * AI on Vertex AI locations. + */ + private final String location; + + private final String publisher; + + private final PredictionServiceSettings predictionServiceSettings; + + public VertexAiImagenConnectionDetails(String projectId, String location, String publisher, + PredictionServiceSettings predictionServiceSettings) { + this.projectId = projectId; + this.location = location; + this.publisher = publisher; + this.predictionServiceSettings = predictionServiceSettings; + } + + public static Builder builder() { + return new Builder(); + } + + public String getProjectId() { + return this.projectId; + } + + public String getLocation() { + return this.location; + } + + public String getPublisher() { + return this.publisher; + } + + public EndpointName getEndpointName(String modelName) { + return EndpointName.ofProjectLocationPublisherModelName(this.projectId, this.location, this.publisher, + modelName); + } + + public com.google.cloud.aiplatform.v1.PredictionServiceSettings getPredictionServiceSettings() { + return this.predictionServiceSettings; + } + + public static class Builder { + + /** + * The Vertex AI embedding endpoint. + */ + private String endpoint; + + /** + * Your project ID. + */ + private String projectId; + + /** + * A location is a + * region you can + * specify in a request to control where data is stored at rest. For a list of + * available regions, see Generative + * AI on Vertex AI locations. + */ + private String location; + + /** + * + */ + private String publisher; + + /** + * Allows the connection settings to be customised + */ + private PredictionServiceSettings predictionServiceSettings; + + public Builder apiEndpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder location(String location) { + this.location = location; + return this; + } + + public Builder publisher(String publisher) { + this.publisher = publisher; + return this; + } + + public Builder predictionServiceSettings(PredictionServiceSettings predictionServiceSettings) { + this.predictionServiceSettings = predictionServiceSettings; + return this; + } + + public VertexAiImagenConnectionDetails build() { + if (!StringUtils.hasText(this.endpoint)) { + if (!StringUtils.hasText(this.location)) { + this.endpoint = DEFAULT_ENDPOINT; + this.location = DEFAULT_LOCATION; + } else { + this.endpoint = this.location + DEFAULT_ENDPOINT_SUFFIX; + } + } + + if (!StringUtils.hasText(this.publisher)) { + this.publisher = DEFAULT_PUBLISHER; + } + + if (this.predictionServiceSettings == null) { + try { + this.predictionServiceSettings = PredictionServiceSettings.newBuilder() + .setEndpoint(this.endpoint) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + return new VertexAiImagenConnectionDetails(this.projectId, this.location, this.publisher, + this.predictionServiceSettings); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java new file mode 100644 index 0000000000..837bc3747c --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java @@ -0,0 +1,256 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Value; +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.image.Image; +import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageGenerationMetadata; +import org.springframework.ai.image.ImageModel; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; +import org.springframework.ai.image.observation.ImageModelObservationContext; +import org.springframework.ai.image.observation.ImageModelObservationConvention; +import org.springframework.ai.image.observation.ImageModelObservationDocumentation; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.vertexai.imagen.VertexAiImagenUtils.ImageInstanceBuilder; +import org.springframework.ai.vertexai.imagen.VertexAiImagenUtils.ImageParametersBuilder; +import org.springframework.ai.vertexai.imagen.metadata.VertexAiImagenImageGenerationMetadata; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +/** + * VertexAiImagenImageModel is a class that implements the ImageModel interface. It + * provides a client for calling the Imagen on Vertex AI image generation API. + * + * @author Sami Marzouki + */ +public class VertexAiImagenImageModel implements ImageModel { + + private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention(); + + /** + * The default options used for the image completion requests. + */ + private final VertexAiImagenImageOptions defaultOptions; + + /** + * The connection details for Imagen on Vertex AI. + */ + private final VertexAiImagenConnectionDetails connectionDetails; + + /** + * The retry template used to retry the Imagen on Vertex AI Image API calls. + */ + private final RetryTemplate retryTemplate; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageOptions defaultOptions) { + this(connectionDetails, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate) { + this(connectionDetails, defaultOptions, retryTemplate, ObservationRegistry.NOOP); + } + + public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + Assert.notNull(defaultOptions, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + this.connectionDetails = connectionDetails; + this.defaultOptions = defaultOptions; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + } + + private static ImageParametersBuilder getImageParametersBuilder(VertexAiImagenImageOptions finalOptions) { + ImageParametersBuilder parametersBuilder = ImageParametersBuilder.of(); + + if (finalOptions.getN() != null) { + parametersBuilder.sampleCount(finalOptions.getN()); + } + if (finalOptions.getSeed() != null) { + parametersBuilder.seed(finalOptions.getSeed()); + } + if (finalOptions.getNegativePrompt() != null) { + parametersBuilder.negativePrompt(finalOptions.getNegativePrompt()); + } + if (finalOptions.getAspectRatio() != null) { + parametersBuilder.aspectRatio(finalOptions.getAspectRatio()); + } + if (finalOptions.getAddWatermark() != null) { + parametersBuilder.addWatermark(finalOptions.getAddWatermark()); + } + if (finalOptions.getStorageUri() != null) { + parametersBuilder.storageUri(finalOptions.getStorageUri()); + } + if (finalOptions.getPersonGeneration() != null) { + parametersBuilder.personGeneration(finalOptions.getPersonGeneration()); + } + if (finalOptions.getSafetySetting() != null) { + parametersBuilder.safetySetting(finalOptions.getSafetySetting()); + } + if (finalOptions.getOutputOptions() != null) { + + ImageParametersBuilder.OutputOptions outputOptions = ImageParametersBuilder.OutputOptions.of(); + if (finalOptions.getOutputOptions().getMimeType() != null) { + outputOptions.mimeType(finalOptions.getOutputOptions().getMimeType()); + } + if (finalOptions.getOutputOptions().getCompressionQuality() != null) { + outputOptions.compressionQuality(finalOptions.getOutputOptions().getCompressionQuality()); + } + + parametersBuilder.outputOptions(outputOptions.build()); + } + + return parametersBuilder; + } + + @Override + public ImageResponse call(ImagePrompt imagePrompt) { + VertexAiImagenImageOptions finalOptions = mergedOptions(imagePrompt); + + var observationContext = ImageModelObservationContext.builder() + .imagePrompt(imagePrompt) + .provider(AiProvider.VERTEX_AI.value()) + .requestOptions(finalOptions) + .build(); + + return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + PredictionServiceClient client = createPredictionServiceClient(); + + EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); + + PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(imagePrompt, endpointName, + finalOptions); + + PredictResponse imageResponse = this.retryTemplate + .execute(context -> getPredictResponse(client, predictRequestBuilder)); + + List imageGenerationList = new ArrayList<>(); + for (Value prediction : imageResponse.getPredictionsList()) { + Value bytesBase64Encoded = prediction.getStructValue().getFieldsOrThrow("bytesBase64Encoded"); + Value mimeType = prediction.getStructValue().getFieldsOrThrow("mimeType"); + ImageGenerationMetadata metadata = new VertexAiImagenImageGenerationMetadata( + imagePrompt.getInstructions().get(0).getText(), finalOptions.getModel(), + mimeType.getStringValue()); + Image image = new Image(null, bytesBase64Encoded.getStringValue()); + imageGenerationList.add(new ImageGeneration(image, metadata)); + } + ImageResponse response = new ImageResponse(imageGenerationList); + + observationContext.setResponse(response); + + return response; + + }); + } + + private VertexAiImagenImageOptions mergedOptions(ImagePrompt imagePrompt) { + + VertexAiImagenImageOptions mergedOptions = this.defaultOptions; + + if (imagePrompt.getOptions() != null) { + var defaultOptionsCopy = VertexAiImagenImageOptions.builder().from(this.defaultOptions).build(); + mergedOptions = ModelOptionsUtils.merge(imagePrompt.getOptions(), defaultOptionsCopy, + VertexAiImagenImageOptions.class); + } + + return mergedOptions; + } + + protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePrompt, EndpointName endpointName, + VertexAiImagenImageOptions finalOptions) { + PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString()); + + ImageParametersBuilder parametersBuilder = getImageParametersBuilder(finalOptions); + if (finalOptions.getOutputOptions() != null) { + ImageParametersBuilder.OutputOptions outputOptionsBuilder = ImageParametersBuilder.OutputOptions.of(); + if (finalOptions.getResponseFormat() != null) { + outputOptionsBuilder.mimeType(finalOptions.getResponseFormat()); + } + if (finalOptions.getCompressionQuality() != null) { + outputOptionsBuilder.compressionQuality(finalOptions.getCompressionQuality()); + } + parametersBuilder.outputOptions(outputOptionsBuilder.build()); + } + + predictRequestBuilder.setParameters(VertexAiImagenUtils.valueOf(parametersBuilder.build())); + + for (int i = 0; i < imagePrompt.getInstructions().size(); i++) { + + ImageInstanceBuilder instanceBuilder = ImageInstanceBuilder + .of(imagePrompt.getInstructions().get(i).getText()); + predictRequestBuilder.addInstances(VertexAiImagenUtils.valueOf(instanceBuilder.build())); + } + return predictRequestBuilder; + } + + // for testing + protected PredictionServiceClient createPredictionServiceClient() { + try { + return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + // for testing + protected PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { + return client.predict(predictRequestBuilder.build()); + } + + /** + * Use the provided convention for reporting observation data. + * + * @param observationConvention The provided convention + */ + public void setObservationConvention(ImageModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java new file mode 100644 index 0000000000..f44cadd4d5 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen; + +/** + * Imagen on VertexAI Models: + * - Image generation + * + * @author Sami Marzouki + */ +public enum VertexAiImagenImageModelName { + + IMAGEN_3("imagen-3.0-generate-001"), + + IMAGEN_3_FAST("imagen-3.0-fast-generate-001"), + + IMAGEN_3_CUSTOMIZATION_AND_EDITING("imagen-3.0-capability-001"), + + IMAGEN_2_V006("imagegeneration@006"), + + IMAGEN_2_V005("imagegeneration@005"), + + IMAGEN_1_V002("imagegeneration@002"); + + private final String value; + + VertexAiImagenImageModelName(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java new file mode 100644 index 0000000000..2e0cabaef4 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java @@ -0,0 +1,435 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.imagen; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.image.ImageOptions; + +import static org.springframework.ai.vertexai.imagen.VertexAiImagenUtils.calculateSizeFromAspectRatio; + +/** + * Options for the Vertex AI Image service. + * + * @author Sami Marzouki + */ +public class VertexAiImagenImageOptions implements ImageOptions { + + public static final String DEFAULT_MODEL_NAME = VertexAiImagenImageModelName.IMAGEN_2_V006.getValue(); + + /** + * Required: int + * The number of images to generate. The default value is 4. The + * imagen-3.0-generate-001 model supports values 1 through 4. The + * imagen-3.0-fast-generate-001 model supports values 1 through 4. The + * imagegeneration@006 model supports values 1 through 4. The imagegeneration@005 + * model supports values 1 through 4. The imagegeneration@002 model supports values 1 + * through 8. + */ + @JsonProperty("sampleCount") + private Integer n; + + /** + * The model to use for image generation. + */ + @JsonProperty("model") + private String model; + + /** + * Optional: Uint32 + * The random seed for image generation. This is not available when addWatermark is set to true. + */ + @JsonProperty("seed") + private Integer seed; + + /** + * Optional: string + * A description of what to discourage in the generated images. + * The imagen-3.0-generate-001 model supports up to 480 tokens. + * The imagen-3.0-fast-generate-001 model supports up to 480 tokens. + * The imagegeneration@006 model supports up to 128 tokens. + * The imagegeneration@005 model supports up to 128 tokens. + * The imagegeneration@002 model supports up to 64 tokens. + */ + @JsonProperty("negativePrompt") + private String negativePrompt; + + /** + * Optional: string + * The aspect ratio for the image. The default value is "1:1". + * The imagen-3.0-generate-001 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". + * The imagen-3.0-fast-generate-001 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". + * The imagegeneration@006 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". + * The imagegeneration@005 model supports "1:1" or "9:16". + * The imagegeneration@002 model supports "1:1". + */ + @JsonProperty("aspectRatio") + private String aspectRatio; + + /** + * Optional: outputOptions + * Describes the output image format in an outputOptions object. + * + * @see OutputOptions + */ + @JsonProperty("outputOptions") + private OutputOptions outputOptions; + + /** + * Optional: string (imagegeneration@002 only) + * Describes the style for the generated images. The following values are supported: + * "photograph", "digital_art", "landscape", "sketch", "watercolor", "cyberpunk", "pop_art". + */ + @JsonProperty("sampleImageStyle") + private String style; + + /** + * Optional: string (imagen-3.0-generate-001, imagen-3.0-fast-generate-001, and imagegeneration@006 only) + * Allow generation of people by the model. The following values are supported: + * "dont_allow": Disallow the inclusion of people or faces in images. + * "allow_adult": Allow generation of adults only. + * "allow_all": Allow generation of people of all ages. + * The default value is "allow_adult". + */ + @JsonProperty("personGeneration") + private String personGeneration; + + /** + * Optional: string (imagen-3.0-generate-001, imagen-3.0-fast-generate-001, and imagegeneration@006 only) + * Adds a filter level to safety filtering. The following values are supported: + * "block_low_and_above": Strongest filtering level, most strict blocking. Deprecated value: "block_most". + * "block_medium_and_above": Block some problematic prompts and responses. Deprecated value: "block_some". + * "block_only_high": Reduces the number of requests blocked due to safety filters. May increase objectionable + * content generated by Imagen. Deprecated value: "block_few". + * "block_none": Block very few problematic prompts and responses. Access to this feature is restricted. + * Previous field value: "block_fewest". + * The default value is "block_medium_and_above". + */ + @JsonProperty("safetySetting") + private String safetySetting; + + /** + * Optional: bool + * Add an invisible watermark to the generated images. + * The default value is false for the imagegeneration@002 and imagegeneration@005 models, + * and true for the imagen-3.0-fast-generate-001, imagegeneration@006, and imagegeneration@006 models. + */ + @JsonProperty("addWatermark") + private Boolean addWatermark; + + /** + * Optional: string + * Cloud Storage URI to store the generated images. + */ + @JsonProperty("storageUri") + private String storageUri; + + private List size; + + public static Builder builder() { + return new Builder(); + } + + @Override + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getWidth() { + if (this.size == null || this.size.isEmpty()) { + return null; + } + return this.size.get(0); + } + + @Override + public Integer getHeight() { + if (this.size == null || this.size.isEmpty()) { + return null; + } + return this.size.get(1); + } + + @Override + public String getStyle() { + return this.style; + } + + public void setStyle(String style) { + this.style = style; + } + + @Override + public String getResponseFormat() { + if (this.outputOptions == null) { + return null; + } + return this.outputOptions.mimeType; + } + + public Integer getCompressionQuality() { + if (this.outputOptions == null) { + return null; + } + return this.outputOptions.compressionQuality; + } + + public OutputOptions getOutputOptions() { + return this.outputOptions; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public String getNegativePrompt() { + return negativePrompt; + } + + public void setNegativePrompt(String negativePrompt) { + this.negativePrompt = negativePrompt; + } + + public String getAspectRatio() { + return aspectRatio; + } + + public void setAspectRatio(String aspectRatio) { + this.aspectRatio = aspectRatio; + } + + public String getPersonGeneration() { + return personGeneration; + } + + public void setPersonGeneration(String personGeneration) { + this.personGeneration = personGeneration; + } + + public String getSafetySetting() { + return safetySetting; + } + + public void setSafetySetting(String safetySetting) { + this.safetySetting = safetySetting; + } + + public Boolean getAddWatermark() { + return addWatermark; + } + + public void setAddWatermark(Boolean addWatermark) { + this.addWatermark = addWatermark; + } + + public String getStorageUri() { + return storageUri; + } + + public void setStorageUri(String storageUri) { + this.storageUri = storageUri; + } + + public void setSize(List size) { + this.size = size; + } + + public static final class OutputOptions { + + @JsonProperty("mimeType") + private String mimeType; + + @JsonProperty("compressionQuality") + private Integer compressionQuality; + + public static Builder builder() { + return new Builder(); + } + + public String getMimeType() { + return mimeType; + } + + public void setMimeType(String mimeType) { + this.mimeType = mimeType; + } + + public Integer getCompressionQuality() { + return compressionQuality; + } + + public void setCompressionQuality(Integer compressionQuality) { + this.compressionQuality = compressionQuality; + } + + public static final class Builder { + + private final OutputOptions options; + + private Builder() { + this.options = new OutputOptions(); + } + + public Builder mimeType(String format) { + this.options.setMimeType(format); + return this; + } + + public Builder compressionQuality(Integer compressionQuality) { + this.options.setCompressionQuality(compressionQuality); + return this; + } + + public OutputOptions build() { + return this.options; + } + } + } + + public static final class Builder { + + private final VertexAiImagenImageOptions options; + + private Builder() { + this.options = new VertexAiImagenImageOptions(); + } + + public Builder from(VertexAiImagenImageOptions fromOptions) { + if (fromOptions.getN() != null) { + this.options.setN(fromOptions.getN()); + } + if (fromOptions.getModel() != null) { + this.options.setModel(fromOptions.getModel()); + } + if (fromOptions.getAspectRatio() != null) { + this.options.setAspectRatio(fromOptions.getAspectRatio()); + this.options.setSize(calculateSizeFromAspectRatio(fromOptions.getAspectRatio())); + } + if (fromOptions.getStyle() != null) { + this.options.setStyle(fromOptions.getStyle()); + } + if (fromOptions.getOutputOptions() != null) { + if (fromOptions.getResponseFormat() != null) { + this.options.outputOptions.setMimeType(fromOptions.getResponseFormat()); + } + if (fromOptions.getCompressionQuality() != null) { + this.options.outputOptions.setCompressionQuality(fromOptions.getCompressionQuality()); + } + } + if (fromOptions.getSeed() != null) { + this.options.setSeed(fromOptions.getSeed()); + } + if (fromOptions.getNegativePrompt() != null) { + this.options.setNegativePrompt(fromOptions.getNegativePrompt()); + } + if (fromOptions.getPersonGeneration() != null) { + this.options.setPersonGeneration(fromOptions.getPersonGeneration()); + } + if (fromOptions.getSafetySetting() != null) { + this.options.setSafetySetting(fromOptions.getSafetySetting()); + } + if (fromOptions.getAddWatermark() != null) { + this.options.setAddWatermark(fromOptions.getAddWatermark()); + } + if (fromOptions.getStorageUri() != null) { + this.options.setStorageUri(fromOptions.getStorageUri()); + } + + return this; + } + + public Builder N(Integer n) { + this.options.setN(n); + return this; + } + + public Builder model(String model) { + this.options.setModel(model); + return this; + } + + public Builder seed(Integer seed) { + this.options.setSeed(seed); + return this; + } + + public Builder negativePrompt(String negativePrompt) { + this.options.setNegativePrompt(negativePrompt); + return this; + } + + public Builder aspectRatio(String aspectRatio) { + this.options.setAspectRatio(aspectRatio); + this.options.setSize(calculateSizeFromAspectRatio(aspectRatio)); + return this; + } + + public Builder outputOptions(OutputOptions outputOptions) { + this.options.outputOptions = outputOptions; + return this; + } + + public Builder personGeneration(String personGeneration) { + this.options.setPersonGeneration(personGeneration); + return this; + } + + public Builder safetySetting(String safetySetting) { + this.options.setSafetySetting(safetySetting); + return this; + } + + public Builder addWatermark(Boolean addWatermark) { + this.options.setAddWatermark(addWatermark); + return this; + } + + public Builder storageUri(String storageUri) { + this.options.setStorageUri(storageUri); + return this; + } + + public Builder style(String style) { + this.options.setStyle(style); + return this; + } + + public VertexAiImagenImageOptions build() { + return this.options; + } + + } +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java new file mode 100644 index 0000000000..d4650105e7 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java @@ -0,0 +1,230 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen; + +import java.util.Arrays; +import java.util.List; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; + +import org.springframework.util.Assert; + +/** + * Utility class for constructing parameter objects for Imagen on Vertex AI requests. + * + * @author Sami Marzouki + */ +public abstract class VertexAiImagenUtils { + + public static Value valueOf(boolean n) { + return Value.newBuilder().setBoolValue(n).build(); + } + + public static Value valueOf(String s) { + return Value.newBuilder().setStringValue(s).build(); + } + + public static Value valueOf(int n) { + return Value.newBuilder().setNumberValue(n).build(); + } + + public static Value valueOf(Struct struct) { + return Value.newBuilder().setStructValue(struct).build(); + } + + public static Value jsonToValue(String json) throws InvalidProtocolBufferException { + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + public static List calculateSizeFromAspectRatio(String aspectRatio) { + if (aspectRatio != null) { + return switch (aspectRatio) { + case "1:1" -> List.of(1024, 1024); + case "9:16" -> List.of(900, 1600); + case "16:9" -> List.of(1600, 900); + case "3:4" -> List.of(750, 1000); + case "4:3" -> List.of(1000, 750); + default -> throw new IllegalStateException("Unexpected value: " + aspectRatio + + " aspect ratio must be one of these values : ['1:1', '9:16', '16:9', '3:4', or '4:3']"); + }; + } + return Arrays.asList(1024, 1024); + } + + public static class ImageInstanceBuilder { + + public String prompt; + + public static ImageInstanceBuilder of(String prompt) { + Assert.hasText(prompt, "Prompt must not be empty"); + var builder = new ImageInstanceBuilder(); + builder.prompt = prompt; + return builder; + } + + public Struct build() { + Struct.Builder textBuilder = Struct.newBuilder(); + textBuilder.putFields("prompt", valueOf(this.prompt)); + return textBuilder.build(); + } + } + + public static class ImageParametersBuilder { + + public Integer sampleCount; + public Integer seed; + public String negativePrompt; + public String aspectRatio; + public Boolean addWatermark; + public String storageUri; + public String personGeneration; + public String safetySetting; + public Struct outputOptions; + + public static ImageParametersBuilder of() { + return new ImageParametersBuilder(); + } + + public ImageParametersBuilder sampleCount(Integer sampleCount) { + Assert.notNull(sampleCount, "Sample count must not be null"); + this.sampleCount = sampleCount; + return this; + } + + public ImageParametersBuilder seed(Integer seed) { + Assert.notNull(seed, "Seed must not be null"); + this.seed = seed; + return this; + } + + public ImageParametersBuilder negativePrompt(String negativePrompt) { + Assert.notNull(negativePrompt, "Negative prompt must not be null"); + this.negativePrompt = negativePrompt; + return this; + } + + public ImageParametersBuilder aspectRatio(String aspectRatio) { + Assert.notNull(aspectRatio, "Aspect ratio must not be null"); + this.aspectRatio = aspectRatio; + return this; + } + + public ImageParametersBuilder addWatermark(Boolean addWatermark) { + Assert.notNull(addWatermark, "Add watermark must not be null"); + this.addWatermark = addWatermark; + return this; + } + + public ImageParametersBuilder storageUri(String storageUri) { + Assert.notNull(storageUri, "Storage URI must not be null"); + this.storageUri = storageUri; + return this; + } + + public ImageParametersBuilder personGeneration(String personGeneration) { + Assert.notNull(personGeneration, "Person generation must not be null"); + this.personGeneration = personGeneration; + return this; + } + + public ImageParametersBuilder safetySetting(String safetySetting) { + Assert.notNull(safetySetting, "Safety setting must not be null"); + this.safetySetting = safetySetting; + return this; + } + + public ImageParametersBuilder outputOptions(Struct outputOptions) { + Assert.notNull(outputOptions, "Output options must not be null"); + this.outputOptions = outputOptions; + return this; + } + + public Struct build() { + Struct.Builder imageParametersBuilder = Struct.newBuilder(); + + if (this.sampleCount != null) { + imageParametersBuilder.putFields("sampleCount", valueOf(this.sampleCount)); + } + if (this.seed != null) { + imageParametersBuilder.putFields("seed", valueOf(this.seed)); + } + if (this.negativePrompt != null) { + imageParametersBuilder.putFields("negativePrompt", valueOf(this.negativePrompt)); + } + if (this.aspectRatio != null) { + imageParametersBuilder.putFields("aspectRatio", valueOf(this.aspectRatio)); + } + if (this.addWatermark != null) { + imageParametersBuilder.putFields("addWatermark", valueOf(this.addWatermark)); + } + if (this.storageUri != null) { + imageParametersBuilder.putFields("storageUri", valueOf(this.storageUri)); + } + if (this.personGeneration != null) { + imageParametersBuilder.putFields("personGeneration", valueOf(this.personGeneration)); + } + if (this.safetySetting != null) { + imageParametersBuilder.putFields("safetySetting", valueOf(this.safetySetting)); + } + if (this.outputOptions != null) { + imageParametersBuilder.putFields("outputOptions", Value.newBuilder().setStructValue(this.outputOptions).build()); + } + return imageParametersBuilder.build(); + } + + public static class OutputOptions { + public String mimeType; + public Integer compressionQuality; + + public static OutputOptions of() { + return new OutputOptions(); + } + + public OutputOptions mimeType(String mimeType) { + Assert.notNull(mimeType, "MIME type must not be null"); + this.mimeType = mimeType; + return this; + } + + public OutputOptions compressionQuality(Integer compressionQuality) { + Assert.notNull(compressionQuality, "Compression quality must not be null"); + this.compressionQuality = compressionQuality; + return this; + } + + public Struct build() { + Struct.Builder outputOptionsBuilder = Struct.newBuilder(); + + if (this.mimeType != null) { + outputOptionsBuilder.putFields("mimeType", valueOf(this.mimeType)); + } + if (this.compressionQuality != null) { + outputOptionsBuilder.putFields("compressionQuality", valueOf(this.compressionQuality)); + } + return outputOptionsBuilder.build(); + } + + } + + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java new file mode 100644 index 0000000000..ee12bd1c97 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java @@ -0,0 +1,81 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen.metadata; + +import java.util.Objects; + +import org.springframework.ai.image.ImageGenerationMetadata; + +/** + * VertexAiImagenImageGenerationMetadata is a class that defines the metadata for Imagen + * on Vertex AI. + * + * @author Sami Marzouki + */ +public class VertexAiImagenImageGenerationMetadata implements ImageGenerationMetadata { + + private final String prompt; + + private final String model; + + private final String mimeType; + + public VertexAiImagenImageGenerationMetadata(String revisedPrompt, String mimeType, String model) { + this.prompt = revisedPrompt; + this.model = model; + this.mimeType = mimeType; + } + + public String getPrompt() { + return prompt; + } + + public String getModel() { + return model; + } + + public String getMimeType() { + return mimeType; + } + + @Override + public String toString() { + return "VertexAiImagenImageGenerationMetadata{" + + "prompt='" + prompt + '\'' + + ", model='" + model + '\'' + + ", mimeType='" + mimeType + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + VertexAiImagenImageGenerationMetadata that = (VertexAiImagenImageGenerationMetadata) o; + return Objects.equals(prompt, that.prompt) + && Objects.equals(model, that.model) + && Objects.equals(mimeType, that.mimeType); + } + + @Override + public int hashCode() { + return Objects.hash(prompt, model, mimeType); + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java new file mode 100644 index 0000000000..4c84c25419 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java @@ -0,0 +1,76 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package imagen; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.retry.support.RetryTemplate; + +/** + * @author Sami Marzouki + */ +public class TestVertexAiImagenImageModel extends VertexAiImagenImageModel { + + private PredictionServiceClient mockPredictionServiceClient; + + private PredictRequest.Builder mockPredictRequestBuilder; + + public TestVertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate) { + super(connectionDetails, defaultOptions, retryTemplate); + } + + public void setMockPredictionServiceClient(PredictionServiceClient mockPredictionServiceClient) { + this.mockPredictionServiceClient = mockPredictionServiceClient; + } + + @Override + public PredictionServiceClient createPredictionServiceClient() { + if (this.mockPredictionServiceClient != null) { + return this.mockPredictionServiceClient; + } + return super.createPredictionServiceClient(); + } + + @Override + public PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { + if (this.mockPredictionServiceClient != null) { + return this.mockPredictionServiceClient.predict(predictRequestBuilder.build()); + } + return super.getPredictResponse(client, predictRequestBuilder); + } + + public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictRequestBuilder) { + this.mockPredictRequestBuilder = mockPredictRequestBuilder; + } + + @Override + protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePrompt, EndpointName endpointName, + VertexAiImagenImageOptions finalOptions) { + if (this.mockPredictRequestBuilder != null) { + return this.mockPredictRequestBuilder; + } + return super.getPredictRequestBuilder(imagePrompt, endpointName, finalOptions); + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java new file mode 100644 index 0000000000..15465acee7 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package imagen; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.ai.vertexai.imagen.metadata.VertexAiImagenImageGenerationMetadata; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +/** + * @author Marzouki Sami + */ +@SpringBootTest(classes = VertexAiImagenImageModelIT.Config.class) +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_LOCATION", matches = ".*") +public class VertexAiImagenImageModelIT { + + @Autowired + protected VertexAiImagenImageModel imageModel; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"imagen-3.0-generate-001", "imagen-3.0-fast-generate-001", "imagen-3.0-capability-001", + "imagegeneration@006", "imagegeneration@005", "imagegeneration@002"}) + void defaultImageGenerator(String modelName) { + Assertions.assertThat(this.imageModel).isNotNull(); + + var options = VertexAiImagenImageOptions.builder().model(modelName).N(1).build(); + + ImageResponse imageResponse = this.imageModel + .call(new ImagePrompt("little kitten sitting on a purple cushion", options)); + + Assertions.assertThat(imageResponse.getResults()).hasSize(2); + Assertions.assertThat(imageResponse.getResults().get(0).getOutput().getB64Json()).isNotEmpty(); + Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getModel()).isNotEmpty(); + Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getPrompt()).isNotEmpty(); + Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getMimeType()).isNotEmpty(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public VertexAiImagenConnectionDetails connectionDetails() { + return VertexAiImagenConnectionDetails.builder() + .projectId(System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID")) + .location(System.getenv("VERTEX_AI_IMAGEN_LOCATION")) + .build(); + } + + @Bean + public VertexAiImagenImageModel imageModel(VertexAiImagenConnectionDetails connectionDetails) { + + VertexAiImagenImageOptions options = VertexAiImagenImageOptions.builder() + .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) + .build(); + + return new VertexAiImagenImageModel(connectionDetails, options); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java new file mode 100644 index 0000000000..8aa3adfe1a --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java @@ -0,0 +1,122 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package imagen; + +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.ImageResponseMetadata; +import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; +import org.springframework.ai.image.observation.ImageModelObservationDocumentation; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModelName; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for observation instrumentation in {@link VertexAiImagenImageModel}. + * + * @author Sami Marzouki + */ +@SpringBootTest(classes = VertexAiImagenImageModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_LOCATION", matches = ".*") +public class VertexAiImagenImageModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + VertexAiImagenImageModel imageModel; + + @Test + void observationForImageOperation() { + var options = VertexAiImagenImageOptions.builder() + .model(VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) + .N(1) + .build(); + + ImagePrompt imagePrompt = new ImagePrompt("Little kitten sitting on a purple cushion", options); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); + assertThat(imageResponse.getResults()).isNotEmpty(); + + ImageResponseMetadata responseMetadata = imageResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("image " + VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) + .hasLowCardinalityKeyValue( + ImageModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.IMAGE.value()) + .hasLowCardinalityKeyValue(ImageModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(), + AiProvider.VERTEX_AI.value()) + .hasLowCardinalityKeyValue( + ImageModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL.asString(), + VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public VertexAiImagenConnectionDetails connectionDetails() { + return VertexAiImagenConnectionDetails.builder() + .projectId(System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID")) + .location(System.getenv("VERTEX_AI_IMAGEN_LOCATION")) + .build(); + } + + @Bean + public VertexAiImagenImageModel imageModel(VertexAiImagenConnectionDetails connectionDetails, + ObservationRegistry observationRegistry) { + + VertexAiImagenImageOptions options = VertexAiImagenImageOptions.builder() + .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) + .build(); + + return new VertexAiImagenImageModel(connectionDetails, options, RetryUtils.DEFAULT_RETRY_TEMPLATE, + observationRegistry); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java new file mode 100644 index 0000000000..0086338969 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package imagen; + +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * @author Sami Marzouki + */ +@ExtendWith(MockitoExtension.class) +public class VertexAiImagenImageRetryTests { + + private TestRetryListener retryListener; + + @Mock + private PredictionServiceClient mockPredictionServiceClient; + + @Mock + private VertexAiImagenConnectionDetails mockConnectionDetails; + + @Mock + private PredictRequest.Builder mockPredictRequestBuilder; + + private TestVertexAiImagenImageModel imageModel; + + @BeforeEach + public void setUp() { + RetryTemplate retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + retryTemplate.registerListener(this.retryListener); + + this.imageModel = new TestVertexAiImagenImageModel(this.mockConnectionDetails, + VertexAiImagenImageOptions.builder().build(), retryTemplate); + this.imageModel.setMockPredictionServiceClient(this.mockPredictionServiceClient); + this.imageModel.setMockPredictRequestBuilder(this.mockPredictRequestBuilder); + given(this.mockPredictRequestBuilder.build()).willReturn(PredictRequest.getDefaultInstance()); + } + + @Test + public void vertexAiImageGeneratorTransientError() { + // Set up the mock PredictResponse + PredictResponse mockResponse = PredictResponse.newBuilder() + .addPredictions(Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("bytesBase64Encoded", Value.newBuilder().setStringValue("BASE64_IMG_BYTES").build()) + .putFields("mimeType", Value.newBuilder().setStringValue("image/png").build()) + .build()) + .build()) + .addPredictions(Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("mimeType", Value.newBuilder().setStringValue("image/png").build()) + .putFields("bytesBase64Encoded", Value.newBuilder().setStringValue("BASE64_IMG_BYTES").build()) + .build()) + .build()) + .build(); + + // Set up the mock PredictionServiceClient + given(this.mockPredictionServiceClient.predict(any())) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(mockResponse); + + ImageResponse result = this.imageModel.call(new ImagePrompt("text1", null)); + + assertThat(result).isNotNull(); + assertThat(result.getResults()).hasSize(2); + assertThat(result.getResults().get(0).getOutput().getB64Json()).isEqualTo("BASE64_IMG_BYTES"); + assertThat(result.getResults().get(1).getOutput().getB64Json()).isEqualTo("BASE64_IMG_BYTES"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + + verify(this.mockPredictRequestBuilder, times(3)).build(); + } + + @Test + public void vertexAiImageGeneratorNonTransientError() { + // Set up the mock PredictionServiceClient to throw a non-transient error + given(this.mockPredictionServiceClient.predict(any())).willThrow(new RuntimeException("Non Transient Error")); + + // Assert that a RuntimeException is thrown and not retried + assertThatThrownBy(() -> this.imageModel.call(new ImagePrompt("text1", null))) + .isInstanceOf(RuntimeException.class); + + // Verify that predict was called only once (no retries for non-transient errors) + verify(this.mockPredictionServiceClient, times(1)).predict(any()); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenAutoConfiguration.java new file mode 100644 index 0000000000..810feecdc6 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenAutoConfiguration.java @@ -0,0 +1,86 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vertexai.imagen; + +import java.io.IOException; + +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.image.observation.ImageModelObservationConvention; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * AutoConfiguration for Vertex AI Imagen. + * + * @author Sami Marzouki + */ +@AutoConfiguration(after = {SpringAiRetryAutoConfiguration.class}) +@ConditionalOnClass({VertexAI.class, VertexAiImagenImageModel.class}) +@EnableConfigurationProperties({VertexAiImagenImageProperties.class, VertexAiImagenConnectionProperties.class}) +@ImportAutoConfiguration(classes = {SpringAiRetryAutoConfiguration.class}) +public class VertexAiImagenAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public VertexAiImagenConnectionDetails connectionDetails( + VertexAiImagenConnectionProperties connectionProperties) throws IOException { + Assert.hasText(connectionProperties.getProjectId(), "Vertex AI project-id must be set!"); + Assert.hasText(connectionProperties.getLocation(), "Vertex AI location must be set!"); + + var connectionBuilder = VertexAiImagenConnectionDetails.builder() + .projectId(connectionProperties.getProjectId()) + .location(connectionProperties.getLocation()); + + if (StringUtils.hasText(connectionProperties.getApiEndpoint())) { + connectionBuilder.apiEndpoint(connectionProperties.getApiEndpoint()); + } + + return connectionBuilder.build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = VertexAiImagenImageProperties.CONFIG_PREFIX, name = "enabled", + havingValue = "true", matchIfMissing = true) + public VertexAiImagenImageModel imageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageProperties properties, RetryTemplate retryTemplate, + ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + + var imageModel = new VertexAiImagenImageModel(connectionDetails, properties.getOptions(), + retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + + observationConvention.ifAvailable(imageModel::setObservationConvention); + + return imageModel; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenConnectionProperties.java new file mode 100644 index 0000000000..fa86bf0090 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenConnectionProperties.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vertexai.imagen; + +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.core.io.Resource; + +/** + * Configuration properties for Vertex AI Imagen. + * + * @author Sami Marzouki + */ +@ConfigurationProperties(VertexAiImagenConnectionProperties.CONFIG_PREFIX) +public class VertexAiImagenConnectionProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.imagen"; + + /** + * Vertex AI Imagen project ID. + */ + private String projectId; + + /** + * Vertex AI Imagen location. + */ + private String location; + + /** + * URI to Vertex AI Imagen credentials (optional) + */ + private Resource credentialsUri; + + /** + * Vertex AI Imagen API endpoint. + */ + private String apiEndpoint; + + public String getProjectId() { + return this.projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getLocation() { + return this.location; + } + + public void setLocation(String location) { + this.location = location; + } + + public Resource getCredentialsUri() { + return this.credentialsUri; + } + + public void setCredentialsUri(Resource credentialsUri) { + this.credentialsUri = credentialsUri; + } + + public String getApiEndpoint() { + return this.apiEndpoint; + } + + public void setApiEndpoint(String apiEndpoint) { + this.apiEndpoint = apiEndpoint; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenImageProperties.java new file mode 100644 index 0000000000..3cba5ac320 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenImageProperties.java @@ -0,0 +1,57 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vertexai.imagen; + +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Configuration properties for Vertex AI Imagen. + * + * @author Sami Marzouki + */ +@ConfigurationProperties(VertexAiImagenImageProperties.CONFIG_PREFIX) +public class VertexAiImagenImageProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.imagen.generator"; + + private boolean enabled = true; + + /** + * Vertex AI Imagen API options. + */ + private VertexAiImagenImageOptions options = VertexAiImagenImageOptions.builder() + .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) + .build(); + + public VertexAiImagenImageOptions getOptions() { + return this.options; + } + + public void setOptions(VertexAiImagenImageOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenModelAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenModelAutoConfigurationIT.java new file mode 100644 index 0000000000..e1211f851a --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenModelAutoConfigurationIT.java @@ -0,0 +1,86 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vertexai.imagen; + +import java.io.File; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.io.TempDir; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Sami Marzouki + */ +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_LOCATION", matches = ".*") +public class VertexAiImagenModelAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.vertex.ai.imagen.project-id=" + System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID"), + "spring.ai.vertex.ai.imagen.location=" + System.getenv("VERTEX_AI_IMAGEN_LOCATION")) + .withConfiguration(AutoConfigurations.of(VertexAiImagenAutoConfiguration.class)); + + @TempDir + File tempDir; + + + @Test + public void imageGenerator() { + this.contextRunner.run(context -> { + var connectionProperties = context.getBean(VertexAiImagenConnectionProperties.class); + var imageProperties = context.getBean(VertexAiImagenImageProperties.class); + + assertThat(connectionProperties).isNotNull(); + assertThat(imageProperties.isEnabled()).isTrue(); + + VertexAiImagenImageModel imageModel = context.getBean(VertexAiImagenImageModel.class); + assertThat(imageModel).isInstanceOf(VertexAiImagenImageModel.class); + + ImageResponse imageResponse = imageModel.call(new ImagePrompt("Spring Framework, Spring AI")); + + assertThat(imageResponse.getResults().size()).isEqualTo(1); + assertThat(imageResponse.getResults().get(0).getOutput().getB64Json()).isNotEmpty(); + }); + } + + @Test + void imageGeneratorActivation() { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.imagen.generator.enabled=false").run(context -> { + assertThat(context.getBeansOfType(VertexAiImagenImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiImagenImageModel.class)).isEmpty(); + }); + + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.imagen.generator.enabled=true").run(context -> { + assertThat(context.getBeansOfType(VertexAiImagenImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiImagenImageModel.class)).isNotEmpty(); + }); + + this.contextRunner.run(context -> { + assertThat(context.getBeansOfType(VertexAiImagenImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiImagenImageModel.class)).isNotEmpty(); + }); + + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml new file mode 100644 index 0000000000..c789f590a3 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml @@ -0,0 +1,61 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-M5 + ../../pom.xml + + + spring-ai-vertex-ai-imagen-spring-boot-starter + jar + Spring AI Starter - VertexAI Imagen + Spring AI Vertex Imagen AI Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-vertex-ai-imagen + ${project.parent.version} + + + + From 0fe3fff9fdb7a6a99f016422f2fbc5f05496ab40 Mon Sep 17 00:00:00 2001 From: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> Date: Thu, 5 Jun 2025 12:44:51 +0100 Subject: [PATCH 2/9] Update spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml Updating the version. Co-authored-by: ByungJun Signed-off-by: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> --- .../spring-ai-starter-vertex-ai-imagen/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml index c789f590a3..d63256afef 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml @@ -22,7 +22,7 @@ org.springframework.ai spring-ai - 1.0.0-M5 + 1.0.0-SNAPSHOT ../../pom.xml From 67c46d14daf959dac7d5a5bb764d790089eb476b Mon Sep 17 00:00:00 2001 From: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> Date: Thu, 5 Jun 2025 13:41:16 +0100 Subject: [PATCH 3/9] Update spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml changing artifactId Co-authored-by: ByungJun Signed-off-by: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> --- .../spring-ai-starter-vertex-ai-imagen/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml index d63256afef..e0a4a38c38 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml @@ -26,7 +26,7 @@ ../../pom.xml - spring-ai-vertex-ai-imagen-spring-boot-starter + spring-ai-starter-model-vertex-ai-imagen jar Spring AI Starter - VertexAI Imagen Spring AI Vertex Imagen AI Auto Configuration From adafe0b9597063a81ab684300e56a7d07fd5509c Mon Sep 17 00:00:00 2001 From: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> Date: Thu, 5 Jun 2025 13:42:01 +0100 Subject: [PATCH 4/9] Update spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml removing unnecessary spaces Co-authored-by: ByungJun Signed-off-by: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> --- .../spring-ai-starter-vertex-ai-imagen/pom.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml index e0a4a38c38..86a71ccd4f 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml @@ -28,7 +28,8 @@ spring-ai-starter-model-vertex-ai-imagen jar - Spring AI Starter - VertexAI Imagen + Spring AI Starter - VertexAI Imagen + Spring AI Vertex Imagen AI Spring Boot Starter Spring AI Vertex Imagen AI Auto Configuration https://github.com/spring-projects/spring-ai From 73c324519b731e23cd8355476ef41535722f7126 Mon Sep 17 00:00:00 2001 From: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> Date: Thu, 5 Jun 2025 13:42:31 +0100 Subject: [PATCH 5/9] updating artifactId Co-authored-by: ByungJun Signed-off-by: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> --- .../spring-ai-starter-vertex-ai-imagen/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml index 86a71ccd4f..fd72af115e 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml @@ -48,7 +48,7 @@ org.springframework.ai - spring-ai-spring-boot-autoconfigure + spring-ai-autoconfigure-model-vertex-ai ${project.parent.version} From a242b4f7f516b41bc2b43403215eb585d549416d Mon Sep 17 00:00:00 2001 From: "sami.marzouki" Date: Thu, 5 Jun 2025 14:06:12 +0100 Subject: [PATCH 6/9] resolving conflict Signed-off-by: Marzouki-Sami --- .../spring-ai-starter-vertex-ai-imagen/pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml index fd72af115e..7c678a56cd 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml @@ -30,7 +30,6 @@ jar Spring AI Starter - VertexAI Imagen Spring AI Vertex Imagen AI Spring Boot Starter - Spring AI Vertex Imagen AI Auto Configuration https://github.com/spring-projects/spring-ai From 9b223f9496ec08794aa0cf6c23cb713c00b5f7c7 Mon Sep 17 00:00:00 2001 From: Marzouki-Sami Date: Tue, 17 Jun 2025 10:03:36 +0100 Subject: [PATCH 7/9] updating to version 1.1.0-SNAPSHOT Signed-off-by: Marzouki-Sami --- models/spring-ai-vertex-ai-imagen/pom.xml | 11 +++----- .../imagen/VertexAiImagenImageModel.java | 25 +++++++++---------- pom.xml | 3 +++ .../pom.xml | 4 +-- 4 files changed, 21 insertions(+), 22 deletions(-) rename spring-ai-spring-boot-starters/{spring-ai-starter-vertex-ai-imagen => spring-ai-starter-model-vertex-ai-imagen}/pom.xml (96%) diff --git a/models/spring-ai-vertex-ai-imagen/pom.xml b/models/spring-ai-vertex-ai-imagen/pom.xml index 973c053c0d..8e41ecec11 100644 --- a/models/spring-ai-vertex-ai-imagen/pom.xml +++ b/models/spring-ai-vertex-ai-imagen/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.springframework.ai - spring-ai - 1.0.0-M5 + spring-ai-parent + 1.1.0-SNAPSHOT ../../pom.xml @@ -67,9 +67,10 @@ + org.springframework.ai - spring-ai-core + spring-ai-model ${project.parent.version} @@ -80,10 +81,6 @@ - - org.springframework - spring-context-support - org.springframework.boot diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java index 837bc3747c..f37e791c0f 100644 --- a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java @@ -148,13 +148,13 @@ private static ImageParametersBuilder getImageParametersBuilder(VertexAiImagenIm @Override public ImageResponse call(ImagePrompt imagePrompt) { - VertexAiImagenImageOptions finalOptions = mergedOptions(imagePrompt); + ImagePrompt finalPrompt = mergedPrompt(imagePrompt); + VertexAiImagenImageOptions finalOptions = (VertexAiImagenImageOptions) finalPrompt.getOptions(); var observationContext = ImageModelObservationContext.builder() - .imagePrompt(imagePrompt) - .provider(AiProvider.VERTEX_AI.value()) - .requestOptions(finalOptions) - .build(); + .imagePrompt(finalPrompt) + .provider(AiProvider.VERTEX_AI.value()) + .build(); return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, @@ -164,7 +164,7 @@ public ImageResponse call(ImagePrompt imagePrompt) { EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); - PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(imagePrompt, endpointName, + PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(finalPrompt, endpointName, finalOptions); PredictResponse imageResponse = this.retryTemplate @@ -175,7 +175,7 @@ public ImageResponse call(ImagePrompt imagePrompt) { Value bytesBase64Encoded = prediction.getStructValue().getFieldsOrThrow("bytesBase64Encoded"); Value mimeType = prediction.getStructValue().getFieldsOrThrow("mimeType"); ImageGenerationMetadata metadata = new VertexAiImagenImageGenerationMetadata( - imagePrompt.getInstructions().get(0).getText(), finalOptions.getModel(), + finalPrompt.getInstructions().get(0).getText(), finalOptions.getModel(), mimeType.getStringValue()); Image image = new Image(null, bytesBase64Encoded.getStringValue()); imageGenerationList.add(new ImageGeneration(image, metadata)); @@ -189,17 +189,16 @@ public ImageResponse call(ImagePrompt imagePrompt) { }); } - private VertexAiImagenImageOptions mergedOptions(ImagePrompt imagePrompt) { + private ImagePrompt mergedPrompt(ImagePrompt originalPrompt) { + VertexAiImagenImageOptions finalOptions = this.defaultOptions; - VertexAiImagenImageOptions mergedOptions = this.defaultOptions; - - if (imagePrompt.getOptions() != null) { + if (originalPrompt.getOptions() != null) { var defaultOptionsCopy = VertexAiImagenImageOptions.builder().from(this.defaultOptions).build(); - mergedOptions = ModelOptionsUtils.merge(imagePrompt.getOptions(), defaultOptionsCopy, + finalOptions = ModelOptionsUtils.merge(originalPrompt.getOptions(), defaultOptionsCopy, VertexAiImagenImageOptions.class); } - return mergedOptions; + return new ImagePrompt(originalPrompt.getInstructions(), finalOptions); } protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePrompt, EndpointName endpointName, diff --git a/pom.xml b/pom.xml index f215c4f660..26a68acfa2 100644 --- a/pom.xml +++ b/pom.xml @@ -175,6 +175,7 @@ models/spring-ai-transformers models/spring-ai-vertex-ai-embedding models/spring-ai-vertex-ai-gemini + models/spring-ai-vertex-ai-imagen models/spring-ai-zhipuai models/spring-ai-deepseek @@ -194,6 +195,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-model-transformers spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-embedding spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-gemini + spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-imagen spring-ai-spring-boot-starters/spring-ai-starter-model-zhipuai spring-ai-spring-boot-starters/spring-ai-starter-model-deepseek @@ -727,6 +729,7 @@ org.springframework.ai.transformers/**/*IT.java org.springframework.ai.vertexai.embedding/**/*IT.java org.springframework.ai.vertexai.gemini/**/*IT.java + org.springframework.ai.vertexai.imagen/**/*IT.java org.springframework.ai.zhipuai/**/*IT.java diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-imagen/pom.xml similarity index 96% rename from spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml rename to spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-imagen/pom.xml index 7c678a56cd..e690ee815e 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-imagen/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.springframework.ai - spring-ai - 1.0.0-SNAPSHOT + spring-ai-parent + 1.1.0-SNAPSHOT ../../pom.xml From f341567d9aaa0f618fecbcab0e260da69ec12bb6 Mon Sep 17 00:00:00 2001 From: Sami Marzouki <73302274+Marzouki-Sami@users.noreply.github.com> Date: Tue, 22 Jul 2025 09:10:01 +0100 Subject: [PATCH 8/9] Update VertexAiImagenImageGenerationMetadata.java switching constructor params to match only call in VertexAiImagenImageModel.java Signed-off-by: Marzouki-Sami samymarzouki502@gmail.com Signed-off-by: Marzouki-Sami --- .../imagen/metadata/VertexAiImagenImageGenerationMetadata.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java index ee12bd1c97..774f210dba 100644 --- a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java @@ -34,7 +34,7 @@ public class VertexAiImagenImageGenerationMetadata implements ImageGenerationMet private final String mimeType; - public VertexAiImagenImageGenerationMetadata(String revisedPrompt, String mimeType, String model) { + public VertexAiImagenImageGenerationMetadata(String revisedPrompt, String model, String mimeType) { this.prompt = revisedPrompt; this.model = model; this.mimeType = mimeType; From 8973601f4acca16b598183da8a1dbb7f4c46c588 Mon Sep 17 00:00:00 2001 From: Marzouki-Sami Date: Thu, 23 Oct 2025 11:43:53 +0100 Subject: [PATCH 9/9] adding the new models + setting older models as deprecated + adding 3 missing imageOptions Author: Marzouki-Sami Signed-off-by: Marzouki-Sami samymarzouki502@gmail.com Signed-off-by: Marzouki-Sami --- .../VertexAiImagenConnectionDetails.java | 18 +- .../imagen/VertexAiImagenImageModel.java | 92 +++--- .../imagen/VertexAiImagenImageModelName.java | 19 +- .../imagen/VertexAiImagenImageOptions.java | 267 +++++++++++++++--- .../vertexai/imagen/VertexAiImagenUtils.java | 61 +++- ...VertexAiImagenImageGenerationMetadata.java | 10 +- .../imagen/TestVertexAiImagenImageModel.java | 7 +- .../imagen/VertexAiImagenImageModelIT.java | 33 ++- ...VertexAiImagenImageModelObservationIT.java | 46 +-- .../imagen/VertexAiImagenImageRetryTests.java | 38 +-- 10 files changed, 422 insertions(+), 169 deletions(-) diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java index a466b95f57..b33452f435 100644 --- a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java @@ -23,10 +23,10 @@ import org.springframework.util.StringUtils; - /** - * VertexAiImagenConnectionDetails represents the details of a connection to the Vertex AI imagen service. - * It provides methods to access the project ID, location, publisher, and PredictionServiceSettings. + * VertexAiImagenConnectionDetails represents the details of a connection to the + * Vertex AI imagen service. It provides methods to access the project ID, location, + * publisher, and PredictionServiceSettings. * * @author Sami Marzouki */ @@ -59,7 +59,7 @@ public class VertexAiImagenConnectionDetails { private final PredictionServiceSettings predictionServiceSettings; public VertexAiImagenConnectionDetails(String projectId, String location, String publisher, - PredictionServiceSettings predictionServiceSettings) { + PredictionServiceSettings predictionServiceSettings) { this.projectId = projectId; this.location = location; this.publisher = publisher; @@ -153,7 +153,8 @@ public VertexAiImagenConnectionDetails build() { if (!StringUtils.hasText(this.location)) { this.endpoint = DEFAULT_ENDPOINT; this.location = DEFAULT_LOCATION; - } else { + } + else { this.endpoint = this.location + DEFAULT_ENDPOINT_SUFFIX; } } @@ -165,9 +166,10 @@ public VertexAiImagenConnectionDetails build() { if (this.predictionServiceSettings == null) { try { this.predictionServiceSettings = PredictionServiceSettings.newBuilder() - .setEndpoint(this.endpoint) - .build(); - } catch (IOException e) { + .setEndpoint(this.endpoint) + .build(); + } + catch (IOException e) { throw new RuntimeException(e); } } diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java index f37e791c0f..2b93a926f6 100644 --- a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java @@ -16,17 +16,12 @@ package org.springframework.ai.vertexai.imagen; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictRequest; import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.protobuf.Value; import io.micrometer.observation.ObservationRegistry; - import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImageGenerationMetadata; @@ -46,8 +41,13 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + /** - * VertexAiImagenImageModel is a class that implements the ImageModel interface. It + * VertexAiImagenImageModel is a class that implements the ImageModel interface. It * provides a client for calling the Imagen on Vertex AI image generation API. * * @author Sami Marzouki @@ -82,18 +82,18 @@ public class VertexAiImagenImageModel implements ImageModel { private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, - VertexAiImagenImageOptions defaultOptions) { + VertexAiImagenImageOptions defaultOptions) { this(connectionDetails, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); } public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, - VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate) { + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate) { this(connectionDetails, defaultOptions, retryTemplate, ObservationRegistry.NOOP); } public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, - VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate, - ObservationRegistry observationRegistry) { + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { Assert.notNull(defaultOptions, "options must not be null"); Assert.notNull(retryTemplate, "retryTemplate must not be null"); Assert.notNull(observationRegistry, "observationRegistry must not be null"); @@ -130,6 +130,15 @@ private static ImageParametersBuilder getImageParametersBuilder(VertexAiImagenIm if (finalOptions.getSafetySetting() != null) { parametersBuilder.safetySetting(finalOptions.getSafetySetting()); } + if (finalOptions.getLanguage() != null) { + parametersBuilder.language(finalOptions.getLanguage()); + } + if (finalOptions.getEnhancePrompt() != null) { + parametersBuilder.enhancePrompt(finalOptions.getEnhancePrompt()); + } + if (finalOptions.getSampleImageSize() != null) { + parametersBuilder.sampleImageSize(finalOptions.getSampleImageSize()); + } if (finalOptions.getOutputOptions() != null) { ImageParametersBuilder.OutputOptions outputOptions = ImageParametersBuilder.OutputOptions.of(); @@ -156,37 +165,37 @@ public ImageResponse call(ImagePrompt imagePrompt) { .provider(AiProvider.VERTEX_AI.value()) .build(); - return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { - PredictionServiceClient client = createPredictionServiceClient(); - - EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); - - PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(finalPrompt, endpointName, - finalOptions); - - PredictResponse imageResponse = this.retryTemplate + return Objects.requireNonNull( + ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + + PredictionServiceClient client = createPredictionServiceClient(); + EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); + PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(finalPrompt, + endpointName, finalOptions); + PredictResponse imageResponse = this.retryTemplate .execute(context -> getPredictResponse(client, predictRequestBuilder)); + List imageGenerationList = new ArrayList<>(); - List imageGenerationList = new ArrayList<>(); - for (Value prediction : imageResponse.getPredictionsList()) { - Value bytesBase64Encoded = prediction.getStructValue().getFieldsOrThrow("bytesBase64Encoded"); - Value mimeType = prediction.getStructValue().getFieldsOrThrow("mimeType"); - ImageGenerationMetadata metadata = new VertexAiImagenImageGenerationMetadata( - finalPrompt.getInstructions().get(0).getText(), finalOptions.getModel(), - mimeType.getStringValue()); - Image image = new Image(null, bytesBase64Encoded.getStringValue()); - imageGenerationList.add(new ImageGeneration(image, metadata)); - } - ImageResponse response = new ImageResponse(imageGenerationList); + for (Value prediction : imageResponse.getPredictionsList()) { + Value bytesBase64Encoded = prediction.getStructValue() + .getFieldsOrThrow("bytesBase64Encoded"); + Value mimeType = prediction.getStructValue().getFieldsOrThrow("mimeType"); + ImageGenerationMetadata metadata = new VertexAiImagenImageGenerationMetadata( + finalPrompt.getInstructions().get(0).getText(), finalOptions.getModel(), + mimeType.getStringValue()); + Image image = new Image(null, bytesBase64Encoded.getStringValue()); + imageGenerationList.add(new ImageGeneration(image, metadata)); + } + ImageResponse response = new ImageResponse(imageGenerationList); - observationContext.setResponse(response); + observationContext.setResponse(response); - return response; + return response; - }); + })); } private ImagePrompt mergedPrompt(ImagePrompt originalPrompt) { @@ -202,7 +211,7 @@ private ImagePrompt mergedPrompt(ImagePrompt originalPrompt) { } protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePrompt, EndpointName endpointName, - VertexAiImagenImageOptions finalOptions) { + VertexAiImagenImageOptions finalOptions) { PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString()); ImageParametersBuilder parametersBuilder = getImageParametersBuilder(finalOptions); @@ -222,7 +231,7 @@ protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePromp for (int i = 0; i < imagePrompt.getInstructions().size(); i++) { ImageInstanceBuilder instanceBuilder = ImageInstanceBuilder - .of(imagePrompt.getInstructions().get(i).getText()); + .of(imagePrompt.getInstructions().get(i).getText()); predictRequestBuilder.addInstances(VertexAiImagenUtils.valueOf(instanceBuilder.build())); } return predictRequestBuilder; @@ -232,19 +241,20 @@ protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePromp protected PredictionServiceClient createPredictionServiceClient() { try { return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings()); - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } } // for testing - protected PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { + protected PredictResponse getPredictResponse(PredictionServiceClient client, + PredictRequest.Builder predictRequestBuilder) { return client.predict(predictRequestBuilder.build()); } /** * Use the provided convention for reporting observation data. - * * @param observationConvention The provided convention */ public void setObservationConvention(ImageModelObservationConvention observationConvention) { diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java index f44cadd4d5..6b60b33e05 100644 --- a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java @@ -17,23 +17,35 @@ package org.springframework.ai.vertexai.imagen; /** - * Imagen on VertexAI Models: - * - Image generation + * Imagen on VertexAI Models: - Image + * generation * * @author Sami Marzouki */ public enum VertexAiImagenImageModelName { - IMAGEN_3("imagen-3.0-generate-001"), + IMAGEN_4_V001("imagen-4.0-generate-001"), + + IMAGEN_4_FAST("imagen-4.0-fast-generate-001"), + + IMAGEN_4_ULTRA("imagen-4.0-ultra-generate-001"), + + IMAGEN_3_V002("imagen-3.0-generate-002"), + + IMAGEN_3_V001("imagen-3.0-generate-001"), IMAGEN_3_FAST("imagen-3.0-fast-generate-001"), IMAGEN_3_CUSTOMIZATION_AND_EDITING("imagen-3.0-capability-001"), + @Deprecated IMAGEN_2_V006("imagegeneration@006"), + @Deprecated IMAGEN_2_V005("imagegeneration@005"), + @Deprecated IMAGEN_1_V002("imagegeneration@002"); private final String value; @@ -45,4 +57,5 @@ public enum VertexAiImagenImageModelName { public String getValue() { return this.value; } + } diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java index 2e0cabaef4..d62ddf58e8 100644 --- a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java @@ -15,75 +15,148 @@ */ package org.springframework.ai.vertexai.imagen; -import java.util.List; - import com.fasterxml.jackson.annotation.JsonProperty; - import org.springframework.ai.image.ImageOptions; +import java.util.List; + import static org.springframework.ai.vertexai.imagen.VertexAiImagenUtils.calculateSizeFromAspectRatio; /** - * Options for the Vertex AI Image service. + *

Options for the Vertex AI Image service.

* * @author Sami Marzouki */ public class VertexAiImagenImageOptions implements ImageOptions { - public static final String DEFAULT_MODEL_NAME = VertexAiImagenImageModelName.IMAGEN_2_V006.getValue(); + public static final String DEFAULT_MODEL_NAME = VertexAiImagenImageModelName.IMAGEN_3_V002.getValue(); /** - * Required: int - * The number of images to generate. The default value is 4. The - * imagen-3.0-generate-001 model supports values 1 through 4. The - * imagen-3.0-fast-generate-001 model supports values 1 through 4. The - * imagegeneration@006 model supports values 1 through 4. The imagegeneration@005 - * model supports values 1 through 4. The imagegeneration@002 model supports values 1 - * through 8. + * Required: int + *

+ * The number of images to generate. The default value is 4. + *

+ *
    + *
  • imagen-3.0-generate-001 model supports values 1 through 4.
  • + *
  • imagen-3.0-fast-generate-001 model supports values 1 through 4.
  • + *
  • imagegeneration@006 model supports values 1 through 4.
  • + *
  • imagegeneration@005 model supports values 1 through 4.
  • + *
  • imagegeneration@002 model supports values 1 through 8.
  • + *
*/ @JsonProperty("sampleCount") private Integer n; /** + *

* The model to use for image generation. + *

*/ @JsonProperty("model") private String model; /** - * Optional: Uint32 - * The random seed for image generation. This is not available when addWatermark is set to true. + * Optional: Uint32 + *

+ * The random seed for image generation. This is not available when addWatermark is + * set to true. + *

*/ @JsonProperty("seed") private Integer seed; /** - * Optional: string + * Optional: string + *

* A description of what to discourage in the generated images. - * The imagen-3.0-generate-001 model supports up to 480 tokens. - * The imagen-3.0-fast-generate-001 model supports up to 480 tokens. - * The imagegeneration@006 model supports up to 128 tokens. - * The imagegeneration@005 model supports up to 128 tokens. - * The imagegeneration@002 model supports up to 64 tokens. + *

+ *

    + *
  • The imagen-3.0-generate-001 model supports up to 480 tokens.
  • + *
  • The imagen-3.0-fast-generate-001 model supports up to 480 tokens.
  • + *
  • The imagegeneration@006 model supports up to 128 tokens.
  • + *
  • The imagegeneration@005 model supports up to 128 tokens.
  • + *
  • The imagegeneration@002 model supports up to 64 tokens.
  • + *
+ * negativePrompt isn't supported by imagen-3.0-generate-002 and newer models. */ @JsonProperty("negativePrompt") private String negativePrompt; /** - * Optional: string + * Optional: string + *

+ * Specifies the generated image's output resolution.
+ * The accepted values are "1K" or "2K".
+ * The default value is "1K". + *

+ */ + @JsonProperty("sampleImageSize") + private String sampleImageSize; + + /** + * Optional: boolean + *

+ * An optional parameter to use an LLM-based prompt rewriting feature to deliver + * higher quality images that better reflect the original prompt's intent. Disabling + * this feature may impact image quality and prompt adherence. + *

+ */ + @JsonProperty("enhancePrompt") + private Boolean enhancePrompt; + + /** + * Optional: string + *

+ * The language code that corresponds to your text prompt language. + *

+ *
    + *
  • auto: Automatic detection. + *

    + * If Imagen detects a supported language, the prompt and an optional negative prompt + * are translated to English. If the language detected isn't supported, Imagen uses + * the input text verbatim, which might result in an unexpected output. No error code + * is returned. + *

    + *
  • + *
  • en: English (if omitted, the default value)
  • + *
  • zh or zh-CN: Chinese (simplified)
  • + *
  • zh-TW: Chinese (traditional)
  • + *
  • hi: Hindi
  • + *
  • ja: Japanese
  • + *
  • ko: Korean
  • + *
  • pt: Portuguese
  • + *
  • es: Spanish
  • + *
+ */ + @JsonProperty("language") + private String language; + + /** + * Optional: string + *

* The aspect ratio for the image. The default value is "1:1". - * The imagen-3.0-generate-001 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". - * The imagen-3.0-fast-generate-001 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". - * The imagegeneration@006 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". - * The imagegeneration@005 model supports "1:1" or "9:16". - * The imagegeneration@002 model supports "1:1". + *

+ *
    + *
  • The imagen-3.0-generate-002 model supports "1:1", "9:16", "16:9", "3:4", or + * "4:3".
  • + *
  • The imagen-3.0-generate-001 model supports "1:1", "9:16", "16:9", "3:4", or + * "4:3".
  • + *
  • The imagen-3.0-fast-generate-001 model supports "1:1", "9:16", "16:9", "3:4", + * or "4:3".
  • + *
  • The imagegeneration@006 model supports "1:1", "9:16", "16:9", "3:4", or + * "4:3".
  • + *
  • The imagegeneration@005 model supports "1:1" or "9:16".
  • + *
  • The imagegeneration@002 model supports "1:1".
  • + *
*/ @JsonProperty("aspectRatio") private String aspectRatio; /** - * Optional: outputOptions + * Optional: outputOptions + *

* Describes the output image format in an outputOptions object. + *

* * @see OutputOptions */ @@ -91,50 +164,82 @@ public class VertexAiImagenImageOptions implements ImageOptions { private OutputOptions outputOptions; /** - * Optional: string (imagegeneration@002 only) + * Optional: string (imagegeneration@002 only) + *

* Describes the style for the generated images. The following values are supported: - * "photograph", "digital_art", "landscape", "sketch", "watercolor", "cyberpunk", "pop_art". + *

+ *
    + *
  • "photograph"
  • + *
  • "digital_art"
  • + *
  • "landscape"
  • + *
  • "sketch"
  • + *
  • "watercolor"
  • + *
  • "cyberpunk"
  • + *
  • "pop_art"
  • + *
*/ @JsonProperty("sampleImageStyle") private String style; /** - * Optional: string (imagen-3.0-generate-001, imagen-3.0-fast-generate-001, and imagegeneration@006 only) + * Optional: string (imagen-3.0-generate-001, imagen-3.0-fast-generate-001, and + * imagegeneration@006 only) + *

* Allow generation of people by the model. The following values are supported: - * "dont_allow": Disallow the inclusion of people or faces in images. - * "allow_adult": Allow generation of adults only. - * "allow_all": Allow generation of people of all ages. + *

+ *
    + *
  • "dont_allow": Disallow the inclusion of people or faces in images.
  • + *
  • "allow_adult": Allow generation of adults only.
  • + *
  • "allow_all": Allow generation of people of all ages.
  • + *
+ *

* The default value is "allow_adult". + *

*/ @JsonProperty("personGeneration") private String personGeneration; /** - * Optional: string (imagen-3.0-generate-001, imagen-3.0-fast-generate-001, and imagegeneration@006 only) + * Optional: string (imagen-3.0-generate-001, imagen-3.0-fast-generate-001, and + * imagegeneration@006 only) + *

* Adds a filter level to safety filtering. The following values are supported: - * "block_low_and_above": Strongest filtering level, most strict blocking. Deprecated value: "block_most". - * "block_medium_and_above": Block some problematic prompts and responses. Deprecated value: "block_some". - * "block_only_high": Reduces the number of requests blocked due to safety filters. May increase objectionable - * content generated by Imagen. Deprecated value: "block_few". - * "block_none": Block very few problematic prompts and responses. Access to this feature is restricted. - * Previous field value: "block_fewest". + *

+ *
    + *
  • "block_low_and_above": Strongest filtering level, most strict blocking.
    + * Deprecated value: "block_most".
  • + *
  • "block_medium_and_above": Block some problematic prompts and responses.
    + * Deprecated value: "block_some".
  • + *
  • "block_only_high": Reduces the number of requests blocked due to safety + * filters. May increase objectionable content generated by Imagen.
    + * Deprecated value: "block_few".
  • + *
  • "block_none": Block very few problematic prompts and responses. Access to this + * feature is restricted.
    + * Previous field value: "block_fewest".
  • + *
+ *

* The default value is "block_medium_and_above". + *

*/ @JsonProperty("safetySetting") private String safetySetting; /** - * Optional: bool - * Add an invisible watermark to the generated images. - * The default value is false for the imagegeneration@002 and imagegeneration@005 models, - * and true for the imagen-3.0-fast-generate-001, imagegeneration@006, and imagegeneration@006 models. + * Optional: bool + *

+ * Add an invisible watermark to the generated images. The default value is false for + * the imagegeneration@002 and imagegeneration@005 models, and true for the + * imagen-3.0-fast-generate-001, imagegeneration@006, and imagegeneration@006 models. + *

*/ @JsonProperty("addWatermark") private Boolean addWatermark; /** - * Optional: string + * Optional: string + *

* Cloud Storage URI to store the generated images. + *

*/ @JsonProperty("storageUri") private String storageUri; @@ -267,11 +372,56 @@ public void setSize(List size) { this.size = size; } + public Boolean getEnhancePrompt() { + return enhancePrompt; + } + + public void setEnhancePrompt(Boolean enhancePrompt) { + this.enhancePrompt = enhancePrompt; + } + + public String getLanguage() { + return language; + } + + public void setLanguage(String language) { + this.language = language; + } + + public String getSampleImageSize() { + return sampleImageSize; + } + + public void setSampleImageSize(String sampleImageSize) { + this.sampleImageSize = sampleImageSize; + } + public static final class OutputOptions { + /** + * Optional: string + *

+ * The image format that the output should be saved as. The following values are + * supported: + *

+ *
    + *
  • "image/png": Save as a PNG image
  • + *
  • "image/jpeg": Save as a JPEG image
  • + *
+ *

+ * The default value is "image/png". + *

+ */ @JsonProperty("mimeType") private String mimeType; + /** + * Optional: int + *

+ * The level of compression if the output type is "image/jpeg". Accepted values + * are 0 through 100. The default value is 75. + *

+ */ @JsonProperty("compressionQuality") private Integer compressionQuality; @@ -316,7 +466,9 @@ public Builder compressionQuality(Integer compressionQuality) { public OutputOptions build() { return this.options; } + } + } public static final class Builder { @@ -367,6 +519,15 @@ public Builder from(VertexAiImagenImageOptions fromOptions) { if (fromOptions.getStorageUri() != null) { this.options.setStorageUri(fromOptions.getStorageUri()); } + if (fromOptions.getLanguage() != null) { + this.options.setLanguage(fromOptions.getLanguage()); + } + if (fromOptions.getEnhancePrompt() != null) { + this.options.setEnhancePrompt(fromOptions.getEnhancePrompt()); + } + if (fromOptions.getSampleImageSize() != null) { + this.options.setSampleImageSize(fromOptions.getSampleImageSize()); + } return this; } @@ -427,9 +588,25 @@ public Builder style(String style) { return this; } + public Builder language(String language) { + this.options.setLanguage(language); + return this; + } + + public Builder enhancePrompt(Boolean enhancePrompt) { + this.options.setEnhancePrompt(enhancePrompt); + return this; + } + + public Builder sampleImageSize(String sampleImageSize) { + this.options.setSampleImageSize(sampleImageSize); + return this; + } + public VertexAiImagenImageOptions build() { return this.options; } } + } diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java index d4650105e7..9858bb5947 100644 --- a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java @@ -16,18 +16,18 @@ package org.springframework.ai.vertexai.imagen; -import java.util.Arrays; -import java.util.List; - import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Struct; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; - import org.springframework.util.Assert; +import java.util.Arrays; +import java.util.List; + /** - * Utility class for constructing parameter objects for Imagen on Vertex AI requests. + * VertexAiImagenUtils is a Utility class for constructing parameter objects for + * Imagen on Vertex AI requests. * * @author Sami Marzouki */ @@ -63,8 +63,8 @@ public static List calculateSizeFromAspectRatio(String aspectRatio) { case "16:9" -> List.of(1600, 900); case "3:4" -> List.of(750, 1000); case "4:3" -> List.of(1000, 750); - default -> throw new IllegalStateException("Unexpected value: " + aspectRatio + - " aspect ratio must be one of these values : ['1:1', '9:16', '16:9', '3:4', or '4:3']"); + default -> throw new IllegalStateException("Unexpected value: " + aspectRatio + + " aspect ratio must be one of these values : ['1:1', '9:16', '16:9', '3:4', or '4:3']"); }; } return Arrays.asList(1024, 1024); @@ -86,20 +86,35 @@ public Struct build() { textBuilder.putFields("prompt", valueOf(this.prompt)); return textBuilder.build(); } + } public static class ImageParametersBuilder { public Integer sampleCount; + public Integer seed; + public String negativePrompt; + public String aspectRatio; + public Boolean addWatermark; + public String storageUri; + public String personGeneration; + public String safetySetting; + public Struct outputOptions; + public String language; + + public Boolean enhancePrompt; + + public String sampleImageSize; + public static ImageParametersBuilder of() { return new ImageParametersBuilder(); } @@ -158,6 +173,24 @@ public ImageParametersBuilder outputOptions(Struct outputOptions) { return this; } + public ImageParametersBuilder language(String language) { + Assert.notNull(language, "language must not be null"); + this.language = language; + return this; + } + + public ImageParametersBuilder enhancePrompt(Boolean enhancePrompt) { + Assert.notNull(enhancePrompt, "enhancePrompt must not be null"); + this.enhancePrompt = enhancePrompt; + return this; + } + + public ImageParametersBuilder sampleImageSize(String sampleImageSize) { + Assert.notNull(sampleImageSize, "sampleImageSize must not be null"); + this.sampleImageSize = sampleImageSize; + return this; + } + public Struct build() { Struct.Builder imageParametersBuilder = Struct.newBuilder(); @@ -186,13 +219,25 @@ public Struct build() { imageParametersBuilder.putFields("safetySetting", valueOf(this.safetySetting)); } if (this.outputOptions != null) { - imageParametersBuilder.putFields("outputOptions", Value.newBuilder().setStructValue(this.outputOptions).build()); + imageParametersBuilder.putFields("outputOptions", + Value.newBuilder().setStructValue(this.outputOptions).build()); + } + if (this.language != null) { + imageParametersBuilder.putFields("language", valueOf(this.language)); + } + if (this.enhancePrompt != null) { + imageParametersBuilder.putFields("enhancePrompt", valueOf(this.enhancePrompt)); + } + if (this.sampleImageSize != null) { + imageParametersBuilder.putFields("sampleImageSize", valueOf(this.sampleImageSize)); } return imageParametersBuilder.build(); } public static class OutputOptions { + public String mimeType; + public Integer compressionQuality; public static OutputOptions of() { diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java index 774f210dba..b262da93bd 100644 --- a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java @@ -54,11 +54,8 @@ public String getMimeType() { @Override public String toString() { - return "VertexAiImagenImageGenerationMetadata{" + - "prompt='" + prompt + '\'' + - ", model='" + model + '\'' + - ", mimeType='" + mimeType + '\'' + - '}'; + return "VertexAiImagenImageGenerationMetadata{" + "prompt='" + prompt + '\'' + ", model='" + model + '\'' + + ", mimeType='" + mimeType + '\'' + '}'; } @Override @@ -68,8 +65,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; VertexAiImagenImageGenerationMetadata that = (VertexAiImagenImageGenerationMetadata) o; - return Objects.equals(prompt, that.prompt) - && Objects.equals(model, that.model) + return Objects.equals(prompt, that.prompt) && Objects.equals(model, that.model) && Objects.equals(mimeType, that.mimeType); } diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java index 4c84c25419..aa4f78f9d4 100644 --- a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java @@ -36,7 +36,7 @@ public class TestVertexAiImagenImageModel extends VertexAiImagenImageModel { private PredictRequest.Builder mockPredictRequestBuilder; public TestVertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, - VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate) { + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate) { super(connectionDetails, defaultOptions, retryTemplate); } @@ -53,7 +53,8 @@ public PredictionServiceClient createPredictionServiceClient() { } @Override - public PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { + public PredictResponse getPredictResponse(PredictionServiceClient client, + PredictRequest.Builder predictRequestBuilder) { if (this.mockPredictionServiceClient != null) { return this.mockPredictionServiceClient.predict(predictRequestBuilder.build()); } @@ -66,7 +67,7 @@ public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictReque @Override protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePrompt, EndpointName endpointName, - VertexAiImagenImageOptions finalOptions) { + VertexAiImagenImageOptions finalOptions) { if (this.mockPredictRequestBuilder != null) { return this.mockPredictRequestBuilder; } diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java index 15465acee7..41a6af0520 100644 --- a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java @@ -20,7 +20,6 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; - import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; @@ -44,21 +43,31 @@ public class VertexAiImagenImageModelIT { protected VertexAiImagenImageModel imageModel; @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = {"imagen-3.0-generate-001", "imagen-3.0-fast-generate-001", "imagen-3.0-capability-001", - "imagegeneration@006", "imagegeneration@005", "imagegeneration@002"}) + @ValueSource(strings = { "imagen-4.0-generate-001", "imagen-4.0-fast-generate-001", "imagen-4.0-ultra-generate-001", + "imagen-3.0-generate-002", "imagen-3.0-generate-001", "imagen-3.0-fast-generate-001", + "imagen-3.0-capability-001" }) void defaultImageGenerator(String modelName) { Assertions.assertThat(this.imageModel).isNotNull(); var options = VertexAiImagenImageOptions.builder().model(modelName).N(1).build(); ImageResponse imageResponse = this.imageModel - .call(new ImagePrompt("little kitten sitting on a purple cushion", options)); + .call(new ImagePrompt("little kitten sitting on a purple cushion", options)); Assertions.assertThat(imageResponse.getResults()).hasSize(2); Assertions.assertThat(imageResponse.getResults().get(0).getOutput().getB64Json()).isNotEmpty(); - Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getModel()).isNotEmpty(); - Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getPrompt()).isNotEmpty(); - Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getMimeType()).isNotEmpty(); + Assertions + .assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()) + .getModel()) + .isNotEmpty(); + Assertions + .assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()) + .getPrompt()) + .isNotEmpty(); + Assertions + .assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()) + .getMimeType()) + .isNotEmpty(); } @SpringBootConfiguration @@ -67,17 +76,17 @@ static class Config { @Bean public VertexAiImagenConnectionDetails connectionDetails() { return VertexAiImagenConnectionDetails.builder() - .projectId(System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID")) - .location(System.getenv("VERTEX_AI_IMAGEN_LOCATION")) - .build(); + .projectId(System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID")) + .location(System.getenv("VERTEX_AI_IMAGEN_LOCATION")) + .build(); } @Bean public VertexAiImagenImageModel imageModel(VertexAiImagenConnectionDetails connectionDetails) { VertexAiImagenImageOptions options = VertexAiImagenImageOptions.builder() - .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) - .build(); + .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) + .build(); return new VertexAiImagenImageModel(connectionDetails, options); } diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java index 8aa3adfe1a..6cab5bdee5 100644 --- a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java @@ -61,9 +61,9 @@ public class VertexAiImagenImageModelObservationIT { @Test void observationForImageOperation() { var options = VertexAiImagenImageOptions.builder() - .model(VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) - .N(1) - .build(); + .model(VertexAiImagenImageModelName.IMAGEN_3_V002.getValue()) + .N(1) + .build(); ImagePrompt imagePrompt = new ImagePrompt("Little kitten sitting on a purple cushion", options); ImageResponse imageResponse = this.imageModel.call(imagePrompt); @@ -73,20 +73,20 @@ void observationForImageOperation() { assertThat(responseMetadata).isNotNull(); TestObservationRegistryAssert.assertThat(this.observationRegistry) - .doesNotHaveAnyRemainingCurrentObservation() - .hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME) - .that() - .hasContextualNameEqualTo("image " + VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) - .hasLowCardinalityKeyValue( - ImageModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), - AiOperationType.IMAGE.value()) - .hasLowCardinalityKeyValue(ImageModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(), - AiProvider.VERTEX_AI.value()) - .hasLowCardinalityKeyValue( - ImageModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL.asString(), - VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) - .hasBeenStarted() - .hasBeenStopped(); + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("image " + VertexAiImagenImageModelName.IMAGEN_3_V002.getValue()) + .hasLowCardinalityKeyValue( + ImageModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.IMAGE.value()) + .hasLowCardinalityKeyValue(ImageModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(), + AiProvider.VERTEX_AI.value()) + .hasLowCardinalityKeyValue( + ImageModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL.asString(), + VertexAiImagenImageModelName.IMAGEN_3_V002.getValue()) + .hasBeenStarted() + .hasBeenStopped(); } @SpringBootConfiguration @@ -100,18 +100,18 @@ public TestObservationRegistry observationRegistry() { @Bean public VertexAiImagenConnectionDetails connectionDetails() { return VertexAiImagenConnectionDetails.builder() - .projectId(System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID")) - .location(System.getenv("VERTEX_AI_IMAGEN_LOCATION")) - .build(); + .projectId(System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID")) + .location(System.getenv("VERTEX_AI_IMAGEN_LOCATION")) + .build(); } @Bean public VertexAiImagenImageModel imageModel(VertexAiImagenConnectionDetails connectionDetails, - ObservationRegistry observationRegistry) { + ObservationRegistry observationRegistry) { VertexAiImagenImageOptions options = VertexAiImagenImageOptions.builder() - .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) - .build(); + .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) + .build(); return new VertexAiImagenImageModel(connectionDetails, options, RetryUtils.DEFAULT_RETRY_TEMPLATE, observationRegistry); diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java index 0086338969..a921dac515 100644 --- a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java @@ -82,25 +82,24 @@ public void setUp() { public void vertexAiImageGeneratorTransientError() { // Set up the mock PredictResponse PredictResponse mockResponse = PredictResponse.newBuilder() - .addPredictions(Value.newBuilder() - .setStructValue(Struct.newBuilder() - .putFields("bytesBase64Encoded", Value.newBuilder().setStringValue("BASE64_IMG_BYTES").build()) - .putFields("mimeType", Value.newBuilder().setStringValue("image/png").build()) - .build()) - .build()) - .addPredictions(Value.newBuilder() - .setStructValue(Struct.newBuilder() - .putFields("mimeType", Value.newBuilder().setStringValue("image/png").build()) - .putFields("bytesBase64Encoded", Value.newBuilder().setStringValue("BASE64_IMG_BYTES").build()) - .build()) - .build()) - .build(); + .addPredictions(Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("bytesBase64Encoded", Value.newBuilder().setStringValue("BASE64_IMG_BYTES").build()) + .putFields("mimeType", Value.newBuilder().setStringValue("image/png").build()) + .build()) + .build()) + .addPredictions(Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("mimeType", Value.newBuilder().setStringValue("image/png").build()) + .putFields("bytesBase64Encoded", Value.newBuilder().setStringValue("BASE64_IMG_BYTES").build()) + .build()) + .build()) + .build(); // Set up the mock PredictionServiceClient - given(this.mockPredictionServiceClient.predict(any())) - .willThrow(new TransientAiException("Transient Error 1")) - .willThrow(new TransientAiException("Transient Error 2")) - .willReturn(mockResponse); + given(this.mockPredictionServiceClient.predict(any())).willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(mockResponse); ImageResponse result = this.imageModel.call(new ImagePrompt("text1", null)); @@ -121,7 +120,7 @@ public void vertexAiImageGeneratorNonTransientError() { // Assert that a RuntimeException is thrown and not retried assertThatThrownBy(() -> this.imageModel.call(new ImagePrompt("text1", null))) - .isInstanceOf(RuntimeException.class); + .isInstanceOf(RuntimeException.class); // Verify that predict was called only once (no retries for non-transient errors) verify(this.mockPredictionServiceClient, times(1)).predict(any()); @@ -130,6 +129,7 @@ public void vertexAiImageGeneratorNonTransientError() { private static class TestRetryListener implements RetryListener { int onErrorRetryCount = 0; + int onSuccessRetryCount = 0; @Override @@ -139,7 +139,7 @@ public void onSuccess(RetryContext context, RetryCallba @Override public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { + Throwable throwable) { this.onErrorRetryCount = context.getRetryCount(); }