Skip to content

Commit 14bc5c4

Browse files
authored
Support proxy in case kms-plugin can not access key vault (#119)
1 parent 14cd0a7 commit 14bc5c4

File tree

7 files changed

+151
-39
lines changed

7 files changed

+151
-39
lines changed

cmd/server/main.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ var (
4343
healthzTimeout = flag.Duration("healthz-timeout", 20*time.Second, "RPC timeout for health check")
4444
metricsBackend = flag.String("metrics-backend", "prometheus", "Backend used for metrics")
4545
metricsAddress = flag.String("metrics-addr", "8095", "The address the metric endpoint binds to")
46+
47+
proxyMode = flag.Bool("proxy-mode", false, "Proxy mode")
48+
proxyAddress = flag.String("proxy-address", "", "proxy address")
49+
proxyPort = flag.Int("proxy-port", 7788, "port for proxy")
4650
)
4751

4852
func main() {
@@ -68,7 +72,7 @@ func main() {
6872
}
6973

7074
klog.InfoS("Starting KeyManagementServiceServer service", "version", version.BuildVersion, "buildDate", version.BuildDate)
71-
kmsServer, err := plugin.New(ctx, *configFilePath, *keyvaultName, *keyName, *keyVersion)
75+
kmsServer, err := plugin.New(ctx, *configFilePath, *keyvaultName, *keyName, *keyVersion, *proxyMode, *proxyAddress, *proxyPort)
7276
if err != nil {
7377
klog.Fatalf("failed to create server, error: %v", err)
7478
}

pkg/auth/auth.go

+35-5
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ import (
99
"crypto/rsa"
1010
"crypto/x509"
1111
"fmt"
12+
"net/http"
1213
"os"
1314
"regexp"
1415
"strings"
1516

1617
"github.com/Azure/kubernetes-kms/pkg/config"
18+
"github.com/Azure/kubernetes-kms/pkg/consts"
1719

1820
"github.com/Azure/go-autorest/autorest"
1921
"github.com/Azure/go-autorest/autorest/adal"
@@ -23,9 +25,9 @@ import (
2325
)
2426

2527
// GetKeyvaultToken() returns token for Keyvault endpoint
26-
func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment) (authorizer autorest.Authorizer, err error) {
28+
func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment, proxyMode bool) (authorizer autorest.Authorizer, err error) {
2729
kvEndPoint := strings.TrimSuffix(env.KeyVaultEndpoint, "/")
28-
servicePrincipalToken, err := GetServicePrincipalToken(config, env.ActiveDirectoryEndpoint, kvEndPoint)
30+
servicePrincipalToken, err := GetServicePrincipalToken(config, env.ActiveDirectoryEndpoint, kvEndPoint, proxyMode)
2931
if err != nil {
3032
return nil, err
3133
}
@@ -34,7 +36,7 @@ func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment) (autho
3436
}
3537

3638
// GetServicePrincipalToken creates a new service principal token based on the configuration
37-
func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource string) (adal.OAuthTokenProvider, error) {
39+
func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource string, proxyMode bool) (adal.OAuthTokenProvider, error) {
3840
oauthConfig, err := adal.NewOAuthConfig(aadEndpoint, config.TenantID)
3941
if err != nil {
4042
return nil, fmt.Errorf("failed to create OAuth config, error: %v", err)
@@ -64,11 +66,18 @@ func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource
6466
klog.V(2).InfoS("azure: using client_id+client_secret to retrieve access token",
6567
"clientID", redactClientCredentials(config.ClientID), "clientSecret", redactClientCredentials(config.ClientSecret))
6668

67-
return adal.NewServicePrincipalToken(
69+
spt, err := adal.NewServicePrincipalToken(
6870
*oauthConfig,
6971
config.ClientID,
7072
config.ClientSecret,
7173
resource)
74+
if err != nil {
75+
return nil, err
76+
}
77+
if proxyMode {
78+
return addTargetTypeHeader(spt), nil
79+
}
80+
return spt, nil
7281
}
7382

7483
if len(config.AADClientCertPath) > 0 && len(config.AADClientCertPassword) > 0 {
@@ -81,12 +90,19 @@ func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource
8190
if err != nil {
8291
return nil, fmt.Errorf("failed to decode the client certificate, error: %v", err)
8392
}
84-
return adal.NewServicePrincipalTokenFromCertificate(
93+
spt, err := adal.NewServicePrincipalTokenFromCertificate(
8594
*oauthConfig,
8695
config.ClientID,
8796
certificate,
8897
privateKey,
8998
resource)
99+
if err != nil {
100+
return nil, err
101+
}
102+
if proxyMode {
103+
return addTargetTypeHeader(spt), nil
104+
}
105+
return spt, nil
90106
}
91107

92108
return nil, fmt.Errorf("no credentials provided for accessing keyvault")
@@ -124,3 +140,17 @@ func redactClientCredentials(sensitiveString string) string {
124140
r, _ := regexp.Compile(`^(\S{4})(\S|\s)*(\S{4})$`)
125141
return r.ReplaceAllString(sensitiveString, "$1##### REDACTED #####$3")
126142
}
143+
144+
// addTargetTypeHeader adds the target header if proxy mode is enabled
145+
func addTargetTypeHeader(spt *adal.ServicePrincipalToken) *adal.ServicePrincipalToken {
146+
spt.SetSender(autorest.CreateSender(
147+
(func() autorest.SendDecorator {
148+
return func(s autorest.Sender) autorest.Sender {
149+
return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) {
150+
r.Header.Set(consts.RequestHeaderTargetType, consts.TargetTypeAzureActiveDirectory)
151+
return s.Do(r)
152+
})
153+
}
154+
})()))
155+
return spt
156+
}

pkg/auth/auth_test.go

+14-8
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ func TestRedactClientCredentials(t *testing.T) {
6161

6262
func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
6363
tests := []struct {
64-
name string
65-
config *config.AzureConfig
64+
name string
65+
config *config.AzureConfig
66+
proxyMode bool // The proxy mode doesn't matter if user-assigned managed identity is used to get service principal token
6667
}{
6768
{
6869
name: "using user-assigned managed identity to access keyvault",
@@ -73,6 +74,7 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
7374
ClientID: "AADClientID",
7475
ClientSecret: "AADClientSecret",
7576
},
77+
proxyMode: false,
7678
},
7779
// The Azure service principal is ignored when
7880
// UseManagedIdentityExtension is set to true
@@ -82,12 +84,13 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
8284
UseManagedIdentityExtension: true,
8385
UserAssignedIdentityID: "clientID",
8486
},
87+
proxyMode: true,
8588
},
8689
}
8790

