diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 0d731a50fd..4bd63441e5 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -56,6 +56,7 @@ import com.google.common.io.Files; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; +import io.grpc.CompositeChannelCredentials; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; @@ -69,6 +70,7 @@ import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.security.KeyStore; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -139,6 +141,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP @Nullable private final Boolean keepAliveWithoutCalls; private final ChannelPoolSettings channelPoolSettings; @Nullable private final Credentials credentials; + @Nullable private final CallCredentials mtlsS2ACallCredentials; @Nullable private final ChannelPrimer channelPrimer; @Nullable private final Boolean attemptDirectPath; @Nullable private final Boolean attemptDirectPathXds; @@ -146,7 +149,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP @VisibleForTesting final ImmutableMap directPathServiceConfig; @Nullable private final MtlsProvider mtlsProvider; @Nullable private final SecureSessionAgent s2aConfigProvider; - @Nullable private final List allowedHardBoundTokenTypes; + private final List allowedHardBoundTokenTypes; @VisibleForTesting final Map headersWithDuplicatesRemoved = new HashMap<>(); @Nullable @@ -188,6 +191,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) { this.channelPoolSettings = builder.channelPoolSettings; this.channelConfigurator = builder.channelConfigurator; this.credentials = builder.credentials; + this.mtlsS2ACallCredentials = builder.mtlsS2ACallCredentials; this.channelPrimer = builder.channelPrimer; this.attemptDirectPath = builder.attemptDirectPath; this.attemptDirectPathXds = builder.attemptDirectPathXds; @@ -648,6 +652,12 @@ private ManagedChannel createSingleChannel() throws IOException { } if (channelCredentials != null) { // Create the channel using S2A-secured channel credentials. + if (mtlsS2ACallCredentials != null) { + // Set {@code mtlsS2ACallCredentials} to be per-RPC call credentials, + // which will be used to fetch MTLS_S2A hard bound tokens from the metdata server. + channelCredentials = + CompositeChannelCredentials.create(channelCredentials, mtlsS2ACallCredentials); + } builder = Grpc.newChannelBuilder(endpoint, channelCredentials); } else { // Use default if we cannot initialize channel credentials via DCA or S2A. @@ -812,18 +822,20 @@ public static final class Builder { @Nullable private Boolean keepAliveWithoutCalls; @Nullable private ApiFunction channelConfigurator; @Nullable private Credentials credentials; + @Nullable private CallCredentials mtlsS2ACallCredentials; @Nullable private ChannelPrimer channelPrimer; private ChannelPoolSettings channelPoolSettings; @Nullable private Boolean attemptDirectPath; @Nullable private Boolean attemptDirectPathXds; @Nullable private Boolean allowNonDefaultServiceAccount; @Nullable private ImmutableMap directPathServiceConfig; - @Nullable private List allowedHardBoundTokenTypes; + private List allowedHardBoundTokenTypes; private Builder() { processorCount = Runtime.getRuntime().availableProcessors(); envProvider = System::getenv; channelPoolSettings = ChannelPoolSettings.staticallySized(1); + allowedHardBoundTokenTypes = new ArrayList<>(); } private Builder(InstantiatingGrpcChannelProvider provider) { @@ -841,11 +853,13 @@ private Builder(InstantiatingGrpcChannelProvider provider) { this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls; this.channelConfigurator = provider.channelConfigurator; this.credentials = provider.credentials; + this.mtlsS2ACallCredentials = provider.mtlsS2ACallCredentials; this.channelPrimer = provider.channelPrimer; this.channelPoolSettings = provider.channelPoolSettings; this.attemptDirectPath = provider.attemptDirectPath; this.attemptDirectPathXds = provider.attemptDirectPathXds; this.allowNonDefaultServiceAccount = provider.allowNonDefaultServiceAccount; + this.allowedHardBoundTokenTypes = provider.allowedHardBoundTokenTypes; this.directPathServiceConfig = provider.directPathServiceConfig; this.mtlsProvider = provider.mtlsProvider; this.s2aConfigProvider = provider.s2aConfigProvider; @@ -914,7 +928,10 @@ Builder setUseS2A(boolean useS2A) { */ @InternalApi public Builder setAllowHardBoundTokenTypes(List allowedValues) { - this.allowedHardBoundTokenTypes = allowedValues; + this.allowedHardBoundTokenTypes = + Preconditions.checkNotNull( + allowedValues, "List of allowed HardBoundTokenTypes cannot be null"); + ; return this; } @@ -1133,7 +1150,50 @@ public Builder setDirectPathServiceConfig(Map serviceConfig) { return this; } + boolean isMtlsS2AHardBoundTokensEnabled() { + // If S2A cannot be used, the list of allowed hard bound token types is empty or doesn't + // contain + // {@code HardBoundTokenTypes.MTLS_S2A}, the {@code credentials} are null or not of type + // {@code + // ComputeEngineCredentials} then {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens + // should + // not + // be used. {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens can only be used on MTLS + // channels established using S2A and when tokens from MDS (i.e {@code + // ComputeEngineCredentials} + // are being used. + if (!this.useS2A + || this.allowedHardBoundTokenTypes.isEmpty() + || this.credentials == null + || !(this.credentials instanceof ComputeEngineCredentials)) { + return false; + } + return allowedHardBoundTokenTypes.stream() + .anyMatch(val -> val.equals(HardBoundTokenTypes.MTLS_S2A)); + } + + CallCredentials createHardBoundTokensCallCredentials( + ComputeEngineCredentials.GoogleAuthTransport googleAuthTransport, + ComputeEngineCredentials.BindingEnforcement bindingEnforcement) { + // We only set scopes and HTTP transport factory from the original credentials because + // only those are used in gRPC CallCredentials to fetch request metadata. + return MoreCallCredentials.from( + ((ComputeEngineCredentials) this.credentials) + .toBuilder() + .setGoogleAuthTransport(googleAuthTransport) + .setBindingEnforcement(bindingEnforcement) + .build()); + } + public InstantiatingGrpcChannelProvider build() { + if (isMtlsS2AHardBoundTokensEnabled()) { + // Set a {@code ComputeEngineCredentials} instance to be per-RPC call credentials, + // which will be used to fetch MTLS_S2A hard bound tokens from the metdata server. + this.mtlsS2ACallCredentials = + createHardBoundTokensCallCredentials( + ComputeEngineCredentials.GoogleAuthTransport.MTLS, + ComputeEngineCredentials.BindingEnforcement.ON); + } InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider = new InstantiatingGrpcChannelProvider(this); instantiatingGrpcChannelProvider.removeApiKeyCredentialDuplicateHeaders(); diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java index 82738cae02..9540235b18 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java @@ -1103,6 +1103,79 @@ void createS2ASecuredChannelCredentials_returnsPlaintextToS2AS2AChannelCredentia InstantiatingGrpcChannelProvider.LOG.removeHandler(logHandler); } + @Test + void isMtlsS2AHardBoundTokensEnabled_useS2AFalse() { + InstantiatingGrpcChannelProvider.Builder providerBuilder = + InstantiatingGrpcChannelProvider.newBuilder() + .setUseS2A(false) + .setAllowHardBoundTokenTypes( + Collections.singletonList( + InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A)) + .setCredentials(computeEngineCredentials); + Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse(); + } + + @Test + void isMtlsS2AHardBoundTokensEnabled_hardBoundTokenTypesEmpty() { + InstantiatingGrpcChannelProvider.Builder providerBuilder = + InstantiatingGrpcChannelProvider.newBuilder() + .setUseS2A(true) + .setAllowHardBoundTokenTypes(new ArrayList<>()) + .setCredentials(computeEngineCredentials); + Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse(); + } + + @Test + void isMtlsS2AHardBoundTokensEnabled_nullCreds() { + InstantiatingGrpcChannelProvider.Builder providerBuilder = + InstantiatingGrpcChannelProvider.newBuilder() + .setUseS2A(true) + .setAllowHardBoundTokenTypes( + Collections.singletonList( + InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A)) + .setCredentials(null); + Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse(); + } + + @Test + void isMtlsS2AHardBoundTokensEnabled_notComputeEngineCreds() { + InstantiatingGrpcChannelProvider.Builder providerBuilder = + InstantiatingGrpcChannelProvider.newBuilder() + .setUseS2A(true) + .setAllowHardBoundTokenTypes( + Collections.singletonList( + InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A)) + .setCredentials(CloudShellCredentials.create(3000)); + Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse(); + } + + @Test + void isMtlsS2AHardBoundTokensEnabled_mtlsS2ANotInList() { + InstantiatingGrpcChannelProvider.Builder providerBuilder = + InstantiatingGrpcChannelProvider.newBuilder() + .setUseS2A(true) + .setAllowHardBoundTokenTypes( + Collections.singletonList( + InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS)) + .setCredentials(computeEngineCredentials); + Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse(); + } + + @Test + void isMtlsS2AHardBoundTokensEnabled_mtlsS2ATokenAllowedInList() { + List allowHardBoundTokenTypes = + new ArrayList<>(); + allowHardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A); + allowHardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS); + + InstantiatingGrpcChannelProvider.Builder providerBuilder = + InstantiatingGrpcChannelProvider.newBuilder() + .setUseS2A(true) + .setAllowHardBoundTokenTypes(allowHardBoundTokenTypes) + .setCredentials(computeEngineCredentials); + Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isTrue(); + } + private static class FakeLogHandler extends Handler { List records = new ArrayList<>();