Skip to content

Commit

Permalink
feat: Enable MTLS_S2A bound token by default for gRPC S2A enabled flo…
Browse files Browse the repository at this point in the history
…ws (#3591)

Similar to implementation for DirectPath in
#3572.

This is part of the experimental S2A feature (see #3400)
  • Loading branch information
rmehta19 authored Feb 5, 2025
1 parent 0e69784 commit 81e21f2
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import com.google.common.io.Files;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.CompositeChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
Expand All @@ -69,6 +70,7 @@
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -139,14 +141,15 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final Boolean keepAliveWithoutCalls;
private final ChannelPoolSettings channelPoolSettings;
@Nullable private final Credentials credentials;
@Nullable private final CallCredentials mtlsS2ACallCredentials;
@Nullable private final ChannelPrimer channelPrimer;
@Nullable private final Boolean attemptDirectPath;
@Nullable private final Boolean attemptDirectPathXds;
@Nullable private final Boolean allowNonDefaultServiceAccount;
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private final MtlsProvider mtlsProvider;
@Nullable private final SecureSessionAgent s2aConfigProvider;
@Nullable private final List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
private final List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
@VisibleForTesting final Map<String, String> headersWithDuplicatesRemoved = new HashMap<>();

@Nullable
Expand Down Expand Up @@ -188,6 +191,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.channelPoolSettings = builder.channelPoolSettings;
this.channelConfigurator = builder.channelConfigurator;
this.credentials = builder.credentials;
this.mtlsS2ACallCredentials = builder.mtlsS2ACallCredentials;
this.channelPrimer = builder.channelPrimer;
this.attemptDirectPath = builder.attemptDirectPath;
this.attemptDirectPathXds = builder.attemptDirectPathXds;
Expand Down Expand Up @@ -648,6 +652,12 @@ private ManagedChannel createSingleChannel() throws IOException {
}
if (channelCredentials != null) {
// Create the channel using S2A-secured channel credentials.
if (mtlsS2ACallCredentials != null) {
// Set {@code mtlsS2ACallCredentials} to be per-RPC call credentials,
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
channelCredentials =
CompositeChannelCredentials.create(channelCredentials, mtlsS2ACallCredentials);
}
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
// Use default if we cannot initialize channel credentials via DCA or S2A.
Expand Down Expand Up @@ -812,18 +822,20 @@ public static final class Builder {
@Nullable private Boolean keepAliveWithoutCalls;
@Nullable private ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
@Nullable private Credentials credentials;
@Nullable private CallCredentials mtlsS2ACallCredentials;
@Nullable private ChannelPrimer channelPrimer;
private ChannelPoolSettings channelPoolSettings;
@Nullable private Boolean attemptDirectPath;
@Nullable private Boolean attemptDirectPathXds;
@Nullable private Boolean allowNonDefaultServiceAccount;
@Nullable private ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
private List<HardBoundTokenTypes> allowedHardBoundTokenTypes;

private Builder() {
processorCount = Runtime.getRuntime().availableProcessors();
envProvider = System::getenv;
channelPoolSettings = ChannelPoolSettings.staticallySized(1);
allowedHardBoundTokenTypes = new ArrayList<>();
}

private Builder(InstantiatingGrpcChannelProvider provider) {
Expand All @@ -841,11 +853,13 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls;
this.channelConfigurator = provider.channelConfigurator;
this.credentials = provider.credentials;
this.mtlsS2ACallCredentials = provider.mtlsS2ACallCredentials;
this.channelPrimer = provider.channelPrimer;
this.channelPoolSettings = provider.channelPoolSettings;
this.attemptDirectPath = provider.attemptDirectPath;
this.attemptDirectPathXds = provider.attemptDirectPathXds;
this.allowNonDefaultServiceAccount = provider.allowNonDefaultServiceAccount;
this.allowedHardBoundTokenTypes = provider.allowedHardBoundTokenTypes;
this.directPathServiceConfig = provider.directPathServiceConfig;
this.mtlsProvider = provider.mtlsProvider;
this.s2aConfigProvider = provider.s2aConfigProvider;
Expand Down Expand Up @@ -914,7 +928,10 @@ Builder setUseS2A(boolean useS2A) {
*/
@InternalApi
public Builder setAllowHardBoundTokenTypes(List<HardBoundTokenTypes> allowedValues) {
this.allowedHardBoundTokenTypes = allowedValues;
this.allowedHardBoundTokenTypes =
Preconditions.checkNotNull(
allowedValues, "List of allowed HardBoundTokenTypes cannot be null");
;
return this;
}

Expand Down Expand Up @@ -1133,7 +1150,50 @@ public Builder setDirectPathServiceConfig(Map<String, ?> serviceConfig) {
return this;
}

boolean isMtlsS2AHardBoundTokensEnabled() {
// If S2A cannot be used, the list of allowed hard bound token types is empty or doesn't
// contain
// {@code HardBoundTokenTypes.MTLS_S2A}, the {@code credentials} are null or not of type
// {@code
// ComputeEngineCredentials} then {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens
// should
// not
// be used. {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens can only be used on MTLS
// channels established using S2A and when tokens from MDS (i.e {@code
// ComputeEngineCredentials}
// are being used.
if (!this.useS2A
|| this.allowedHardBoundTokenTypes.isEmpty()
|| this.credentials == null
|| !(this.credentials instanceof ComputeEngineCredentials)) {
return false;
}
return allowedHardBoundTokenTypes.stream()
.anyMatch(val -> val.equals(HardBoundTokenTypes.MTLS_S2A));
}

CallCredentials createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport googleAuthTransport,
ComputeEngineCredentials.BindingEnforcement bindingEnforcement) {
// We only set scopes and HTTP transport factory from the original credentials because
// only those are used in gRPC CallCredentials to fetch request metadata.
return MoreCallCredentials.from(
((ComputeEngineCredentials) this.credentials)
.toBuilder()
.setGoogleAuthTransport(googleAuthTransport)
.setBindingEnforcement(bindingEnforcement)
.build());
}

public InstantiatingGrpcChannelProvider build() {
if (isMtlsS2AHardBoundTokensEnabled()) {
// Set a {@code ComputeEngineCredentials} instance to be per-RPC call credentials,
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
this.mtlsS2ACallCredentials =
createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport.MTLS,
ComputeEngineCredentials.BindingEnforcement.ON);
}
InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider =
new InstantiatingGrpcChannelProvider(this);
instantiatingGrpcChannelProvider.removeApiKeyCredentialDuplicateHeaders();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,79 @@ void createS2ASecuredChannelCredentials_returnsPlaintextToS2AS2AChannelCredentia
InstantiatingGrpcChannelProvider.LOG.removeHandler(logHandler);
}

@Test
void isMtlsS2AHardBoundTokensEnabled_useS2AFalse() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(false)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(computeEngineCredentials);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_hardBoundTokenTypesEmpty() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(new ArrayList<>())
.setCredentials(computeEngineCredentials);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_nullCreds() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(null);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_notComputeEngineCreds() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(CloudShellCredentials.create(3000));
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_mtlsS2ANotInList() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS))
.setCredentials(computeEngineCredentials);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_mtlsS2ATokenAllowedInList() {
List<InstantiatingGrpcChannelProvider.HardBoundTokenTypes> allowHardBoundTokenTypes =
new ArrayList<>();
allowHardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A);
allowHardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS);

InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(allowHardBoundTokenTypes)
.setCredentials(computeEngineCredentials);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isTrue();
}

private static class FakeLogHandler extends Handler {

List<LogRecord> records = new ArrayList<>();
Expand Down

0 comments on commit 81e21f2

Please sign in to comment.