Skip to content

Commit 8d5c103

Browse files
authored
Add more VertexAI unit tests (#6104)
1 parent a41af7d commit 8d5c103

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ internal class StreamingSnapshotTests {
8686
}
8787
}
8888

89+
@Test
90+
fun `unknown enum in finish reason`() =
91+
goldenStreamingFile("failure-unknown-finish-enum.txt") {
92+
val responses = model.generateContentStream("prompt")
93+
94+
withTimeout(testTimeout) {
95+
val exception = shouldThrow<ResponseStoppedException> { responses.collect() }
96+
exception.response.candidates.first().finishReason shouldBe FinishReason.UNKNOWN
97+
}
98+
}
99+
89100
@Test
90101
fun `quotes escaped`() =
91102
goldenStreamingFile("success-quotes-escaped.txt") {
@@ -184,4 +195,20 @@ internal class StreamingSnapshotTests {
184195

185196
withTimeout(testTimeout) { shouldThrow<InvalidAPIKeyException> { responses.collect() } }
186197
}
198+
199+
@Test
200+
fun `invalid json`() =
201+
goldenStreamingFile("failure-invalid-json.txt") {
202+
val responses = model.generateContentStream("prompt")
203+
204+
withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
205+
}
206+
207+
@Test
208+
fun `malformed content`() =
209+
goldenStreamingFile("failure-malformed-content.txt") {
210+
val responses = model.generateContentStream("prompt")
211+
212+
withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
213+
}
187214
}

firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import io.kotest.matchers.nulls.shouldNotBeNull
3838
import io.kotest.matchers.should
3939
import io.kotest.matchers.shouldBe
4040
import io.kotest.matchers.shouldNotBe
41+
import io.kotest.matchers.string.shouldContain
4142
import io.kotest.matchers.string.shouldNotBeEmpty
4243
import io.kotest.matchers.types.shouldBeInstanceOf
4344
import io.ktor.http.HttpStatusCode
@@ -84,6 +85,56 @@ internal class UnarySnapshotTests {
8485
response.candidates.isEmpty() shouldBe false
8586
val candidate = response.candidates.first()
8687
candidate.safetyRatings.any { it.category == HarmCategory.UNKNOWN } shouldBe true
88+
response.promptFeedback?.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } shouldBe
89+
true
90+
}
91+
}
92+
93+
@Test
94+
fun `unknown enum in finish reason`() =
95+
goldenUnaryFile("failure-unknown-enum-finish-reason.json") {
96+
withTimeout(testTimeout) {
97+
shouldThrow<ResponseStoppedException> { model.generateContent("prompt") } should
98+
{
99+
it.response.candidates.first().finishReason shouldBe FinishReason.UNKNOWN
100+
}
101+
}
102+
}
103+
104+
@Test
105+
fun `unknown enum in block reason`() =
106+
goldenUnaryFile("failure-unknown-enum-prompt-blocked.json") {
107+
withTimeout(testTimeout) {
108+
shouldThrow<PromptBlockedException> { model.generateContent("prompt") } should
109+
{
110+
it.response.promptFeedback?.blockReason shouldBe BlockReason.UNKNOWN
111+
}
112+
}
113+
}
114+
115+
@Test
116+
fun `quotes escaped`() =
117+
goldenUnaryFile("success-quote-reply.json") {
118+
withTimeout(testTimeout) {
119+
val response = model.generateContent("prompt")
120+
121+
response.candidates.isEmpty() shouldBe false
122+
response.candidates.first().content.parts.isEmpty() shouldBe false
123+
val part = response.candidates.first().content.parts.first() as TextPart
124+
part.text shouldContain "\""
125+
}
126+
}
127+
128+
@Test
129+
fun `safetyRatings missing`() =
130+
goldenUnaryFile("success-missing-safety-ratings.json") {
131+
withTimeout(testTimeout) {
132+
val response = model.generateContent("prompt")
133+
134+
response.candidates.isEmpty() shouldBe false
135+
response.candidates.first().content.parts.isEmpty() shouldBe false
136+
response.candidates.first().safetyRatings.isEmpty() shouldBe true
137+
response.promptFeedback?.safetyRatings?.isEmpty() shouldBe true
87138
}
88139
}
89140

@@ -147,6 +198,15 @@ internal class UnarySnapshotTests {
147198
}
148199
}
149200

201+
@Test
202+
fun `stopped for safety with no content`() =
203+
goldenUnaryFile("failure-finish-reason-safety-no-content.json") {
204+
withTimeout(testTimeout) {
205+
val exception = shouldThrow<ResponseStoppedException> { model.generateContent("prompt") }
206+
exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY
207+
}
208+
}
209+
150210
@Test
151211
fun `citation returns correctly`() =
152212
goldenUnaryFile("success-citations.json") {
@@ -292,4 +352,79 @@ internal class UnarySnapshotTests {
292352
callPart.args["current"] shouldBe "true"
293353
}
294354
}
355+
356+
@Test
357+
fun `function call contains no arguments`() =
358+
goldenUnaryFile("success-function-call-no-arguments.json") {
359+
withTimeout(testTimeout) {
360+
val response = model.generateContent("prompt")
361+
val callPart = response.functionCalls.shouldNotBeEmpty().first()
362+
363+
callPart.name shouldBe "current_time"
364+
callPart.args.isEmpty() shouldBe true
365+
}
366+
}
367+
368+
@Test
369+
fun `function call contains arguments`() =
370+
goldenUnaryFile("success-function-call-with-arguments.json") {
371+
withTimeout(testTimeout) {
372+
val response = model.generateContent("prompt")
373+
val callPart = response.functionCalls.shouldNotBeEmpty().first()
374+
375+
callPart.name shouldBe "sum"
376+
callPart.args["x"] shouldBe "4"
377+
callPart.args["y"] shouldBe "5"
378+
}
379+
}
380+
381+
@Test
382+
fun `function call with parallel calls`() =
383+
goldenUnaryFile("success-function-call-parallel-calls.json") {
384+
withTimeout(testTimeout) {
385+
val response = model.generateContent("prompt")
386+
val callList = response.functionCalls
387+
388+
callList.size shouldBe 3
389+
callList.forEach {
390+
it.name shouldBe "sum"
391+
it.args.size shouldBe 2
392+
}
393+
}
394+
}
395+
396+
@Test
397+
fun `function call with mixed content`() =
398+
goldenUnaryFile("success-function-call-mixed-content.json") {
399+
withTimeout(testTimeout) {
400+
val response = model.generateContent("prompt")
401+
val callList = response.functionCalls
402+
403+
response.text shouldBe "The sum of [1, 2, 3] is"
404+
callList.size shouldBe 2
405+
callList.forEach { it.args.size shouldBe 2 }
406+
}
407+
}
408+
409+
@Test
410+
fun `countTokens succeeds`() =
411+
goldenUnaryFile("success-total-tokens.json") {
412+
withTimeout(testTimeout) {
413+
val response = model.countTokens("prompt")
414+
415+
response.totalTokens shouldBe 6
416+
response.totalBillableCharacters shouldBe 16
417+
}
418+
}
419+
420+
@Test
421+
fun `countTokens succeeds with no billable characters`() =
422+
goldenUnaryFile("success-no-billable-characters.json") {
423+
withTimeout(testTimeout) {
424+
val response = model.countTokens("prompt")
425+
426+
response.totalTokens shouldBe 258
427+
response.totalBillableCharacters shouldBe 0
428+
}
429+
}
295430
}

0 commit comments

Comments
 (0)