1
1
/*
2
- * Copyright 2023 Google LLC
2
+ * Copyright 2024 Google LLC
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
16
16
17
17
package com .google .cloud .vertexai ;
18
18
19
+ import com .google .api .gax .core .CredentialsProvider ;
19
20
import com .google .api .gax .core .FixedCredentialsProvider ;
20
- import com .google .auth .oauth2 .GoogleCredentials ;
21
+ import com .google .api .gax .core .GaxProperties ;
22
+ import com .google .api .gax .core .GoogleCredentialsProvider ;
23
+ import com .google .api .gax .rpc .FixedHeaderProvider ;
24
+ import com .google .api .gax .rpc .HeaderProvider ;
25
+ import com .google .auth .Credentials ;
26
+ import com .google .cloud .vertexai .api .LlmUtilityServiceClient ;
27
+ import com .google .cloud .vertexai .api .LlmUtilityServiceSettings ;
21
28
import com .google .cloud .vertexai .api .PredictionServiceClient ;
22
29
import com .google .cloud .vertexai .api .PredictionServiceSettings ;
23
- import com .google .cloud .vertexai .api .stub .PredictionServiceStubSettings ;
24
30
import java .io .IOException ;
25
- import java .util .List ;
31
+ import java .util .Arrays ;
26
32
import java .util .logging .Level ;
27
33
import java .util .logging .Logger ;
28
34
@@ -44,11 +50,14 @@ public class VertexAI implements AutoCloseable {
44
50
45
51
private final String projectId ;
46
52
private final String location ;
47
- private final GoogleCredentials credentials ;
53
+ private String apiEndpoint ;
54
+ private CredentialsProvider credentialsProvider = null ;
48
55
private Transport transport = Transport .GRPC ;
49
56
// The clients will be instantiated lazily
50
57
private PredictionServiceClient predictionServiceClient = null ;
51
58
private PredictionServiceClient predictionServiceRestClient = null ;
59
+ private LlmUtilityServiceClient llmUtilityClient = null ;
60
+ private LlmUtilityServiceClient llmUtilityRestClient = null ;
52
61
53
62
/**
54
63
* Construct a VertexAI instance with custom credentials.
@@ -57,10 +66,11 @@ public class VertexAI implements AutoCloseable {
57
66
* @param location the default location to use when making API calls
58
67
* @param credentials the custom credentials to use when making API calls
59
68
*/
60
- public VertexAI (String projectId , String location , GoogleCredentials credentials ) {
69
+ public VertexAI (String projectId , String location , Credentials credentials ) {
61
70
this .projectId = projectId ;
62
71
this .location = location ;
63
- this .credentials = credentials ;
72
+ this .apiEndpoint = String .format ("%s-aiplatform.googleapis.com" , this .location );
73
+ this .credentialsProvider = FixedCredentialsProvider .create (credentials );
64
74
}
65
75
66
76
/**
@@ -71,8 +81,7 @@ public VertexAI(String projectId, String location, GoogleCredentials credentials
71
81
* @param transport the default {@link Transport} layer to use to send API requests
72
82
* @param credentials the default custom credentials to use when making API calls
73
83
*/
74
- public VertexAI (
75
- String projectId , String location , Transport transport , GoogleCredentials credentials ) {
84
+ public VertexAI (String projectId , String location , Transport transport , Credentials credentials ) {
76
85
this (projectId , location , credentials );
77
86
this .transport = transport ;
78
87
}
@@ -82,24 +91,22 @@ public VertexAI(
82
91
*
83
92
* @param projectId the default project to use when making API calls
84
93
* @param location the default location to use when making API calls
85
- * @param scopes collection of scopes in the default credentials
94
+ * @param scopes collection of scopes in the default credentials. Make sure you have specified
95
+ * "https://www.googleapis.com/auth/cloud-platform" scope to access resources on Vertex AI.
86
96
*/
87
97
public VertexAI (String projectId , String location , String ... scopes ) throws IOException {
88
- // Disable the warning message logged in getApplicationDefault
89
- Logger logger = Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
90
- Level previousLevel = logger .getLevel ();
91
- logger .setLevel (Level .SEVERE );
92
- List <String > defaultScopes =
93
- PredictionServiceStubSettings .defaultCredentialsProviderBuilder ().getScopesToApply ();
94
- GoogleCredentials credentials =
98
+ CredentialsProvider credentialsProvider =
95
99
scopes .length == 0
96
- ? GoogleCredentials .getApplicationDefault ().createScoped (defaultScopes )
97
- : GoogleCredentials .getApplicationDefault ().createScoped (scopes );
98
- logger .setLevel (previousLevel );
100
+ ? null
101
+ : GoogleCredentialsProvider .newBuilder ()
102
+ .setScopesToApply (Arrays .asList (scopes ))
103
+ .setUseJwtAccessWithScope (true )
104
+ .build ();
99
105
100
106
this .projectId = projectId ;
101
107
this .location = location ;
102
- this .credentials = credentials ;
108
+ this .apiEndpoint = String .format ("%s-aiplatform.googleapis.com" , this .location );
109
+ this .credentialsProvider = credentialsProvider ;
103
110
}
104
111
105
112
/**
@@ -131,28 +138,72 @@ public String getLocation() {
131
138
return this .location ;
132
139
}
133
140
141
+ /** Returns the default endpoint to use when making API calls. */
142
+ public String getApiEndpoint () {
143
+ return this .apiEndpoint ;
144
+ }
145
+
134
146
/** Returns the default credentials to use when making API calls. */
135
- public GoogleCredentials getCredentials () {
136
- return credentials ;
147
+ public Credentials getCredentials () throws IOException {
148
+ return credentialsProvider . getCredentials () ;
137
149
}
138
150
139
151
/** Sets the value for {@link #getTransport()}. */
140
152
public void setTransport (Transport transport ) {
141
153
this .transport = transport ;
142
154
}
143
155
156
+ /** Sets the value for {@link #getApiEndpoint()}. */
157
+ public void setApiEndpoint (String apiEndpoint ) {
158
+ this .apiEndpoint = apiEndpoint ;
159
+
160
+ if (this .predictionServiceClient != null ) {
161
+ this .predictionServiceClient .close ();
162
+ this .predictionServiceClient = null ;
163
+ }
164
+
165
+ if (this .predictionServiceRestClient != null ) {
166
+ this .predictionServiceRestClient .close ();
167
+ this .predictionServiceRestClient = null ;
168
+ }
169
+
170
+ if (this .llmUtilityClient != null ) {
171
+ this .llmUtilityClient .close ();
172
+ this .llmUtilityClient = null ;
173
+ }
174
+
175
+ if (this .llmUtilityRestClient != null ) {
176
+ this .llmUtilityRestClient .close ();
177
+ this .llmUtilityRestClient = null ;
178
+ }
179
+ }
180
+
144
181
/**
145
182
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
146
183
* first prediction API call is made.
147
184
*/
148
185
public PredictionServiceClient getPredictionServiceClient () throws IOException {
149
186
if (predictionServiceClient == null ) {
150
- PredictionServiceSettings settings =
151
- PredictionServiceSettings .newBuilder ()
152
- .setEndpoint (String .format ("%s-aiplatform.googleapis.com:443" , this .location ))
153
- .setCredentialsProvider (FixedCredentialsProvider .create (this .credentials ))
154
- .build ();
155
- predictionServiceClient = PredictionServiceClient .create (settings );
187
+ PredictionServiceSettings .Builder settingsBuilder = PredictionServiceSettings .newBuilder ();
188
+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
189
+ if (this .credentialsProvider != null ) {
190
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
191
+ }
192
+ HeaderProvider headerProvider =
193
+ FixedHeaderProvider .create (
194
+ "user-agent" ,
195
+ String .format (
196
+ "%s/%s" ,
197
+ Constants .USER_AGENT_HEADER ,
198
+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
199
+ settingsBuilder .setHeaderProvider (headerProvider );
200
+ // Disable the warning message logged in getApplicationDefault
201
+ Logger defaultCredentialsProviderLogger =
202
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
203
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
204
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
205
+ predictionServiceClient = PredictionServiceClient .create (settingsBuilder .build ());
206
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
156
207
}
157
208
return predictionServiceClient ;
158
209
}
@@ -163,16 +214,92 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException {
163
214
*/
164
215
public PredictionServiceClient getPredictionServiceRestClient () throws IOException {
165
216
if (predictionServiceRestClient == null ) {
166
- PredictionServiceSettings settings =
167
- PredictionServiceSettings .newHttpJsonBuilder ()
168
- .setEndpoint (String .format ("%s-aiplatform.googleapis.com:443" , this .location ))
169
- .setCredentialsProvider (FixedCredentialsProvider .create (this .credentials ))
170
- .build ();
171
- predictionServiceRestClient = PredictionServiceClient .create (settings );
217
+ PredictionServiceSettings .Builder settingsBuilder =
218
+ PredictionServiceSettings .newHttpJsonBuilder ();
219
+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
220
+ if (this .credentialsProvider != null ) {
221
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
222
+ }
223
+ HeaderProvider headerProvider =
224
+ FixedHeaderProvider .create (
225
+ "user-agent" ,
226
+ String .format (
227
+ "%s/%s" ,
228
+ Constants .USER_AGENT_HEADER ,
229
+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
230
+ settingsBuilder .setHeaderProvider (headerProvider );
231
+ // Disable the warning message logged in getApplicationDefault
232
+ Logger defaultCredentialsProviderLogger =
233
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
234
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
235
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
236
+ predictionServiceRestClient = PredictionServiceClient .create (settingsBuilder .build ());
237
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
172
238
}
173
239
return predictionServiceRestClient ;
174
240
}
175
241
242
+ /**
243
+ * Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
244
+ * first prediction API call is made.
245
+ */
246
+ public LlmUtilityServiceClient getLlmUtilityClient () throws IOException {
247
+ if (llmUtilityClient == null ) {
248
+ LlmUtilityServiceSettings .Builder settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
249
+ settingsBuilder .setEndpoint (String .format ("%s-aiplatform.googleapis.com:443" , this .location ));
250
+ if (this .credentialsProvider != null ) {
251
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
252
+ }
253
+ HeaderProvider headerProvider =
254
+ FixedHeaderProvider .create (
255
+ "user-agent" ,
256
+ String .format (
257
+ "%s/%s" ,
258
+ Constants .USER_AGENT_HEADER ,
259
+ GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
260
+ settingsBuilder .setHeaderProvider (headerProvider );
261
+ // Disable the warning message logged in getApplicationDefault
262
+ Logger defaultCredentialsProviderLogger =
263
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
264
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
265
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
266
+ llmUtilityClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
267
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
268
+ }
269
+ return llmUtilityClient ;
270
+ }
271
+
272
+ /**
273
+ * Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
274
+ * first prediction API call is made.
275
+ */
276
+ public LlmUtilityServiceClient getLlmUtilityRestClient () throws IOException {
277
+ if (llmUtilityRestClient == null ) {
278
+ LlmUtilityServiceSettings .Builder settingsBuilder =
279
+ LlmUtilityServiceSettings .newHttpJsonBuilder ();
280
+ settingsBuilder .setEndpoint (String .format ("%s-aiplatform.googleapis.com:443" , this .location ));
281
+ if (this .credentialsProvider != null ) {
282
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
283
+ }
284
+ HeaderProvider headerProvider =
285
+ FixedHeaderProvider .create (
286
+ "user-agent" ,
287
+ String .format (
288
+ "%s/%s" ,
289
+ Constants .USER_AGENT_HEADER ,
290
+ GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
291
+ settingsBuilder .setHeaderProvider (headerProvider );
292
+ // Disable the warning message logged in getApplicationDefault
293
+ Logger defaultCredentialsProviderLogger =
294
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
295
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
296
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
297
+ llmUtilityRestClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
298
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
299
+ }
300
+ return llmUtilityRestClient ;
301
+ }
302
+
176
303
/** Closes the VertexAI instance together with all its instantiated clients. */
177
304
@ Override
178
305
public void close () {
@@ -182,5 +309,11 @@ public void close() {
182
309
if (predictionServiceRestClient != null ) {
183
310
predictionServiceRestClient .close ();
184
311
}
312
+ if (llmUtilityClient != null ) {
313
+ llmUtilityClient .close ();
314
+ }
315
+ if (llmUtilityRestClient != null ) {
316
+ llmUtilityRestClient .close ();
317
+ }
185
318
}
186
319
}
0 commit comments