Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport into 1.4.x. Fix SRV Lookups #8533

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const EnvVaultCAPath = "VAULT_CAPATH"
const EnvVaultClientCert = "VAULT_CLIENT_CERT"
const EnvVaultClientKey = "VAULT_CLIENT_KEY"
const EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT"
const EnvVaultSRVLookup = "VAULT_SRV_LOOKUP"
const EnvVaultSkipVerify = "VAULT_SKIP_VERIFY"
const EnvVaultNamespace = "VAULT_NAMESPACE"
const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME"
Expand Down Expand Up @@ -105,6 +106,9 @@ type Config struct {
// Note: It is not thread-safe to set this and make concurrent requests
// with the same client. Cloning a client will not clone this value.
OutputCurlString bool

// SRVLookup enables the client to lookup the host through DNS SRV lookup
SRVLookup bool
}

// TLSConfig contains the parameters needed to configure TLS on the HTTP client
Expand Down Expand Up @@ -245,6 +249,7 @@ func (c *Config) ReadEnvironment() error {
var envInsecure bool
var envTLSServerName string
var envMaxRetries *uint64
var envSRVLookup bool
var limit *rate.Limiter

// Parse the environment variables
Expand Down Expand Up @@ -302,6 +307,13 @@ func (c *Config) ReadEnvironment() error {
return fmt.Errorf("could not parse VAULT_INSECURE")
}
}
if v := os.Getenv(EnvVaultSRVLookup); v != "" {
var err error
envSRVLookup, err = strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("could not parse %s", EnvVaultSRVLookup)
}
}

if v := os.Getenv(EnvVaultTLSServerName); v != "" {
envTLSServerName = v
Expand All @@ -320,6 +332,7 @@ func (c *Config) ReadEnvironment() error {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()

c.SRVLookup = envSRVLookup
c.Limiter = limit

if err := c.ConfigureTLS(t); err != nil {
Expand Down Expand Up @@ -686,12 +699,6 @@ func (c *Client) SetPolicyOverride(override bool) {
c.policyOverride = override
}

// portMap defines the standard port map
var portMap = map[string]string{
"http": "80",
"https": "443",
}

// NewRequest creates a new raw request object to query the Vault server
// configured for this client. This is an advanced method and generally
// doesn't need to be called externally.
Expand All @@ -704,20 +711,14 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
policyOverride := c.policyOverride
c.modifyLock.RUnlock()

var host = addr.Host
// if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV
// record and take the highest match; this is not designed for high-availability, just discovery
var host string = addr.Host
if addr.Port() == "" {
// Avoid lookup of SRV record if scheme is known
port, ok := portMap[addr.Scheme]
if ok {
host = net.JoinHostPort(host, port)
} else {
// Internet Draft specifies that the SRV record is ignored if a port is given
_, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname())
if err == nil && len(addrs) > 0 {
host = fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)
}
// Internet Draft specifies that the SRV record is ignored if a port is given
if addr.Port() == "" && c.config.SRVLookup {
_, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname())
if err == nil && len(addrs) > 0 {
host = fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)
}
}

Expand All @@ -729,6 +730,7 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
Host: host,
Path: path.Join(addr.Path, requestPath),
},
Host: addr.Host,
ClientToken: token,
Params: make(map[string][]string),
}
Expand Down
31 changes: 31 additions & 0 deletions api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,37 @@ func TestClientToken(t *testing.T) {
}
}

func TestClientHostHeader(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Host))
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()

config.Address = strings.ReplaceAll(config.Address, "127.0.0.1", "localhost")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}

// Set the token manually
client.SetToken("foo")

resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
if err != nil {
t.Fatal(err)
}

// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)

// Verify we got the response from the primary
if buf.String() != strings.ReplaceAll(config.Address, "http://", "") {
t.Fatalf("Bad address: %s", buf.String())
}
}

func TestClientBadToken(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {}

Expand Down
3 changes: 2 additions & 1 deletion api/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
type Request struct {
Method string
URL *url.URL
Host string
Params url.Values
Headers http.Header
ClientToken string
Expand Down Expand Up @@ -115,7 +116,7 @@ func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) {
req.URL.User = r.URL.User
req.URL.Scheme = r.URL.Scheme
req.URL.Host = r.URL.Host
req.Host = r.URL.Host
req.Host = r.Host

if r.Headers != nil {
for header, vals := range r.Headers {
Expand Down