Skip to content

Commit da6eea8

Browse files
authored
feat: [vertexai] sync the vertexai to google3 sot (#10225)
It includes the following changes: chore: remove the term "url" and replace with "uri" chore: add user-agent header in Java SDK feat: support "publishers/google/models/" prefix feat: add apiEndpoint in VertexAI chore: change the implementation of countTokens. chore: switch to v1 gapic clients. chore: remove URL support in from MultiModalData chore: remove the logic to throw an exception when getting multi-modal data in chat
1 parent 57b0587 commit da6eea8

File tree

318 files changed

+21341
-17113
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

318 files changed

+21341
-17113
lines changed

java-vertexai/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ import java.util.Arrays;
192192
import java.util.List;
193193

194194
public class Main {
195-
private static final String PROJECT_ID = "cloud-llm-preview1";
196-
private static final String LOCATION = "us-central1";
195+
private static final String PROJECT_ID = <your project id>;
196+
private static final String LOCATION = <location>;
197197
private static final String MODEL_NAME = "gemini-pro";
198198

199199
public static void main(String[] args) throws IOException {

java-vertexai/google-cloud-vertexai-bom/pom.xml

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
</dependency>
3232
<dependency>
3333
<groupId>com.google.api.grpc</groupId>
34-
<artifactId>grpc-google-cloud-vertexai-v1beta1</artifactId>
35-
<version>0.3.0-SNAPSHOT</version><!-- {x-version-update:grpc-google-cloud-vertexai-v1beta1:current} -->
34+
<artifactId>grpc-google-cloud-vertexai-v1</artifactId>
35+
<version>0.3.0-SNAPSHOT</version><!-- {x-version-update:grpc-google-cloud-vertexai-v1:current} -->
3636
</dependency>
3737
<dependency>
3838
<groupId>com.google.api.grpc</groupId>
39-
<artifactId>proto-google-cloud-vertexai-v1beta1</artifactId>
40-
<version>0.3.0-SNAPSHOT</version><!-- {x-version-update:proto-google-cloud-vertexai-v1beta1:current} -->
39+
<artifactId>proto-google-cloud-vertexai-v1</artifactId>
40+
<version>0.3.0-SNAPSHOT</version><!-- {x-version-update:proto-google-cloud-vertexai-v1:current} -->
4141
</dependency>
4242
</dependencies>
4343
</dependencyManagement>

java-vertexai/google-cloud-vertexai/pom.xml

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
<dependency>
4545
<groupId>com.google.api.grpc</groupId>
46-
<artifactId>proto-google-cloud-vertexai-v1beta1</artifactId>
46+
<artifactId>proto-google-cloud-vertexai-v1</artifactId>
4747
</dependency>
4848
<dependency>
4949
<groupId>com.google.guava</groupId>
@@ -97,7 +97,7 @@
9797

9898
<dependency>
9999
<groupId>com.google.api.grpc</groupId>
100-
<artifactId>grpc-google-cloud-vertexai-v1beta1</artifactId>
100+
<artifactId>grpc-google-cloud-vertexai-v1</artifactId>
101101
<scope>test</scope>
102102
</dependency>
103103
<!-- Need testing utility classes for generated gRPC clients tests -->
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.vertexai;
18+
19+
/** A class that holds all constants for vertexai. */
20+
public final class Constants {
21+
// Constants for VertexAI class
22+
public static final String USER_AGENT_HEADER = "model-builder";
23+
24+
private Constants() {}
25+
}

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Transport.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 Google LLC
2+
* Copyright 2024 Google LLC
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java

+168-35
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 Google LLC
2+
* Copyright 2024 Google LLC
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,13 +16,19 @@
1616

1717
package com.google.cloud.vertexai;
1818

19+
import com.google.api.gax.core.CredentialsProvider;
1920
import com.google.api.gax.core.FixedCredentialsProvider;
20-
import com.google.auth.oauth2.GoogleCredentials;
21+
import com.google.api.gax.core.GaxProperties;
22+
import com.google.api.gax.core.GoogleCredentialsProvider;
23+
import com.google.api.gax.rpc.FixedHeaderProvider;
24+
import com.google.api.gax.rpc.HeaderProvider;
25+
import com.google.auth.Credentials;
26+
import com.google.cloud.vertexai.api.LlmUtilityServiceClient;
27+
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
2128
import com.google.cloud.vertexai.api.PredictionServiceClient;
2229
import com.google.cloud.vertexai.api.PredictionServiceSettings;
23-
import com.google.cloud.vertexai.api.stub.PredictionServiceStubSettings;
2430
import java.io.IOException;
25-
import java.util.List;
31+
import java.util.Arrays;
2632
import java.util.logging.Level;
2733
import java.util.logging.Logger;
2834

@@ -44,11 +50,14 @@ public class VertexAI implements AutoCloseable {
4450

4551
private final String projectId;
4652
private final String location;
47-
private final GoogleCredentials credentials;
53+
private String apiEndpoint;
54+
private CredentialsProvider credentialsProvider = null;
4855
private Transport transport = Transport.GRPC;
4956
// The clients will be instantiated lazily
5057
private PredictionServiceClient predictionServiceClient = null;
5158
private PredictionServiceClient predictionServiceRestClient = null;
59+
private LlmUtilityServiceClient llmUtilityClient = null;
60+
private LlmUtilityServiceClient llmUtilityRestClient = null;
5261

5362
/**
5463
* Construct a VertexAI instance with custom credentials.
@@ -57,10 +66,11 @@ public class VertexAI implements AutoCloseable {
5766
* @param location the default location to use when making API calls
5867
* @param credentials the custom credentials to use when making API calls
5968
*/
60-
public VertexAI(String projectId, String location, GoogleCredentials credentials) {
69+
public VertexAI(String projectId, String location, Credentials credentials) {
6170
this.projectId = projectId;
6271
this.location = location;
63-
this.credentials = credentials;
72+
this.apiEndpoint = String.format("%s-aiplatform.googleapis.com", this.location);
73+
this.credentialsProvider = FixedCredentialsProvider.create(credentials);
6474
}
6575

6676
/**
@@ -71,8 +81,7 @@ public VertexAI(String projectId, String location, GoogleCredentials credentials
7181
* @param transport the default {@link Transport} layer to use to send API requests
7282
* @param credentials the default custom credentials to use when making API calls
7383
*/
74-
public VertexAI(
75-
String projectId, String location, Transport transport, GoogleCredentials credentials) {
84+
public VertexAI(String projectId, String location, Transport transport, Credentials credentials) {
7685
this(projectId, location, credentials);
7786
this.transport = transport;
7887
}
@@ -82,24 +91,22 @@ public VertexAI(
8291
*
8392
* @param projectId the default project to use when making API calls
8493
* @param location the default location to use when making API calls
85-
* @param scopes collection of scopes in the default credentials
94+
* @param scopes collection of scopes in the default credentials. Make sure you have specified
95+
* "https://www.googleapis.com/auth/cloud-platform" scope to access resources on Vertex AI.
8696
*/
8797
public VertexAI(String projectId, String location, String... scopes) throws IOException {
88-
// Disable the warning message logged in getApplicationDefault
89-
Logger logger = Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
90-
Level previousLevel = logger.getLevel();
91-
logger.setLevel(Level.SEVERE);
92-
List<String> defaultScopes =
93-
PredictionServiceStubSettings.defaultCredentialsProviderBuilder().getScopesToApply();
94-
GoogleCredentials credentials =
98+
CredentialsProvider credentialsProvider =
9599
scopes.length == 0
96-
? GoogleCredentials.getApplicationDefault().createScoped(defaultScopes)
97-
: GoogleCredentials.getApplicationDefault().createScoped(scopes);
98-
logger.setLevel(previousLevel);
100+
? null
101+
: GoogleCredentialsProvider.newBuilder()
102+
.setScopesToApply(Arrays.asList(scopes))
103+
.setUseJwtAccessWithScope(true)
104+
.build();
99105

100106
this.projectId = projectId;
101107
this.location = location;
102-
this.credentials = credentials;
108+
this.apiEndpoint = String.format("%s-aiplatform.googleapis.com", this.location);
109+
this.credentialsProvider = credentialsProvider;
103110
}
104111

105112
/**
@@ -131,28 +138,72 @@ public String getLocation() {
131138
return this.location;
132139
}
133140

141+
/** Returns the default endpoint to use when making API calls. */
142+
public String getApiEndpoint() {
143+
return this.apiEndpoint;
144+
}
145+
134146
/** Returns the default credentials to use when making API calls. */
135-
public GoogleCredentials getCredentials() {
136-
return credentials;
147+
public Credentials getCredentials() throws IOException {
148+
return credentialsProvider.getCredentials();
137149
}
138150

139151
/** Sets the value for {@link #getTransport()}. */
140152
public void setTransport(Transport transport) {
141153
this.transport = transport;
142154
}
143155

156+
/** Sets the value for {@link #getApiEndpoint()}. */
157+
public void setApiEndpoint(String apiEndpoint) {
158+
this.apiEndpoint = apiEndpoint;
159+
160+
if (this.predictionServiceClient != null) {
161+
this.predictionServiceClient.close();
162+
this.predictionServiceClient = null;
163+
}
164+
165+
if (this.predictionServiceRestClient != null) {
166+
this.predictionServiceRestClient.close();
167+
this.predictionServiceRestClient = null;
168+
}
169+
170+
if (this.llmUtilityClient != null) {
171+
this.llmUtilityClient.close();
172+
this.llmUtilityClient = null;
173+
}
174+
175+
if (this.llmUtilityRestClient != null) {
176+
this.llmUtilityRestClient.close();
177+
this.llmUtilityRestClient = null;
178+
}
179+
}
180+
144181
/**
145182
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
146183
* first prediction API call is made.
147184
*/
148185
public PredictionServiceClient getPredictionServiceClient() throws IOException {
149186
if (predictionServiceClient == null) {
150-
PredictionServiceSettings settings =
151-
PredictionServiceSettings.newBuilder()
152-
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location))
153-
.setCredentialsProvider(FixedCredentialsProvider.create(this.credentials))
154-
.build();
155-
predictionServiceClient = PredictionServiceClient.create(settings);
187+
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
188+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
189+
if (this.credentialsProvider != null) {
190+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
191+
}
192+
HeaderProvider headerProvider =
193+
FixedHeaderProvider.create(
194+
"user-agent",
195+
String.format(
196+
"%s/%s",
197+
Constants.USER_AGENT_HEADER,
198+
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
199+
settingsBuilder.setHeaderProvider(headerProvider);
200+
// Disable the warning message logged in getApplicationDefault
201+
Logger defaultCredentialsProviderLogger =
202+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
203+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
204+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
205+
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
206+
defaultCredentialsProviderLogger.setLevel(previousLevel);
156207
}
157208
return predictionServiceClient;
158209
}
@@ -163,16 +214,92 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException {
163214
*/
164215
public PredictionServiceClient getPredictionServiceRestClient() throws IOException {
165216
if (predictionServiceRestClient == null) {
166-
PredictionServiceSettings settings =
167-
PredictionServiceSettings.newHttpJsonBuilder()
168-
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location))
169-
.setCredentialsProvider(FixedCredentialsProvider.create(this.credentials))
170-
.build();
171-
predictionServiceRestClient = PredictionServiceClient.create(settings);
217+
PredictionServiceSettings.Builder settingsBuilder =
218+
PredictionServiceSettings.newHttpJsonBuilder();
219+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
220+
if (this.credentialsProvider != null) {
221+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
222+
}
223+
HeaderProvider headerProvider =
224+
FixedHeaderProvider.create(
225+
"user-agent",
226+
String.format(
227+
"%s/%s",
228+
Constants.USER_AGENT_HEADER,
229+
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
230+
settingsBuilder.setHeaderProvider(headerProvider);
231+
// Disable the warning message logged in getApplicationDefault
232+
Logger defaultCredentialsProviderLogger =
233+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
234+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
235+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
236+
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
237+
defaultCredentialsProviderLogger.setLevel(previousLevel);
172238
}
173239
return predictionServiceRestClient;
174240
}
175241

242+
/**
243+
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
244+
* first prediction API call is made.
245+
*/
246+
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
247+
if (llmUtilityClient == null) {
248+
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
249+
settingsBuilder.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location));
250+
if (this.credentialsProvider != null) {
251+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
252+
}
253+
HeaderProvider headerProvider =
254+
FixedHeaderProvider.create(
255+
"user-agent",
256+
String.format(
257+
"%s/%s",
258+
Constants.USER_AGENT_HEADER,
259+
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
260+
settingsBuilder.setHeaderProvider(headerProvider);
261+
// Disable the warning message logged in getApplicationDefault
262+
Logger defaultCredentialsProviderLogger =
263+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
264+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
265+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
266+
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
267+
defaultCredentialsProviderLogger.setLevel(previousLevel);
268+
}
269+
return llmUtilityClient;
270+
}
271+
272+
/**
273+
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
274+
* first prediction API call is made.
275+
*/
276+
public LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
277+
if (llmUtilityRestClient == null) {
278+
LlmUtilityServiceSettings.Builder settingsBuilder =
279+
LlmUtilityServiceSettings.newHttpJsonBuilder();
280+
settingsBuilder.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location));
281+
if (this.credentialsProvider != null) {
282+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
283+
}
284+
HeaderProvider headerProvider =
285+
FixedHeaderProvider.create(
286+
"user-agent",
287+
String.format(
288+
"%s/%s",
289+
Constants.USER_AGENT_HEADER,
290+
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
291+
settingsBuilder.setHeaderProvider(headerProvider);
292+
// Disable the warning message logged in getApplicationDefault
293+
Logger defaultCredentialsProviderLogger =
294+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
295+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
296+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
297+
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
298+
defaultCredentialsProviderLogger.setLevel(previousLevel);
299+
}
300+
return llmUtilityRestClient;
301+
}
302+
176303
/** Closes the VertexAI instance together with all its instantiated clients. */
177304
@Override
178305
public void close() {
@@ -182,5 +309,11 @@ public void close() {
182309
if (predictionServiceRestClient != null) {
183310
predictionServiceRestClient.close();
184311
}
312+
if (llmUtilityClient != null) {
313+
llmUtilityClient.close();
314+
}
315+
if (llmUtilityRestClient != null) {
316+
llmUtilityRestClient.close();
317+
}
185318
}
186319
}

0 commit comments

Comments
 (0)