Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support processing parameters sent as a URL-encoded form #8325

Merged
merged 5 commits into from
Feb 12, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"io/ioutil"
"mime"
"net"
"net/http"
"net/textproto"
Expand Down Expand Up @@ -566,7 +567,7 @@ func parseQuery(values url.Values) map[string]interface{} {
return nil
}

func parseRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) {
func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) {
// Limit the maximum number of bytes to MaxRequestSize to protect
// against an indefinite amount of data being read.
reader := r.Body
Expand Down Expand Up @@ -598,6 +599,43 @@ func parseRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out
return nil, err
}

// parseFormRequest parses values from a form POST.
//
// A nil map will be returned if the format is empty or invalid.
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
maxRequestSize := r.Context().Value("max_request_size")
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
}
if max > 0 {
r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max))
}
}
if err := r.ParseForm(); err != nil {
return nil, err
}

var data map[string]interface{}

if len(r.PostForm) != 0 {
data = make(map[string]interface{}, len(r.PostForm))
for k, v := range r.PostForm {
switch len(v) {
case 0, 1:
data[k] = v[0]
default:
// Almost anywhere taking in a string list can take in comma
// separated values, and really this is super niche anyways
data[k] = strings.Join(v, ",")
}
}
}

return data, nil
}

// handleRequestForwarding determines whether to forward a request or not,
// falling back on the older behavior of redirecting the client
func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handler {
Expand Down Expand Up @@ -960,6 +998,40 @@ func parseMFAHeader(req *logical.Request) error {
return nil
}

// isForm tries to determine whether the request should be
// processed as a form or as JSON.
//
// Virtually all existing use cases have assumed processing as JSON,
// and there has not been a Content-Type requirement in the API. In order to
// maintain backwards compatibility, this will err on the side of JSON.
// The request will be considered a form only if:
//
// 1. The content type is "application/x-www-form-urlencoded"
// 2. The start of the request doesn't look like JSON. For this test we
// we expect the body to begin with { or [, ignoring leading whitespace.
func isForm(head []byte, contentType string) bool {
contentType, _, err := mime.ParseMediaType(contentType)

if err != nil || contentType != "application/x-www-form-urlencoded" {
return false
}

// Look for the start of JSON or not-JSON, skipping any insignificant
// whitespace (per https://tools.ietf.org/html/rfc7159#section-2).
for _, c := range head {
switch c {
case ' ', '\t', '\n', '\r':
continue
case '[', '{': // JSON
return false
default: // not JSON
return true
}
}

return true
}

func respondError(w http.ResponseWriter, status int, err error) {
logical.RespondError(w, status, err)
}
Expand Down
63 changes: 63 additions & 0 deletions http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package http

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/textproto"
"net/url"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -676,3 +679,63 @@ func testNonPrintable(t *testing.T, disable bool) {
testResponseStatus(t, resp, 400)
}
}

func TestHandler_Parse_Form(t *testing.T) {
cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{
HandlerFunc: Handler,
})
cluster.Start()
defer cluster.Cleanup()

cores := cluster.Cores

core := cores[0].Core
vault.TestWaitActive(t, core)

c := cleanhttp.DefaultClient()
c.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: cluster.RootCAs,
},
}

values := url.Values{
"zip": []string{"zap"},
"abc": []string{"xyz"},
"multi": []string{"first", "second"},
}
req, err := http.NewRequest("POST", cores[0].Client.Address()+"/v1/secret/foo", nil)
if err != nil {
t.Fatal(err)
}
req.Body = ioutil.NopCloser(strings.NewReader(values.Encode()))
req.Header.Set("x-vault-token", cluster.RootToken)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
resp, err := c.Do(req)
if err != nil {
t.Fatal(err)
}

if resp.StatusCode != 204 {
t.Fatalf("bad response: %#v\nrequest was: %#v\nurl was: %#v", *resp, *req, req.URL)
}

client := cores[0].Client
client.SetToken(cluster.RootToken)

apiResp, err := client.Logical().Read("secret/foo")
if err != nil {
t.Fatal(err)
}
if apiResp == nil {
t.Fatal("api resp is nil")
}
expected := map[string]interface{}{
"zip": "zap",
"abc": "xyz",
"multi": "first,second",
}
if diff := deep.Equal(expected, apiResp.Data); diff != nil {
t.Fatal(diff)
}
}
58 changes: 49 additions & 9 deletions http/logical.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http

