Skip to content

Commit

Permalink
Merge branch 'master' into typhon-raw-request
Browse files Browse the repository at this point in the history
  • Loading branch information
milesbxf authored Oct 9, 2024
2 parents 9692c67 + c2c0b01 commit 84f203b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
46 changes: 37 additions & 9 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ import (
// https://play.golang.org/p/MxhRiL37R-9
type routerContextKeyType struct{}
type routerRequestPatternContextKeyType struct{}
type routerRequestMethodContextKeyType struct{}

var (
routerContextKey = routerContextKeyType{}
routerRequestPatternContextKey = routerRequestPatternContextKeyType{}
routerRequestMethodContextKey = routerRequestMethodContextKeyType{}
routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`)
)

Expand Down Expand Up @@ -53,6 +55,22 @@ func routerPathPatternForRequest(r Request) string {
return ""
}

// RequestPatternFromContext returns the pattern that was matched for the request, if available.
func RequestPatternFromContext(ctx context.Context) (string, bool) {
if v := ctx.Value(routerRequestPatternContextKey); v != nil {
return v.(string), true
}
return "", false
}

// RequestMethodFromContext returns the method of the request, if available.
func RequestMethodFromContext(ctx context.Context) (string, bool) {
if v := ctx.Value(routerRequestMethodContextKey); v != nil {
return v.(string), true
}
return "", false
}

func (r *Router) compile(pattern string) *regexp.Regexp {
re, pos := ``, 0
for _, m := range routerComponentsRe.FindAllStringSubmatchIndex(pattern, -1) {
Expand Down Expand Up @@ -134,6 +152,7 @@ func (r Router) Serve() Service {
}
req.Context = context.WithValue(req.Context, routerContextKey, &r)
req.Context = context.WithValue(req.Context, routerRequestPatternContextKey, pathPattern)
req.Context = context.WithValue(req.Context, routerRequestMethodContextKey, req.Method)
rsp := svc(req)
if rsp.Request == nil {
rsp.Request = &req
Expand All @@ -157,37 +176,46 @@ func (r Router) Params(req Request) map[string]string {
// Sugar

// GET is shorthand for:
// r.Register("GET", pattern, svc)
//
// r.Register("GET", pattern, svc)
func (r *Router) GET(pattern string, svc Service) { r.Register("GET", pattern, svc) }

// CONNECT is shorthand for:
// r.Register("CONNECT", pattern, svc)
//
// r.Register("CONNECT", pattern, svc)
func (r *Router) CONNECT(pattern string, svc Service) { r.Register("CONNECT", pattern, svc) }

// DELETE is shorthand for:
// r.Register("DELETE", pattern, svc)
//
// r.Register("DELETE", pattern, svc)
func (r *Router) DELETE(pattern string, svc Service) { r.Register("DELETE", pattern, svc) }

// HEAD is shorthand for:
// r.Register("HEAD", pattern, svc)
//
// r.Register("HEAD", pattern, svc)
func (r *Router) HEAD(pattern string, svc Service) { r.Register("HEAD", pattern, svc) }

// OPTIONS is shorthand for:
// r.Register("OPTIONS", pattern, svc)
//
// r.Register("OPTIONS", pattern, svc)
func (r *Router) OPTIONS(pattern string, svc Service) { r.Register("OPTIONS", pattern, svc) }

// PATCH is shorthand for:
// r.Register("PATCH", pattern, svc)
//
// r.Register("PATCH", pattern, svc)
func (r *Router) PATCH(pattern string, svc Service) { r.Register("PATCH", pattern, svc) }

// POST is shorthand for:
// r.Register("POST", pattern, svc)
//
// r.Register("POST", pattern, svc)
func (r *Router) POST(pattern string, svc Service) { r.Register("POST", pattern, svc) }

// PUT is shorthand for:
// r.Register("PUT", pattern, svc)
//
// r.Register("PUT", pattern, svc)
func (r *Router) PUT(pattern string, svc Service) { r.Register("PUT", pattern, svc) }

// TRACE is shorthand for:
// r.Register("TRACE", pattern, svc)
//
// r.Register("TRACE", pattern, svc)
func (r *Router) TRACE(pattern string, svc Service) { r.Register("TRACE", pattern, svc) }
22 changes: 22 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,25 @@ func TestRouterSetsRequest(t *testing.T) {
req.Context = rsp.Request.Context
assert.Equal(t, req, *rsp.Request)
}

func TestRouterSetsContextValues(t *testing.T) {
t.Parallel()

router := Router{}
router.GET("/", func(req Request) Response {
return Response{}
})

ctx := context.Background()
req := NewRequest(ctx, "GET", "/", map[string]string{"r": "foo"})
rsp := router.Serve()(req)
require.NotNil(t, rsp.Request)

ctxPattern, ok := RequestPatternFromContext(rsp.Request.Context)
assert.True(t, ok)
assert.Equal(t, "/", ctxPattern)

ctxMethod, ok := RequestMethodFromContext(rsp.Request.Context)
assert.True(t, ok)
assert.Equal(t, "GET", ctxMethod)
}

0 comments on commit 84f203b

Please sign in to comment.