Skip to content

Commit 89d2b15

Browse files
feat: [vertexai] Support Function calling (#10242)
PiperOrigin-RevId: 599919035 Co-authored-by: Jaycee Li <[email protected]>
1 parent 2da4e3e commit 89d2b15

File tree

2 files changed

+156
-6
lines changed

2 files changed

+156
-6
lines changed

java-vertexai/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ If you are using Maven with [BOM][libraries-bom], add this to your pom.xml file:
1818
<dependency>
1919
<groupId>com.google.cloud</groupId>
2020
<artifactId>libraries-bom</artifactId>
21-
<version>26.30.0</version>
21+
<version>26.29.0</version>
2222
<type>pom</type>
2323
<scope>import</scope>
2424
</dependency>

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java

+155-5
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
import com.google.cloud.vertexai.api.CountTokensRequest;
2323
import com.google.cloud.vertexai.api.CountTokensResponse;
2424
import com.google.cloud.vertexai.api.GenerateContentRequest;
25-
import com.google.cloud.vertexai.api.GenerateContentRequest.Builder;
2625
import com.google.cloud.vertexai.api.GenerateContentResponse;
2726
import com.google.cloud.vertexai.api.GenerationConfig;
2827
import com.google.cloud.vertexai.api.Part;
2928
import com.google.cloud.vertexai.api.SafetySetting;
29+
import com.google.cloud.vertexai.api.Tool;
3030
import java.io.IOException;
3131
import java.util.ArrayList;
3232
import java.util.Arrays;
@@ -40,8 +40,131 @@ public class GenerativeModel {
4040
private final VertexAI vertexAi;
4141
private GenerationConfig generationConfig = null;
4242
private List<SafetySetting> safetySettings = null;
43+
private List<Tool> tools = null;
4344
private Transport transport;
4445

46+
public static Builder newBuilder() {
47+
return new Builder();
48+
}
49+
50+
private GenerativeModel(Builder builder) {
51+
this.modelName = builder.modelName;
52+
53+
this.vertexAi = builder.vertexAi;
54+
55+
this.resourceName =
56+
String.format(
57+
"projects/%s/locations/%s/publishers/google/models/%s",
58+
this.vertexAi.getProjectId(), this.vertexAi.getLocation(), this.modelName);
59+
60+
if (builder.generationConfig != null) {
61+
this.generationConfig = builder.generationConfig;
62+
}
63+
if (builder.safetySettings != null) {
64+
this.safetySettings = builder.safetySettings;
65+
}
66+
if (builder.tools != null) {
67+
this.tools = builder.tools;
68+
}
69+
70+
if (builder.transport != null) {
71+
this.transport = builder.transport;
72+
} else {
73+
this.transport = this.vertexAi.getTransport();
74+
}
75+
}
76+
77+
/** Builder class for {@link GenerativeModel}. */
78+
public static class Builder {
79+
private String modelName;
80+
private VertexAI vertexAi;
81+
private GenerationConfig generationConfig;
82+
private List<SafetySetting> safetySettings;
83+
private List<Tool> tools;
84+
private Transport transport;
85+
86+
private Builder() {}
87+
88+
public GenerativeModel build() {
89+
if (this.modelName == null) {
90+
throw new IllegalArgumentException(
91+
"modelName is required. Please call setModelName() before building.");
92+
}
93+
if (this.vertexAi == null) {
94+
throw new IllegalArgumentException(
95+
"vertexAi is required. Please call setVertexAi() before building.");
96+
}
97+
return new GenerativeModel(this);
98+
}
99+
100+
/**
101+
* Set the name of the generative model. This is required for building a GenerativeModel
102+
* instance. Supported format: "gemini-pro", "models/gemini-pro",
103+
* "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
104+
* names can be found at
105+
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
106+
*/
107+
public Builder setModelName(String modelName) {
108+
this.modelName = validateModelName(modelName);
109+
return this;
110+
}
111+
112+
/**
113+
* Set {@link com.google.cloud.vertexai.VertexAI} that contains the default configs for the
114+
* generative model. This is required for building a GenerativeModel instance.
115+
*/
116+
public Builder setVertexAi(VertexAI vertexAi) {
117+
this.vertexAi = vertexAi;
118+
return this;
119+
}
120+
121+
/**
122+
* Set {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used by default to
123+
* interact with the generative model.
124+
*/
125+
public Builder setGenerationConfig(GenerationConfig generationConfig) {
126+
this.generationConfig = generationConfig;
127+
return this;
128+
}
129+
130+
/**
131+
* Set a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used by
132+
* default to interact with the generative model.
133+
*/
134+
public Builder setSafetySettings(List<SafetySetting> safetySettings) {
135+
this.safetySettings = new ArrayList<>();
136+
for (SafetySetting safetySetting : safetySettings) {
137+
if (safetySetting != null) {
138+
this.safetySettings.add(safetySetting);
139+
}
140+
}
141+
return this;
142+
}
143+
144+
/**
145+
* Set a list of {@link com.google.cloud.vertexai.api.Tool} that will be used by default to
146+
* interact with the generative model.
147+
*/
148+
public Builder setTools(List<Tool> tools) {
149+
this.tools = new ArrayList<>();
150+
for (Tool tool : tools) {
151+
if (tool != null) {
152+
this.tools.add(tool);
153+
}
154+
}
155+
return this;
156+
}
157+
158+
/**
159+
* Set the {@link Transport} layer for API calls in the generative model. It overrides the
160+
* transport setting in {@link com.google.cloud.vertexai.VertexAI}
161+
*/
162+
public Builder setTransport(Transport transport) {
163+
this.transport = transport;
164+
return this;
165+
}
166+
}
167+
45168
/**
46169
* Construct a GenerativeModel instance.
47170
*
@@ -384,7 +507,8 @@ public GenerateContentResponse generateContent(
384507
public GenerateContentResponse generateContent(
385508
List<Content> contents, GenerationConfig generationConfig, List<SafetySetting> safetySettings)
386509
throws IOException {
387-
Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents);
510+
GenerateContentRequest.Builder requestBuilder =
511+
GenerateContentRequest.newBuilder().addAllContents(contents);
388512
if (generationConfig != null) {
389513
requestBuilder.setGenerationConfig(generationConfig);
390514
} else if (this.generationConfig != null) {
@@ -395,6 +519,9 @@ public GenerateContentResponse generateContent(
395519
} else if (this.safetySettings != null) {
396520
requestBuilder.addAllSafetySettings(this.safetySettings);
397521
}
522+
if (this.tools != null) {
523+
requestBuilder.addAllTools(this.tools);
524+
}
398525
return ResponseHandler.aggregateStreamIntoResponse(generateContentStream(requestBuilder));
399526
}
400527

@@ -655,7 +782,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
655782
public ResponseStream<GenerateContentResponse> generateContentStream(
656783
List<Content> contents, GenerationConfig generationConfig, List<SafetySetting> safetySettings)
657784
throws IOException {
658-
Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents);
785+
GenerateContentRequest.Builder requestBuilder =
786+
GenerateContentRequest.newBuilder().addAllContents(contents);
659787
if (generationConfig != null) {
660788
requestBuilder.setGenerationConfig(generationConfig);
661789
} else if (this.generationConfig != null) {
@@ -666,6 +794,9 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
666794
} else if (this.safetySettings != null) {
667795
requestBuilder.addAllSafetySettings(this.safetySettings);
668796
}
797+
if (this.tools != null) {
798+
requestBuilder.addAllTools(this.tools);
799+
}
669800
return generateContentStream(requestBuilder);
670801
}
671802

@@ -678,8 +809,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
678809
* com.google.cloud.vertexai.api.GenerateContentResponse}
679810
* @throws IOException if an I/O error occurs while making the API call
680811
*/
681-
private ResponseStream<GenerateContentResponse> generateContentStream(Builder requestBuilder)
682-
throws IOException {
812+
private ResponseStream<GenerateContentResponse> generateContentStream(
813+
GenerateContentRequest.Builder requestBuilder) throws IOException {
683814
GenerateContentRequest request = requestBuilder.setModel(this.resourceName).build();
684815
ResponseStream<GenerateContentResponse> responseStream = null;
685816
if (this.transport == Transport.REST) {
@@ -723,6 +854,16 @@ public void setSafetySettings(List<SafetySetting> safetySettings) {
723854
}
724855
}
725856

857+
/**
858+
* Sets the value for {@link #getTools}, which will be used by default for generating response.
859+
*/
860+
public void setTools(List<Tool> tools) {
861+
this.tools = new ArrayList<>();
862+
for (Tool tool : tools) {
863+
this.tools.add(tool);
864+
}
865+
}
866+
726867
/**
727868
* Sets the value for {@link #getTransport}, which defines the layer for API calls in this
728869
* generative model.
@@ -760,6 +901,15 @@ public List<SafetySetting> getSafetySettings() {
760901
}
761902
}
762903

904+
/** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this generative model. */
905+
public List<Tool> getTools() {
906+
if (this.tools != null) {
907+
return Collections.unmodifiableList(this.tools);
908+
} else {
909+
return null;
910+
}
911+
}
912+
763913
public ChatSession startChat() {
764914
return new ChatSession(this);
765915
}

0 commit comments

Comments
 (0)