diff --git a/firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java b/firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java index 5e363ed95b2..4d5d1185f14 100644 --- a/firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java +++ b/firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java @@ -17,11 +17,15 @@ package java.com.google.firebase.ai; import android.graphics.Bitmap; +import androidx.annotation.Nullable; import com.google.common.util.concurrent.ListenableFuture; import com.google.firebase.ai.FirebaseAI; import com.google.firebase.ai.GenerativeModel; +import com.google.firebase.ai.LiveGenerativeModel; import com.google.firebase.ai.java.ChatFutures; import com.google.firebase.ai.java.GenerativeModelFutures; +import com.google.firebase.ai.java.LiveModelFutures; +import com.google.firebase.ai.java.LiveSessionFutures; import com.google.firebase.ai.type.BlockReason; import com.google.firebase.ai.type.Candidate; import com.google.firebase.ai.type.Citation; @@ -32,25 +36,40 @@ import com.google.firebase.ai.type.FileDataPart; import com.google.firebase.ai.type.FinishReason; import com.google.firebase.ai.type.FunctionCallPart; +import com.google.firebase.ai.type.FunctionResponsePart; import com.google.firebase.ai.type.GenerateContentResponse; +import com.google.firebase.ai.type.GenerationConfig; import com.google.firebase.ai.type.HarmCategory; import com.google.firebase.ai.type.HarmProbability; import com.google.firebase.ai.type.HarmSeverity; import com.google.firebase.ai.type.ImagePart; import com.google.firebase.ai.type.InlineDataPart; +import com.google.firebase.ai.type.LiveGenerationConfig; +import com.google.firebase.ai.type.LiveServerContent; +import com.google.firebase.ai.type.LiveServerMessage; +import com.google.firebase.ai.type.LiveServerSetupComplete; +import com.google.firebase.ai.type.LiveServerToolCall; +import com.google.firebase.ai.type.LiveServerToolCallCancellation; +import com.google.firebase.ai.type.MediaData; import com.google.firebase.ai.type.ModalityTokenCount; import com.google.firebase.ai.type.Part; import com.google.firebase.ai.type.PromptFeedback; +import com.google.firebase.ai.type.PublicPreviewAPI; +import com.google.firebase.ai.type.ResponseModality; import com.google.firebase.ai.type.SafetyRating; +import com.google.firebase.ai.type.SpeechConfig; import com.google.firebase.ai.type.TextPart; import com.google.firebase.ai.type.UsageMetadata; +import com.google.firebase.ai.type.Voices; import com.google.firebase.concurrent.FirebaseExecutors; import java.util.Calendar; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; +import kotlin.OptIn; import kotlinx.serialization.json.JsonElement; import kotlinx.serialization.json.JsonNull; +import kotlinx.serialization.json.JsonObject; import org.junit.Assert; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -59,13 +78,36 @@ /** * Tests in this file exist to be compiled, not invoked */ +@OptIn(markerClass = PublicPreviewAPI.class) public class JavaCompileTests { public void initializeJava() throws Exception { FirebaseAI vertex = FirebaseAI.getInstance(); - GenerativeModel model = vertex.generativeModel("fake-model-name"); + GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig()); + LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig()); GenerativeModelFutures futures = GenerativeModelFutures.from(model); + LiveModelFutures liveFutures = LiveModelFutures.from(live); testFutures(futures); + testLiveFutures(liveFutures); + } + + private GenerationConfig getConfig() { + return new GenerationConfig.Builder().build(); + // TODO b/406558430 GenerationConfig.Builder.setParts returns void + } + + private LiveGenerationConfig getLiveConfig() { + return new LiveGenerationConfig.Builder() + .setTopK(10) + .setTopP(11.0F) + .setTemperature(32.0F) + .setCandidateCount(1) + .setMaxOutputTokens(0xCAFEBABE) + .setFrequencyPenalty(1.0F) + .setPresencePenalty(2.0F) + .setResponseModality(ResponseModality.AUDIO) + .setSpeechConfig(new SpeechConfig(Voices.AOEDE)) + .build(); } private void testFutures(GenerativeModelFutures futures) throws Exception { @@ -159,7 +201,10 @@ public void validateCandidates(List candidates) { } } - public void validateContent(Content content) { + public void validateContent(@Nullable Content content) { + if (content == null) { + return; + } String role = content.getRole(); for (Part part : content.getParts()) { if (part instanceof TextPart) { @@ -236,4 +281,67 @@ public void validateUsageMetadata(UsageMetadata metadata) { } } } + + private void testLiveFutures(LiveModelFutures futures) throws Exception { + LiveSessionFutures session = futures.connect().get(); + session + .receive() + .subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(LiveServerMessage message) { + validateLiveContentResponse(message); + } + + @Override + public void onError(Throwable t) { + // Ignore + } + + @Override + public void onComplete() { + // Also ignore + } + }); + + session.send("Fake message"); + session.send(new Content.Builder().addText("Fake message").build()); + + byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE}; + session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl"))); + + FunctionResponsePart functionResponse = + new FunctionResponsePart("myFunction", new JsonObject(Map.of())); + session.sendFunctionResponse(List.of(functionResponse, functionResponse)); + + session.startAudioConversation(part -> functionResponse); + session.startAudioConversation(); + session.stopAudioConversation(); + session.stopReceiving(); + session.close(); + } + + private void validateLiveContentResponse(LiveServerMessage message) { + if (message instanceof LiveServerContent) { + LiveServerContent content = (LiveServerContent) message; + validateContent(content.getContent()); + boolean complete = content.getGenerationComplete(); + boolean interrupted = content.getInterrupted(); + boolean turnComplete = content.getTurnComplete(); + } else if (message instanceof LiveServerSetupComplete) { + LiveServerSetupComplete setup = (LiveServerSetupComplete) message; + // No methods + } else if (message instanceof LiveServerToolCall) { + LiveServerToolCall call = (LiveServerToolCall) message; + validateFunctionCalls(call.getFunctionCalls()); + } else if (message instanceof LiveServerToolCallCancellation) { + LiveServerToolCallCancellation cancel = (LiveServerToolCallCancellation) message; + List functions = cancel.getFunctionIds(); + } + } } diff --git a/firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java b/firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java index 066e672ffb8..cf71db18798 100644 --- a/firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java +++ b/firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java @@ -21,8 +21,11 @@ import com.google.firebase.concurrent.FirebaseExecutors; import com.google.firebase.vertexai.FirebaseVertexAI; import com.google.firebase.vertexai.GenerativeModel; +import com.google.firebase.vertexai.LiveGenerativeModel; import com.google.firebase.vertexai.java.ChatFutures; import com.google.firebase.vertexai.java.GenerativeModelFutures; +import com.google.firebase.vertexai.java.LiveModelFutures; +import com.google.firebase.vertexai.java.LiveSessionFutures; import com.google.firebase.vertexai.type.BlockReason; import com.google.firebase.vertexai.type.Candidate; import com.google.firebase.vertexai.type.Citation; @@ -33,24 +36,33 @@ import com.google.firebase.vertexai.type.FileDataPart; import com.google.firebase.vertexai.type.FinishReason; import com.google.firebase.vertexai.type.FunctionCallPart; +import com.google.firebase.vertexai.type.FunctionResponsePart; import com.google.firebase.vertexai.type.GenerateContentResponse; +import com.google.firebase.vertexai.type.GenerationConfig; import com.google.firebase.vertexai.type.HarmCategory; import com.google.firebase.vertexai.type.HarmProbability; import com.google.firebase.vertexai.type.HarmSeverity; import com.google.firebase.vertexai.type.ImagePart; import com.google.firebase.vertexai.type.InlineDataPart; +import com.google.firebase.vertexai.type.LiveContentResponse; +import com.google.firebase.vertexai.type.LiveGenerationConfig; +import com.google.firebase.vertexai.type.MediaData; import com.google.firebase.vertexai.type.ModalityTokenCount; import com.google.firebase.vertexai.type.Part; import com.google.firebase.vertexai.type.PromptFeedback; +import com.google.firebase.vertexai.type.ResponseModality; import com.google.firebase.vertexai.type.SafetyRating; +import com.google.firebase.vertexai.type.SpeechConfig; import com.google.firebase.vertexai.type.TextPart; import com.google.firebase.vertexai.type.UsageMetadata; +import com.google.firebase.vertexai.type.Voices; import java.util.Calendar; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; import kotlinx.serialization.json.JsonElement; import kotlinx.serialization.json.JsonNull; +import kotlinx.serialization.json.JsonObject; import org.junit.Assert; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -63,9 +75,31 @@ public class JavaCompileTests { public void initializeJava() throws Exception { FirebaseVertexAI vertex = FirebaseVertexAI.getInstance(); - GenerativeModel model = vertex.generativeModel("fake-model-name"); + GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig()); + LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig()); GenerativeModelFutures futures = GenerativeModelFutures.from(model); + LiveModelFutures liveFutures = LiveModelFutures.from(live); testFutures(futures); + testLiveFutures(liveFutures); + } + + private GenerationConfig getConfig() { + return new GenerationConfig.Builder().build(); + // TODO b/406558430 GenerationConfig.Builder.setParts returns void + } + + private LiveGenerationConfig getLiveConfig() { + return new LiveGenerationConfig.Builder() + .setTopK(10) + .setTopP(11.0F) + .setTemperature(32.0F) + .setCandidateCount(1) + .setMaxOutputTokens(0xCAFEBABE) + .setFrequencyPenalty(1.0F) + .setPresencePenalty(2.0F) + .setResponseModality(ResponseModality.AUDIO) + .setSpeechConfig(new SpeechConfig(Voices.AOEDE)) + .build(); } private void testFutures(GenerativeModelFutures futures) throws Exception { @@ -236,4 +270,62 @@ public void validateUsageMetadata(UsageMetadata metadata) { } } } + + private void testLiveFutures(LiveModelFutures futures) throws Exception { + LiveSessionFutures session = futures.connect().get(); + session + .receive() + .subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(LiveContentResponse response) { + validateLiveContentResponse(response); + } + + @Override + public void onError(Throwable t) { + // Ignore + } + + @Override + public void onComplete() { + // Also ignore + } + }); + + session.send("Fake message"); + session.send(new Content.Builder().addText("Fake message").build()); + + byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE}; + session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl"))); + + FunctionResponsePart functionResponse = + new FunctionResponsePart("myFunction", new JsonObject(Map.of())); + session.sendFunctionResponse(List.of(functionResponse, functionResponse)); + + session.startAudioConversation(part -> functionResponse); + session.startAudioConversation(); + session.stopAudioConversation(); + session.stopReceiving(); + session.close(); + } + + private void validateLiveContentResponse(LiveContentResponse response) { + // int status = response.getStatus(); + // Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL()); + // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED()); + // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE()); + // TODO b/412743328 LiveContentResponse.Status inaccessible for Java users + Content data = response.getData(); + if (data != null) { + validateContent(data); + } + String text = response.getText(); + validateFunctionCalls(response.getFunctionCalls()); + } }