Skip to content

Commit d5226e9

Browse files
committed
feat(middleware): add chain middleware
Fixes: #12
1 parent ccce0b2 commit d5226e9

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

middleware/chain.go

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
)
6+
7+
// Middleware represents a function that can wrap a http.Handler with
8+
// additional functionality. It takes a http.Handler and returns a new
9+
// http.Handler that includes the middleware's behavior.
10+
type Middleware func(http.Handler) http.Handler
11+
12+
// Chain creates a chain of HTTP middleware functions to wrap around a
13+
// http.Handler.
14+
// It applies each middleware in the order they are provided, allowing for
15+
// layered processing of HTTP requests and responses.
16+
func Chain(next http.Handler, ms ...Middleware) http.Handler {
17+
for i := len(ms) - 1; i >= 0; i-- {
18+
m := ms[i]
19+
next = m(next)
20+
}
21+
return next
22+
}

middleware/chain_test.go

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/candango/httpok/testrunner"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
type PlainHandler struct {
12+
http.Handler
13+
}
14+
15+
func (h *PlainHandler) GetSomething(w http.ResponseWriter, r *http.Request) {
16+
w.Write([]byte("Something"))
17+
}
18+
19+
func (h *PlainHandler) GetSomethingElse(w http.ResponseWriter, r *http.Request) {
20+
w.Write([]byte("Something else"))
21+
}
22+
23+
func NewPlainServeMux() http.Handler {
24+
plain := &PlainHandler{}
25+
h := http.NewServeMux()
26+
h.HandleFunc("/something", plain.GetSomething)
27+
h.HandleFunc("/something_else", plain.GetSomethingElse)
28+
return h
29+
}
30+
31+
func TestChainMiddlewareServer(t *testing.T) {
32+
plain := NewPlainServeMux()
33+
34+
runner := testrunner.NewHttpTestRunner(t).WithHandler(plain)
35+
36+
t.Run("Plain runner", func(t *testing.T) {
37+
res, err := runner.WithPath("/something").Get()
38+
if err != nil {
39+
t.Error(err)
40+
}
41+
assert.Equal(t, "200 OK", res.Status)
42+
assert.Equal(t, "Something", testrunner.BodyAsString(t, res))
43+
44+
res, err = runner.WithPath("/something_else").Get()
45+
if err != nil {
46+
t.Error(err)
47+
}
48+
assert.Equal(t, "200 OK", res.Status)
49+
assert.Equal(t, "Something else", testrunner.BodyAsString(t, res))
50+
})
51+
52+
changeSomething := func(next http.Handler) http.Handler {
53+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
54+
if r.URL.String() == "/something" {
55+
w.Write([]byte("First Middleware with "))
56+
}
57+
next.ServeHTTP(w, r)
58+
})
59+
}
60+
61+
blockSomethingElse := func(next http.Handler) http.Handler {
62+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
63+
if r.URL.String() == "/something_else" {
64+
http.Error(w, "Not allowed", http.StatusMethodNotAllowed)
65+
return
66+
}
67+
next.ServeHTTP(w, r)
68+
})
69+
}
70+
71+
chain := Chain(plain, changeSomething, blockSomethingElse)
72+
runner = testrunner.NewHttpTestRunner(t).WithHandler(chain)
73+
74+
t.Run("Chained runner", func(t *testing.T) {
75+
res, err := runner.WithPath("/something").Get()
76+
if err != nil {
77+
t.Error(err)
78+
}
79+
assert.Equal(t, "200 OK", res.Status)
80+
assert.Equal(t, "First Middleware with Something", testrunner.BodyAsString(t, res))
81+
82+
res, err = runner.WithPath("/something_else").Get()
83+
if err != nil {
84+
t.Error(err)
85+
}
86+
assert.Equal(t, "405 Method Not Allowed", res.Status)
87+
assert.Equal(t, "Not allowed\n", testrunner.BodyAsString(t, res))
88+
})
89+
}

0 commit comments

Comments
 (0)