From 70cc020c21d1d6647c839edb122b89d32de381f0 Mon Sep 17 00:00:00 2001 From: Rodrigo Lazo Paz Date: Thu, 1 May 2025 17:56:25 -0400 Subject: [PATCH] [Ai] Add workaround for invalid SafetyRating from the backend. Due to a bug in the backend, it's possible that we receive an invalid `SafetyRating` value, without either category or probability. We return null in those cases to enable filtering by the higher level types. --- .../com/google/firebase/ai/type/Candidate.kt | 30 ++++++++++++------- .../google/firebase/ai/type/PromptFeedback.kt | 2 +- .../ai/VertexAIStreamingSnapshotTests.kt | 15 ++++++++++ .../firebase/ai/VertexAIUnarySnapshotTests.kt | 13 ++++++++ 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt index a95231b583e..d5fc51f21c0 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt @@ -51,7 +51,7 @@ internal constructor( val groundingMetadata: GroundingMetadata? = null, ) { internal fun toPublic(): Candidate { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + val safetyRatings = safetyRatings?.mapNotNull { it.toPublic() }.orEmpty() val citations = citationMetadata?.toPublic() val finishReason = finishReason?.toPublic() @@ -120,8 +120,8 @@ internal constructor( internal data class Internal @JvmOverloads constructor( - val category: HarmCategory.Internal, - val probability: HarmProbability.Internal, + val category: HarmCategory.Internal? = null, + val probability: HarmProbability.Internal? = null, val blocked: Boolean? = null, // TODO(): any reason not to default to false? val probabilityScore: Float? = null, val severity: HarmSeverity.Internal? = null, @@ -129,14 +129,22 @@ internal constructor( ) { internal fun toPublic() = - SafetyRating( - category = category.toPublic(), - probability = probability.toPublic(), - probabilityScore = probabilityScore ?: 0f, - blocked = blocked, - severity = severity?.toPublic(), - severityScore = severityScore - ) + // Due to a bug in the backend, it's possible that we receive + // an invalid `SafetyRating` value, without either category or + // probability. We return null in those cases to enable + // filtering by the higher level types. + if (category == null || probability == null) { + null + } else { + SafetyRating( + category = category.toPublic(), + probability = probability.toPublic(), + probabilityScore = probabilityScore ?: 0f, + blocked = blocked, + severity = severity?.toPublic(), + severityScore = severityScore + ) + } } } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/PromptFeedback.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/PromptFeedback.kt index 817d1358f67..5f9840263eb 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/PromptFeedback.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/PromptFeedback.kt @@ -42,7 +42,7 @@ public class PromptFeedback( ) { internal fun toPublic(): PromptFeedback { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + val safetyRatings = safetyRatings?.mapNotNull { it.toPublic() }.orEmpty() return PromptFeedback(blockReason?.toPublic(), safetyRatings, blockReasonMessage) } } diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt index e5a58541123..8b6079cc6a5 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt @@ -36,7 +36,10 @@ import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withTimeout import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +@RunWith(RobolectricTestRunner::class) internal class VertexAIStreamingSnapshotTests { private val testTimeout = 5.seconds @@ -85,6 +88,18 @@ internal class VertexAIStreamingSnapshotTests { } } + @Test + fun `invalid safety ratings during image generation`() = + goldenVertexStreamingFile("streaming-success-image-invalid-safety-ratings.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val responseList = responses.toList() + + responseList.isEmpty() shouldBe false + } + } + @Test fun `unknown enum in finish reason`() = goldenVertexStreamingFile("streaming-failure-unknown-finish-enum.txt") { diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt index a19339b4981..ca1d279d288 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt @@ -55,8 +55,11 @@ import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonPrimitive import org.json.JSONArray import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner @OptIn(PublicPreviewAPI::class) +@RunWith(RobolectricTestRunner::class) internal class VertexAIUnarySnapshotTests { private val testTimeout = 5.seconds @@ -125,6 +128,16 @@ internal class VertexAIUnarySnapshotTests { } } + @Test + fun `invalid safety ratings during image generation`() = + goldenVertexUnaryFile("unary-success-image-invalid-safety-ratings.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + } + } + @Test fun `unknown enum in finish reason`() = goldenVertexUnaryFile("unary-failure-unknown-enum-finish-reason.json") {