Skip to content

credentials, transport, grpc : add a call option to override the :authority header on a per-RPC basis #8068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ type AuthInfo interface {
AuthType() string
}

// AuthorityValidator validates the authority used to override the `:authority`
// header. A struct implementing AuthInfo should also implement
// AuthorityValidator if the credentials need to support per-RPC authority overrides.
type AuthorityValidator interface {
ValidateAuthority(authority string) error
}

// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
// and the caller should not close rawConn.
var ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
Expand Down
338 changes: 338 additions & 0 deletions credentials/credentials_ext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package credentials_test

import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/stubserver"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)

var cert tls.Certificate
var creds credentials.TransportCredentials

func init() {
var err error
cert, err = tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
log.Fatalf("failed to load key pair: %s", err)
}
creds, err = credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
if err != nil {
log.Fatalf("Failed to create credentials %v", err)
}
}

func authorityChecker(ctx context.Context, expectedAuthority string) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.InvalidArgument, "failed to parse metadata")
}
auths, ok := md[":authority"]
if !ok {
return nil, status.Error(codes.InvalidArgument, "no authority header")
}
if len(auths) != 1 {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths))
}
if auths[0] != expectedAuthority {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority))
}
return &testpb.Empty{}, nil
}

func checkUnavailableRPCError(t *testing.T, err error) {
t.Helper()
if err == nil {
t.Fatalf("EmptyCall() should fail")
}
s, ok := status.FromError(err)
if !ok {
t.Fatalf("unexpected error: %v", err)
}
if s.Code() != codes.Unavailable {
t.Fatalf("EmptyCall() = _, %v, want _, error code: %v", s.Code(), codes.Unavailable)
}
}

// Tests the grpc.CallAuthority option with TLS credentials. This test verifies
// that the provided authority is correctly propagated to the server when using TLS.
// It covers both positive and negative cases: correct authority and incorrect
// authority, expecting the RPC to fail with `UNAVAILABLE` status code error in
// the latter case.
func (s) TestAuthorityCallOptionsWithTLSCreds(t *testing.T) {
tests := []struct {
name string
expectedAuth string
wantRPCError bool
}{
{
name: "CorrectAuthorityWithTLS",
expectedAuth: "auth.test.example.com",
wantRPCError: false,
},
{
name: "IncorrectAuthorityWithTLS",
expectedAuth: "auth.example.com",
wantRPCError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, tt.expectedAuth)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm.. I'm not convinced about checking the authority header on the server handler, because when the validation fails, the RPC will not even make it to the server. Maybe, checking from the server handler is OK when you actually expect validation to succeed on the client and expect the RPC to reach the server.

You know the certs you are using for the server. So, you can specify an authority override on the client that you expect to work and one that you dont expect to work, because it will fail validation with the peer certificate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what you mean , when the authority is not correct , we want the RPC call to return UNAVAILABLE error in client and that is what we are checking. And when it passes the validation , we want it to correctly reach the server and check if the correct authority has reached. We do not expect to check authority on server even when it is wrong or fails validation?

},
}
if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

_, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.expectedAuth))
if tt.wantRPCError {
checkUnavailableRPCError(t, err)
} else if err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
})
}
}

func (s) TestTLSCredsWithNoAuthorityOverride(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test comment for this.

What scenario is this testing? If this test fails, what does it indicate about the authority override feature?

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, "x.test.example.com")
},
}
if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

_, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{})
if err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
}

// Tests the scenario where grpc.CallAuthority option is used with insecure credentials.
// The test verifies that the CallAuthority option is correctly passed even when
// insecure credentials are used.
func (s) TestAuthorityCallOptionWithInsecureCreds(t *testing.T) {
const expectedAuthority = "test.server.name"

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, expectedAuthority)
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(expectedAuthority)); err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
}

// FakeCredsNoAuthValidator is a test credential that does not implement AuthorityValidator.
type FakeCredsNoAuthValidator struct {
}

// ClientHandshake performs the client-side handshake.
func (c *FakeCredsNoAuthValidator) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, TestAuthInfo{}, nil
}

// TestAuthInfo implements the AuthInfo interface.
type TestAuthInfo struct{}

// AuthType returns the authentication type.
func (TestAuthInfo) AuthType() string { return "test" }

// Clone creates a copy of FakeCredsNoAuthValidator.
func (c *FakeCredsNoAuthValidator) Clone() credentials.TransportCredentials {
return c
}