import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
Expand All @@ -20,6 +21,24 @@ import (
"go.uber.org/atomic"
)

// bufferedReader can be used to replace a request body with a buffered
// version. The Close method invokes the original Closer.
type bufferedReader struct {
*bufio.Reader
rOrig io.ReadCloser
}

func newBufferedReader(r io.ReadCloser) *bufferedReader {
return &bufferedReader{
Reader: bufio.NewReader(r),
rOrig: r,
}
}

func (b *bufferedReader) Close() error {
return b.rOrig.Close()
}

func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
ns, err := namespace.FromContext(r.Context())
if err != nil {
Expand Down Expand Up @@ -71,16 +90,37 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.

case "POST", "PUT":
op = logical.UpdateOperation
// Parse the request if we can
if op == logical.UpdateOperation {
// If we are uploading a snapshot we don't want to parse it. Instead
// we will simply add the HTTP request to the logical request object
// for later consumption.
if path == "sys/storage/raft/snapshot" || path == "sys/storage/raft/snapshot-force" {
passHTTPReq = true
origBody = r.Body

// Buffer the request body in order to allow us to peek at the beginning
// without consuming it. This approach involves no copying.
bufferedBody := newBufferedReader(r.Body)
r.Body = bufferedBody

// If we are uploading a snapshot we don't want to parse it. Instead
// we will simply add the HTTP request to the logical request object
// for later consumption.
if path == "sys/storage/raft/snapshot" || path == "sys/storage/raft/snapshot-force" {
passHTTPReq = true
origBody = r.Body
} else {
// Sample the first bytes to determine whether this should be parsed as
// a form or as JSON. The amount to look ahead (512 bytes) is arbitrary
// but extremely tolerant (i.e. allowing 511 bytes of leading whitespace
// and an incorrect content-type).
head, err := bufferedBody.Peek(512)
if err != nil && err != bufio.ErrBufferFull && err != io.EOF {
return nil, nil, http.StatusBadRequest, err
}

if isForm(head, r.Header.Get("Content-Type")) {
formData, err := parseFormRequest(r)
if err != nil {
return nil, nil, http.StatusBadRequest, fmt.Errorf("error parsing form data: %w", err)
}

data = formData
} else {
origBody, err = parseRequest(perfStandby, r, w, &data)
origBody, err = parseJSONRequest(perfStandby, r, w, &data)
if err == io.EOF {
data = nil
err = nil
Expand Down
25 changes: 25 additions & 0 deletions http/logical_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,28 @@ func TestLogical_Audit_invalidWrappingToken(t *testing.T) {
}
}
}

func TestLogical_ShouldParseForm(t *testing.T) {
const formCT = "application/x-www-form-urlencoded"

tests := map[string]struct {
prefix string
contentType string
isForm bool
}{
"JSON": {`{"a":42}`, formCT, false},
"JSON 2": {`[42]`, formCT, false},
"JSON w/leading space": {" \n\n\r\t [42] ", formCT, false},
"Form": {"a=42&b=dog", formCT, true},
"Form w/wrong CT": {"a=42&b=dog", "application/json", false},
}

for name, test := range tests {
isForm := isForm([]byte(test.prefix), test.contentType)

if isForm != test.isForm {
t.Fatalf("%s fail: expected isForm %t, got %t", name, test.isForm, isForm)
}
}

}
4 changes: 2 additions & 2 deletions http/sys_generate_root.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r
func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) {
// Parse the request
var req GenerateRootInitRequest
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF {
if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF {
respondError(w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -132,7 +132,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Parse the request
var req GenerateRootUpdateRequest
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion http/sys_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request)

// Parse the request
var req InitRequest
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion http/sys_raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func handleSysRaftJoin(core *vault.Core) http.Handler {
func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Request) {
// Parse the request
var req JoinRequest
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF {
if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF {
respondError(w, http.StatusBadRequest, err)
return
}
Expand Down
6 changes: 3 additions & 3 deletions http/sys_rekey.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool,
func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
// Parse the request
var req RekeyRequest
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler {

// Parse the request
var req RekeyUpdateRequest
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -306,7 +306,7 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery
func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
// Parse the request
var req RekeyVerificationUpdateRequest
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion http/sys_seal.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {

// Parse the request
var req UnsealRequest
if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil {
if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
Expand Down