Skip to content

Commit f7a57ea

Browse files
committed
Add change from #6925
1 parent 4162383 commit f7a57ea

File tree

4 files changed

+50
-12
lines changed

4 files changed

+50
-12
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package com.google.firebase.vertexai.type
2121
import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer
2222
import java.util.Calendar
2323
import kotlinx.serialization.ExperimentalSerializationApi
24+
import kotlinx.serialization.InternalSerializationApi
2425
import kotlinx.serialization.KSerializer
2526
import kotlinx.serialization.SerialName
2627
import kotlinx.serialization.Serializable
@@ -55,7 +56,7 @@ internal constructor(
5556
val groundingMetadata: GroundingMetadata? = null,
5657
) {
5758
internal fun toPublic(): Candidate {
58-
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
59+
val safetyRatings = safetyRatings?.mapNotNull { it.toPublic() }.orEmpty()
5960
val citations = citationMetadata?.toPublic()
6061
val finishReason = finishReason?.toPublic()
6162

@@ -128,23 +129,32 @@ internal constructor(
128129
internal data class Internal
129130
@JvmOverloads
130131
constructor(
131-
val category: HarmCategory.Internal,
132-
val probability: HarmProbability.Internal,
132+
val category: HarmCategory.Internal? = null,
133+
val probability: HarmProbability.Internal? = null,
133134
val blocked: Boolean? = null, // TODO(): any reason not to default to false?
134135
val probabilityScore: Float? = null,
135136
val severity: HarmSeverity.Internal? = null,
136137
val severityScore: Float? = null,
137138
) {
138139

139140
internal fun toPublic() =
140-
SafetyRating(
141-
category = category.toPublic(),
142-
probability = probability.toPublic(),
143-
probabilityScore = probabilityScore ?: 0f,
144-
blocked = blocked,
145-
severity = severity?.toPublic(),
146-
severityScore = severityScore
147-
)
141+
/**
142+
* Due to a bug in the backend, it's possible that we receive an invalid `SafetyRating` value,
143+
* without either category or probability. We return null in those cases to enable filtering
144+
* by the higher level types.
145+
*/
146+
if (category == null || probability == null) {
147+
null
148+
} else {
149+
SafetyRating(
150+
category = category.toPublic(),
151+
probability = probability.toPublic(),
152+
probabilityScore = probabilityScore ?: 0f,
153+
blocked = blocked,
154+
severity = severity?.toPublic(),
155+
severityScore = severityScore
156+
)
157+
}
148158
}
149159
}
150160

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public class PromptFeedback(
4646
) {
4747

4848
internal fun toPublic(): PromptFeedback {
49-
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
49+
val safetyRatings = safetyRatings?.mapNotNull { it.toPublic() }.orEmpty()
5050
return PromptFeedback(blockReason?.toPublic(), safetyRatings, blockReasonMessage)
5151
}
5252
}

firebase-vertexai/src/test/java/com/google/firebase/vertexai/VertexAIStreamingSnapshotTests.kt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ import kotlinx.coroutines.flow.collect
3636
import kotlinx.coroutines.flow.toList
3737
import kotlinx.coroutines.withTimeout
3838
import org.junit.Test
39+
import org.junit.runner.RunWith
40+
import org.robolectric.RobolectricTestRunner
3941

42+
@RunWith(RobolectricTestRunner::class)
4043
internal class VertexAIStreamingSnapshotTests {
4144
private val testTimeout = 5.seconds
4245

@@ -69,6 +72,18 @@ internal class VertexAIStreamingSnapshotTests {
6972
}
7073
}
7174

75+
@Test
76+
fun `invalid safety ratings during image generation`() =
77+
goldenVertexStreamingFile("streaming-success-image-invalid-safety-ratings.txt") {
78+
val responses = model.generateContentStream("prompt")
79+
80+
withTimeout(testTimeout) {
81+
val responseList = responses.toList()
82+
83+
responseList.isEmpty() shouldBe false
84+
}
85+
}
86+
7287
@Test
7388
fun `unknown enum in safety ratings`() =
7489
goldenVertexStreamingFile("streaming-success-unknown-safety-enum.txt") {

firebase-vertexai/src/test/java/com/google/firebase/vertexai/VertexAIUnarySnapshotTests.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@ import kotlinx.serialization.json.jsonObject
5555
import kotlinx.serialization.json.jsonPrimitive
5656
import org.json.JSONArray
5757
import org.junit.Test
58+
import org.junit.runner.RunWith
59+
import org.robolectric.RobolectricTestRunner
5860

5961
@OptIn(PublicPreviewAPI::class)
62+
@RunWith(RobolectricTestRunner::class)
6063
internal class VertexAIUnarySnapshotTests {
6164
private val testTimeout = 5.seconds
6265

@@ -125,6 +128,16 @@ internal class VertexAIUnarySnapshotTests {
125128
}
126129
}
127130

131+
@Test
132+
fun `invalid safety ratings during image generation`() =
133+
goldenVertexUnaryFile("unary-success-image-invalid-safety-ratings.json") {
134+
withTimeout(testTimeout) {
135+
val response = model.generateContent("prompt")
136+
137+
response.candidates.isEmpty() shouldBe false
138+
}
139+
}
140+
128141
@Test
129142
fun `unknown enum in finish reason`() =
130143
goldenVertexUnaryFile("unary-failure-unknown-enum-finish-reason.json") {

0 commit comments

Comments
 (0)