// Info provides protocol information.
func (c *FakeCredsNoAuthValidator) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}

// OverrideServerName overrides the server name used for verification.
func (c *FakeCredsNoAuthValidator) OverrideServerName(serverName string) error {
return nil
}

// ServerHandshake performs the server-side handshake.
// Returns a test AuthInfo object to satisfy the interface requirements.
func (c *FakeCredsNoAuthValidator) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, TestAuthInfo{}, nil
}

// FakeCredsWithAuthValidator is a test credential that does not implement AuthorityValidator.
type FakeCredsWithAuthValidator struct {
}

// ClientHandshake performs the client-side handshake.
func (c *FakeCredsWithAuthValidator) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, FakeAuthInfo{}, nil
}

// TestAuthInfo implements the AuthInfo interface.
type FakeAuthInfo struct{}

// AuthType returns the authentication type.
func (FakeAuthInfo) AuthType() string { return "test" }

// AuthType returns the authentication type.
func (FakeAuthInfo) ValidateAuthority(authority string) error {
if authority == "auth.test.example.com" {
return nil
} else {
return fmt.Errorf("invalid authority")
}
}

// Clone creates a copy of FakeCredsWithAuthValidator.
func (c *FakeCredsWithAuthValidator) Clone() credentials.TransportCredentials {
return c
}

// Info provides protocol information.
func (c *FakeCredsWithAuthValidator) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}

// OverrideServerName overrides the server name used for verification.
func (c *FakeCredsWithAuthValidator) OverrideServerName(serverName string) error {
return nil
}

// ServerHandshake performs the server-side handshake.
// Returns a test AuthInfo object to satisfy the interface requirements.
func (c *FakeCredsWithAuthValidator) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, FakeAuthInfo{}, nil
}

// TestCorrectAuthorityWithCustomCreds tests the CallAuthority call option
// with custom credentials that implement AuthorityValidator and verifies
// it with both correct and incorrect authority override.
func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) {
tests := []struct {
name string
creds credentials.TransportCredentials
expectedAuth string
wantRPCError bool
}{
{
name: "CorrectAuthorityWithFakeCreds",
expectedAuth: "auth.test.example.com",
creds: &FakeCredsWithAuthValidator{},
wantRPCError: false,
},
{
name: "IncorrectAuthorityWithFakeCreds",
expectedAuth: "auth.example.com",
creds: &FakeCredsWithAuthValidator{},
wantRPCError: true,
},
{
name: "FakeCredsWithNoAuthValidator",
creds: &FakeCredsNoAuthValidator{},
expectedAuth: "auth.test.example.com",
wantRPCError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ss := stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, tt.expectedAuth)
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Failed to start stub server: %v", err)
}
defer ss.Stop()

// Create a gRPC client connection with FakeCredsWithAuthValidator.
clientConn, err := grpc.NewClient(ss.Address,
grpc.WithTransportCredentials(tt.creds))
if err != nil {
t.Fatalf("Failed to create gRPC client connection: %v", err)
}
defer clientConn.Close()

// Perform a test RPC with a specified call authority.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

_, err = testgrpc.NewTestServiceClient(clientConn).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.expectedAuth))
if tt.wantRPCError {
checkUnavailableRPCError(t, err)
} else if err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
})
}
}
4 changes: 4 additions & 0 deletions credentials/insecure/insecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func (info) AuthType() string {
return "insecure"
}

func (info) ValidateAuthority(_ string) error {
return nil
}

// insecureBundle implements an insecure bundle.
// An insecure bundle provides a thin wrapper around insecureTC to support
// the credentials.Bundle interface.
Expand Down
11 changes: 11 additions & 0 deletions credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ func (t TLSInfo) AuthType() string {
return "tls"
}

// ValidateAuthority validates the authority by checking it against the peer certificates.
func (t TLSInfo) ValidateAuthority(authority string) error {
var err error
for _, cert := range t.State.PeerCertificates {
if err = cert.VerifyHostname(authority); err == nil {
return nil
}
}
return fmt.Errorf("credentials: failed to verify authority %s", err)
}

// cipherSuiteLookup returns the string version of a TLS cipher suite ID.
func cipherSuiteLookup(cipherSuiteID uint16) string {
for _, s := range tls.CipherSuites() {
Expand Down
Loading
Loading