8891
for _, test := range tests {
8992
t.Run(test.name, func(t *testing.T) {
90-
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net")
93+
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", test.proxyMode)
9194
if err != nil {
9295
t.Fatalf("expected err to be nil, got: %v", err)
9396
}
@@ -108,14 +111,16 @@ func TestGetServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
108111

109112
func TestGetServicePrincipalTokenFromMSI(t *testing.T) {
110113
tests := []struct {
111-
name string
112-
config *config.AzureConfig
114+
name string
115+
config *config.AzureConfig
116+
proxyMode bool // The proxy mode doesn't matter if MSI is used to get service principal token
113117
}{
114118
{
115119
name: "using system-assigned managed identity to access keyvault",
116120
config: &config.AzureConfig{
117121
UseManagedIdentityExtension: true,
118122
},
123+
proxyMode: false,
119124
},
120125
// The Azure service principal is ignored when
121126
// UseManagedIdentityExtension is set to true
@@ -127,12 +132,13 @@ func TestGetServicePrincipalTokenFromMSI(t *testing.T) {
127132
ClientID: "AADClientID",
128133
ClientSecret: "AADClientSecret",
129134
},
135+
proxyMode: true,
130136
},
131137
}
132138

133139
for _, test := range tests {
134140
t.Run(test.name, func(t *testing.T) {
135-
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net")
141+
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", test.proxyMode)
136142
if err != nil {
137143
t.Fatalf("expected err to be nil, got: %v", err)
138144
}
@@ -168,7 +174,7 @@ func TestGetServicePrincipalToken(t *testing.T) {
168174

169175
for _, test := range tests {
170176
t.Run(test.name, func(t *testing.T) {
171-
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net")
177+
token, err := GetServicePrincipalToken(test.config, "https://login.microsoftonline.com/", "https://vault.azure.net", false)
172178
if err != nil {
173179
t.Fatalf("expected err to be nil, got: %v", err)
174180
}
@@ -183,7 +189,7 @@ func TestGetServicePrincipalToken(t *testing.T) {
183189
t.Fatalf("expected err to be nil, got: %v", err)
184190
}
185191
if !reflect.DeepEqual(token, spt) {
186-
t.Fatalf("expected: %v, got: %v", spt, token)
192+
t.Fatalf("expected: %+v, got: %+v", spt, token)
187193
}
188194
})
189195
}

pkg/consts/consts.go

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright (c) Microsoft and contributors. All rights reserved.
2+
//
3+
// This source code is licensed under the MIT license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
package consts
7+
8+
const (
9+
// In proxy mode, the header is added into the requests from kms-plugin.
10+
// The proxy will check the header and forward the request to different destinations.
11+
// e.g. When the value of the header "x-azure-proxy-target" is "KeyVault", the request
12+
// is forwared to Azure Key Vault by the proxy.
13+
RequestHeaderTargetType = "x-azure-proxy-target"
14+
TargetTypeAzureActiveDirectory = "AzureActiveDirectory"
15+
TargetTypeKeyVault = "KeyVault"
16+
)

pkg/plugin/keyvault.go

+21-3
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@ import (
1010
"encoding/base64"
1111
"fmt"
1212
"regexp"
13+
"strings"
1314

1415
"github.com/Azure/kubernetes-kms/pkg/auth"
1516
"github.com/Azure/kubernetes-kms/pkg/config"
17+
"github.com/Azure/kubernetes-kms/pkg/consts"
1618
"github.com/Azure/kubernetes-kms/pkg/utils"
1719
"github.com/Azure/kubernetes-kms/pkg/version"
1820

1921
kv "github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault"
22+
"github.com/Azure/go-autorest/autorest"
2023
"github.com/Azure/go-autorest/autorest/azure"
2124
"k8s.io/klog/v2"
2225
)
@@ -38,7 +41,7 @@ type keyVaultClient struct {
3841
}
3942

4043
// NewKeyVaultClient returns a new key vault client to use for kms operations
41-
func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersion string) (*keyVaultClient, error) {
44+
func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersion string, proxyMode bool, proxyAddress string, proxyPort int) (*keyVaultClient, error) {
4245
// Sanitize vaultName, keyName, keyVersion. (https://github.com/Azure/kubernetes-kms/issues/85)
4346
vaultName = utils.SanitizeString(vaultName)
4447
keyName = utils.SanitizeString(keyName)
@@ -58,7 +61,11 @@ func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersio
5861
if err != nil {
5962
return nil, fmt.Errorf("failed to parse cloud environment: %s, error: %+v", config.Cloud, err)
6063
}
61-
token, err := auth.GetKeyvaultToken(config, env)
64+
if proxyMode {
65+
env.ActiveDirectoryEndpoint = fmt.Sprintf("http://%s:%d/", proxyAddress, proxyPort)
66+
}
67+
68+
token, err := auth.GetKeyvaultToken(config, env, proxyMode)
6269
if err != nil {
6370
return nil, fmt.Errorf("failed to get key vault token, error: %+v", err)
6471
}
@@ -69,7 +76,12 @@ func newKeyVaultClient(config *config.AzureConfig, vaultName, keyName, keyVersio
6976
return nil, fmt.Errorf("failed to get vault url, error: %+v", err)
7077
}
7178

72-
klog.InfoS("using kms key for encrypt/decrypt", "vaultName", vaultName, "keyName", keyName, "keyVersion", keyVersion)
79+
if proxyMode {
80+
kvClient.RequestInspector = autorest.WithHeader(consts.RequestHeaderTargetType, consts.TargetTypeKeyVault)
81+
vaultURL = getProxiedVaultURL(vaultURL, proxyAddress, proxyPort)
82+
}
83+
84+
klog.InfoS("using kms key for encrypt/decrypt", "vaultURL", *vaultURL, "keyName", keyName, "keyVersion", keyVersion)
7385

7486
client := &keyVaultClient{
7587
baseClient: kvClient,
@@ -130,5 +142,11 @@ func getVaultURL(vaultName string, azureEnvironment *azure.Environment) (vaultUR
130142

131143
vaultDNSSuffixValue := azureEnvironment.KeyVaultDNSSuffix
132144
vaultURI := "https://" + vaultName + "." + vaultDNSSuffixValue + "/"
145+
133146
return &vaultURI, nil
134147
}
148+
149+
func getProxiedVaultURL(vaultURL *string, proxyAddress string, proxyPort int) *string {
150+
proxiedVaultURL := fmt.Sprintf("http://%s:%d/%s", proxyAddress, proxyPort, strings.TrimPrefix(*vaultURL, "https://"))
151+
return &proxiedVaultURL
152+
}

0 commit comments

Comments
 (0)