Skip to content

Fix AI builders for Java consumers #6930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@ constructor(public val role: String? = "user", public val parts: List<Part>) {
public class Builder {

/** The producer of the content. Must be either 'user' or 'model'. By default, it's "user". */
public var role: String? = "user"
@JvmField public var role: String? = "user"

/**
* The mutable list of [Part]s comprising the [Content].
*
* Prefer using the provided helper methods over modifying this list directly.
*/
public var parts: MutableList<Part> = arrayListOf()
@JvmField public var parts: MutableList<Part> = arrayListOf()

public fun setRole(role: String?): Content.Builder = apply { this.role = role }
public fun setParts(parts: MutableList<Part>): Content.Builder = apply { this.parts = parts }

/** Adds a new [Part] to [parts]. */
@JvmName("addPart")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,36 @@ private constructor(
@JvmField public var responseSchema: Schema? = null
@JvmField public var responseModalities: List<ResponseModality>? = null

public fun setTemperature(temperature: Float?): Builder = apply {
this.temperature = temperature
}
public fun setTopK(topK: Int?): Builder = apply { this.topK = topK }
public fun setTopP(topP: Float?): Builder = apply { this.topP = topP }
public fun setCandidateCount(candidateCount: Int?): Builder = apply {
this.candidateCount = candidateCount
}
public fun setMaxOutputTokens(maxOutputTokens: Int?): Builder = apply {
this.maxOutputTokens = maxOutputTokens
}
public fun setPresencePenalty(presencePenalty: Float?): Builder = apply {
this.presencePenalty = presencePenalty
}
public fun setFrequencyPenalty(frequencyPenalty: Float?): Builder = apply {
this.frequencyPenalty = frequencyPenalty
}
public fun setStopSequences(stopSequences: List<String>?): Builder = apply {
this.stopSequences = stopSequences
}
public fun setResponseMimeType(responseMimeType: String?): Builder = apply {
this.responseMimeType = responseMimeType
}
public fun setResponseSchema(responseSchema: Schema?): Builder = apply {
this.responseSchema = responseSchema
}
public fun setResponseModalities(responseModalities: List<ResponseModality>?): Builder = apply {
this.responseModalities = responseModalities
}

/** Create a new [GenerationConfig] with the attached arguments. */
public fun build(): GenerationConfig =
GenerationConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@
import com.google.firebase.ai.type.PublicPreviewAPI;
import com.google.firebase.ai.type.ResponseModality;
import com.google.firebase.ai.type.SafetyRating;
import com.google.firebase.ai.type.Schema;
import com.google.firebase.ai.type.SpeechConfig;
import com.google.firebase.ai.type.TextPart;
import com.google.firebase.ai.type.UsageMetadata;
import com.google.firebase.ai.type.Voices;
import com.google.firebase.concurrent.FirebaseExecutors;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -92,8 +94,37 @@ public void initializeJava() throws Exception {
}

private GenerationConfig getConfig() {
return new GenerationConfig.Builder().build();
// TODO b/406558430 GenerationConfig.Builder.setParts returns void
return new GenerationConfig.Builder()
.setTopK(10)
.setTopP(11.0F)
.setTemperature(32.0F)
.setCandidateCount(1)
.setMaxOutputTokens(0xCAFEBABE)
.setFrequencyPenalty(1.0F)
.setPresencePenalty(2.0F)
.setStopSequences(List.of("foo", "bar"))
.setResponseMimeType("image/jxl")
.setResponseModalities(List.of(ResponseModality.TEXT, ResponseModality.TEXT))
.setResponseSchema(getSchema())
.build();
}

private Schema getSchema() {
return Schema.obj(
Map.of(
"foo", Schema.numInt(),
"bar", Schema.numInt("Some integer"),
"baz", Schema.numInt("Some integer", false),
"qux", Schema.numDouble(),
"quux", Schema.numFloat("Some floating point number"),
"xyzzy", Schema.array(Schema.numInt(), "A list of integers"),
"fee", Schema.numLong(),
"ber",
Schema.obj(
Map.of(
"bez", Schema.array(Schema.numDouble("Nullable double", true)),
"qez", Schema.enumeration(List.of("A", "B", "C"), "One of 3 letters"),
"qeez", Schema.str("A funny string")))));
}

private LiveGenerationConfig getLiveConfig() {
Expand All @@ -113,13 +144,14 @@ private LiveGenerationConfig getLiveConfig() {
private void testFutures(GenerativeModelFutures futures) throws Exception {
Content content =
new Content.Builder()
.setParts(new ArrayList<>())
.addText("Fake prompt")
.addFileData("fakeuri", "image/png")
.addInlineData(new byte[] {}, "text/json")
.addImage(Bitmap.createBitmap(0, 0, Bitmap.Config.HARDWARE))
.addPart(new FunctionCallPart("fakeFunction", Map.of("fakeArg", JsonNull.INSTANCE)))
.setRole("user")
.build();
// TODO b/406558430 Content.Builder.setParts and Content.Builder.setRole return void
Executor executor = FirebaseExecutors.directExecutor();
ListenableFuture<CountTokensResponse> countResponse = futures.countTokens(content);
validateCountTokensResponse(countResponse.get());
Expand Down
Loading