22
22
import com .google .cloud .vertexai .api .CountTokensRequest ;
23
23
import com .google .cloud .vertexai .api .CountTokensResponse ;
24
24
import com .google .cloud .vertexai .api .GenerateContentRequest ;
25
- import com .google .cloud .vertexai .api .GenerateContentRequest .Builder ;
26
25
import com .google .cloud .vertexai .api .GenerateContentResponse ;
27
26
import com .google .cloud .vertexai .api .GenerationConfig ;
28
27
import com .google .cloud .vertexai .api .Part ;
29
28
import com .google .cloud .vertexai .api .SafetySetting ;
29
+ import com .google .cloud .vertexai .api .Tool ;
30
30
import java .io .IOException ;
31
31
import java .util .ArrayList ;
32
32
import java .util .Arrays ;
@@ -40,8 +40,131 @@ public class GenerativeModel {
40
40
private final VertexAI vertexAi ;
41
41
private GenerationConfig generationConfig = null ;
42
42
private List <SafetySetting > safetySettings = null ;
43
+ private List <Tool > tools = null ;
43
44
private Transport transport ;
44
45
46
+ public static Builder newBuilder () {
47
+ return new Builder ();
48
+ }
49
+
50
+ private GenerativeModel (Builder builder ) {
51
+ this .modelName = builder .modelName ;
52
+
53
+ this .vertexAi = builder .vertexAi ;
54
+
55
+ this .resourceName =
56
+ String .format (
57
+ "projects/%s/locations/%s/publishers/google/models/%s" ,
58
+ this .vertexAi .getProjectId (), this .vertexAi .getLocation (), this .modelName );
59
+
60
+ if (builder .generationConfig != null ) {
61
+ this .generationConfig = builder .generationConfig ;
62
+ }
63
+ if (builder .safetySettings != null ) {
64
+ this .safetySettings = builder .safetySettings ;
65
+ }
66
+ if (builder .tools != null ) {
67
+ this .tools = builder .tools ;
68
+ }
69
+
70
+ if (builder .transport != null ) {
71
+ this .transport = builder .transport ;
72
+ } else {
73
+ this .transport = this .vertexAi .getTransport ();
74
+ }
75
+ }
76
+
77
+ /** Builder class for {@link GenerativeModel}. */
78
+ public static class Builder {
79
+ private String modelName ;
80
+ private VertexAI vertexAi ;
81
+ private GenerationConfig generationConfig ;
82
+ private List <SafetySetting > safetySettings ;
83
+ private List <Tool > tools ;
84
+ private Transport transport ;
85
+
86
+ private Builder () {}
87
+
88
+ public GenerativeModel build () {
89
+ if (this .modelName == null ) {
90
+ throw new IllegalArgumentException (
91
+ "modelName is required. Please call setModelName() before building." );
92
+ }
93
+ if (this .vertexAi == null ) {
94
+ throw new IllegalArgumentException (
95
+ "vertexAi is required. Please call setVertexAi() before building." );
96
+ }
97
+ return new GenerativeModel (this );
98
+ }
99
+
100
+ /**
101
+ * Set the name of the generative model. This is required for building a GenerativeModel
102
+ * instance. Supported format: "gemini-pro", "models/gemini-pro",
103
+ * "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
104
+ * names can be found at
105
+ * https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
106
+ */
107
+ public Builder setModelName (String modelName ) {
108
+ this .modelName = validateModelName (modelName );
109
+ return this ;
110
+ }
111
+
112
+ /**
113
+ * Set {@link com.google.cloud.vertexai.VertexAI} that contains the default configs for the
114
+ * generative model. This is required for building a GenerativeModel instance.
115
+ */
116
+ public Builder setVertexAi (VertexAI vertexAi ) {
117
+ this .vertexAi = vertexAi ;
118
+ return this ;
119
+ }
120
+
121
+ /**
122
+ * Set {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used by default to
123
+ * interact with the generative model.
124
+ */
125
+ public Builder setGenerationConfig (GenerationConfig generationConfig ) {
126
+ this .generationConfig = generationConfig ;
127
+ return this ;
128
+ }
129
+
130
+ /**
131
+ * Set a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used by
132
+ * default to interact with the generative model.
133
+ */
134
+ public Builder setSafetySettings (List <SafetySetting > safetySettings ) {
135
+ this .safetySettings = new ArrayList <>();
136
+ for (SafetySetting safetySetting : safetySettings ) {
137
+ if (safetySetting != null ) {
138
+ this .safetySettings .add (safetySetting );
139
+ }
140
+ }
141
+ return this ;
142
+ }
143
+
144
+ /**
145
+ * Set a list of {@link com.google.cloud.vertexai.api.Tool} that will be used by default to
146
+ * interact with the generative model.
147
+ */
148
+ public Builder setTools (List <Tool > tools ) {
149
+ this .tools = new ArrayList <>();
150
+ for (Tool tool : tools ) {
151
+ if (tool != null ) {
152
+ this .tools .add (tool );
153
+ }
154
+ }
155
+ return this ;
156
+ }
157
+
158
+ /**
159
+ * Set the {@link Transport} layer for API calls in the generative model. It overrides the
160
+ * transport setting in {@link com.google.cloud.vertexai.VertexAI}
161
+ */
162
+ public Builder setTransport (Transport transport ) {
163
+ this .transport = transport ;
164
+ return this ;
165
+ }
166
+ }
167
+
45
168
/**
46
169
* Construct a GenerativeModel instance.
47
170
*
@@ -384,7 +507,8 @@ public GenerateContentResponse generateContent(
384
507
public GenerateContentResponse generateContent (
385
508
List <Content > contents , GenerationConfig generationConfig , List <SafetySetting > safetySettings )
386
509
throws IOException {
387
- Builder requestBuilder = GenerateContentRequest .newBuilder ().addAllContents (contents );
510
+ GenerateContentRequest .Builder requestBuilder =
511
+ GenerateContentRequest .newBuilder ().addAllContents (contents );
388
512
if (generationConfig != null ) {
389
513
requestBuilder .setGenerationConfig (generationConfig );
390
514
} else if (this .generationConfig != null ) {
@@ -395,6 +519,9 @@ public GenerateContentResponse generateContent(
395
519
} else if (this .safetySettings != null ) {
396
520
requestBuilder .addAllSafetySettings (this .safetySettings );
397
521
}
522
+ if (this .tools != null ) {
523
+ requestBuilder .addAllTools (this .tools );
524
+ }
398
525
return ResponseHandler .aggregateStreamIntoResponse (generateContentStream (requestBuilder ));
399
526
}
400
527
@@ -655,7 +782,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
655
782
public ResponseStream <GenerateContentResponse > generateContentStream (
656
783
List <Content > contents , GenerationConfig generationConfig , List <SafetySetting > safetySettings )
657
784
throws IOException {
658
- Builder requestBuilder = GenerateContentRequest .newBuilder ().addAllContents (contents );
785
+ GenerateContentRequest .Builder requestBuilder =
786
+ GenerateContentRequest .newBuilder ().addAllContents (contents );
659
787
if (generationConfig != null ) {
660
788
requestBuilder .setGenerationConfig (generationConfig );
661
789
} else if (this .generationConfig != null ) {
@@ -666,6 +794,9 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
666
794
} else if (this .safetySettings != null ) {
667
795
requestBuilder .addAllSafetySettings (this .safetySettings );
668
796
}
797
+ if (this .tools != null ) {
798
+ requestBuilder .addAllTools (this .tools );
799
+ }
669
800
return generateContentStream (requestBuilder );
670
801
}
671
802
@@ -678,8 +809,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
678
809
* com.google.cloud.vertexai.api.GenerateContentResponse}
679
810
* @throws IOException if an I/O error occurs while making the API call
680
811
*/
681
- private ResponseStream <GenerateContentResponse > generateContentStream (Builder requestBuilder )
682
- throws IOException {
812
+ private ResponseStream <GenerateContentResponse > generateContentStream (
813
+ GenerateContentRequest . Builder requestBuilder ) throws IOException {
683
814
GenerateContentRequest request = requestBuilder .setModel (this .resourceName ).build ();
684
815
ResponseStream <GenerateContentResponse > responseStream = null ;
685
816
if (this .transport == Transport .REST ) {
@@ -723,6 +854,16 @@ public void setSafetySettings(List<SafetySetting> safetySettings) {
723
854
}
724
855
}
725
856
857
+ /**
858
+ * Sets the value for {@link #getTools}, which will be used by default for generating response.
859
+ */
860
+ public void setTools (List <Tool > tools ) {
861
+ this .tools = new ArrayList <>();
862
+ for (Tool tool : tools ) {
863
+ this .tools .add (tool );
864
+ }
865
+ }
866
+
726
867
/**
727
868
* Sets the value for {@link #getTransport}, which defines the layer for API calls in this
728
869
* generative model.
@@ -760,6 +901,15 @@ public List<SafetySetting> getSafetySettings() {
760
901
}
761
902
}
762
903
904
+ /** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this generative model. */
905
+ public List <Tool > getTools () {
906
+ if (this .tools != null ) {
907
+ return Collections .unmodifiableList (this .tools );
908
+ } else {
909
+ return null ;
910
+ }
911
+ }
912
+
763
913
public ChatSession startChat () {
764
914
return new ChatSession (this );
765
915
}
0 commit comments