Skip to content

Commit

Permalink
Move handler transformers to a property on the router
Browse files Browse the repository at this point in the history
Refs #9

* Improve non-group route registration
* Refactor
* Remove debug line
  • Loading branch information
tmus authored Jan 4, 2023
1 parent 5490202 commit cd79fae
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 105 deletions.
30 changes: 26 additions & 4 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ type Group struct {
prefix string
routes []*Route

router *Router

middleware []Middleware
}

Expand All @@ -13,11 +15,11 @@ func (g *Group) calculateRouteRegexs() {
}
}

func NewGroup(routes ...*Route) *Group {
g := &Group{}
func newGroup(router *Router, routes ...*Route) *Group {
g := &Group{router: router}

for _, r := range routes {
r.group = g
g.Add(r)
g.calculateRouteRegexs()
}

Expand All @@ -38,10 +40,30 @@ func (g *Group) Middleware(middleware ...Middleware) *Group {
}

func (g *Group) Add(routes ...*Route) *Group {
g.routes = append(g.routes, routes...)
for _, r := range routes {
i, found := g.findExistingRoute(r)
if found {
g.routes[i] = r
} else {
g.routes = append(g.routes, r)
}

r.group = g
r.buildHandler()
}
return g
}

func (g *Group) findExistingRoute(route *Route) (int, bool) {
for i, r := range g.routes {
if r.path == route.path && methodsMatch(r, route) {
return i, true
}
}

return -1, false
}

func (g *Group) Routes() []*Route {
return g.routes
}
10 changes: 7 additions & 3 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@ func TestFindRoutesInGroups(t *testing.T) {
server := httptest.NewServer(r)
defer server.Close()

resp, _ := http.Get(server.URL + "/group/test")
resp, err := http.Get(server.URL + "/group/test")
if err != nil {
t.Fatal(err)
}

if resp.StatusCode != http.StatusOK {
t.Errorf("Expected %d, got %d", http.StatusOK, resp.StatusCode)
}
}

func TestGroupCanAddRoute(t *testing.T) {
group := router.NewGroup()
rtr := router.New()
group := rtr.Group()

group.Add(
router.NewRoute([]string{http.MethodGet}, "/", func() string { return "Hello" }),
router.Get("/", func() string { return "Hello" }),
)

assert.Equal(t, 1, len(group.Routes()))
Expand Down
32 changes: 11 additions & 21 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,13 @@ import (
"reflect"
)

// handlerTransformers is a package-level variable that contains a map between
// function signatures and the function declaration used to turn a provided
// route resolver into a valid http.Handler. This value is used at the point
// of building a handler for a newly registered route.
//
// See `handler_defaults.go` for examples of "standard" transformers.
var handlerTransformers = make(map[string]interface{})

func init() {
for _, h := range defaultHandlers {
if err := AddHandlerTransformer(h); err != nil {
panic(err)
}
}
}

// AddHandlerTransformer adds the provided transformer function to the router,
// enabling routes to be registered that match the given signature.
//
// Note, the expected signature of `fn` is:
//
// func(v interface{}) http.Handler
func AddHandlerTransformer(fn interface{}) error {
func (r *Router) AddHandlerTransformer(fn interface{}) error {
t := reflect.TypeOf(fn)
if err := validateHandlerTransformer(t); err != nil {
return err
Expand All @@ -37,25 +21,31 @@ func AddHandlerTransformer(fn interface{}) error {
// Get the type of the first (and only) parameter to `fn`. This will be the
// signature of any route resolvers that are added to the router.
sig := t.In(0).String()
if _, ok := handlerTransformers[sig]; ok {
if _, ok := r.transformers[sig]; ok {
return fmt.Errorf("handler signature `%s` already exists, transformer not added", sig)
}

handlerTransformers[sig] = fn
r.transformers[sig] = fn
return nil
}

// buildHandler dynamically creates an http.Handler based on the function signature
// of the passed in function `fn`.
func buildHandler(v interface{}) http.Handler {
func (r *Router) buildHandler(v interface{}) http.Handler {
// Retrieve the of the function from the transformer map.
t := fmt.Sprintf("%T", v)
f := reflect.ValueOf(handlerTransformers[t])
val, ok := r.transformers[t]
if !ok {
panic("transformer does not exist")
}
f := reflect.ValueOf(val)

// Reflect the value of `fn` and use it as the argument for the transformer
// function. Return the value coerced to an http.Handler.
in := []reflect.Value{reflect.ValueOf(v)}

handler, ok := f.Call(in)[0].Interface().(http.Handler)
// TODO: return an error here instead.
if !ok {
panic(fmt.Sprintf("expected http.Handler return type, got %T", v))
}
Expand Down
5 changes: 3 additions & 2 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ import (
)

func TestCanRegisterCustomHandlerTransformer(t *testing.T) {
err := router.AddHandlerTransformer(func(fn func() int) http.Handler {
r := router.New()

err := r.AddHandlerTransformer(func(fn func() int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
str := strconv.Itoa(fn())
w.Write([]byte(str))
})
})
assert.NoError(t, err)

r := router.New()
r.Get("/", func() int {
return 99
})
Expand Down
20 changes: 9 additions & 11 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,18 @@ func threeMiddleware(next http.Handler) http.Handler {
}

func TestCanAddMiddlewareToRoute(t *testing.T) {
r := router.NewRoute(
[]string{http.MethodGet},
"/",
helloHandler,
).Middleware(oneMiddleware, twoMiddleware)

req, _ := http.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
rtr := router.New()
rtr.Get("/", helloHandler).Middleware(oneMiddleware, twoMiddleware)

r.Serve(rr, req)
server := httptest.NewServer(rtr)
defer server.Close()

body, _ := ioutil.ReadAll(rr.Body)
resp, err := http.Get(server.URL + "/")
if err != nil {
t.Fatal(err)
}
body, _ := ioutil.ReadAll(resp.Body)
expected := "21Hello"

if string(body) != expected {
t.Errorf("Got %s, wanted %s.", string(body), expected)
}
Expand Down
22 changes: 15 additions & 7 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ import (
type Route struct {
methods []string
path string
handler http.Handler
regex *regexp.Regexp
// rawHandler is the handler provided to the route as-is, i.e., before it has
// been transformed into an http.Handler
rawHandler interface{}
handler http.Handler

regex *regexp.Regexp

// The {} bits of a route
params []string
Expand All @@ -32,18 +36,18 @@ func (route *Route) Middleware(middleware ...Middleware) *Route {

// NewRoute creates a new route definition for a given method, path and handler.
func NewRoute(methods []string, path string, handler interface{}) *Route {
return newHandlerRoute(methods, path, buildHandler(handler))
return newHandlerRoute(methods, path, handler)
}

func newHandlerRoute(methods []string, path string, handler http.Handler) *Route {
func newHandlerRoute(methods []string, path string, handler interface{}) *Route {
if path[0] != '/' {
path = "/" + path
}

r := &Route{
methods: methods,
path: path,
handler: handler,
methods: methods,
path: path,
rawHandler: handler,
}

r.regex = r.calculateRouteRegex()
Expand All @@ -68,6 +72,10 @@ func (route *Route) Serve(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
}

func (r *Route) buildHandler() {
r.handler = r.group.router.buildHandler(r.rawHandler)
}

// Get defines a new `GET` route.
func Get(path string, handler interface{}) *Route {
return NewRoute([]string{http.MethodGet}, path, handler)
Expand Down
41 changes: 24 additions & 17 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,31 @@ import (
)

func TestRouteInference(t *testing.T) {
routes := []*router.Route{
// TODO: Fix the stringer implementation
// router.NewRoute([]string{http.MethodGet}, "/", testStringer),
router.NewRoute([]string{http.MethodGet}, "/", testHandler),
router.NewRoute([]string{http.MethodGet}, "/", testFunc),
}
rtr := router.New()
rtr.Get("/handler", testHandler)
rtr.Get("func", testFunc)

rr := httptest.NewRecorder()
server := httptest.NewServer(rtr)
defer server.Close()

for _, rt := range routes {
req, _ := http.NewRequest(http.MethodGet, "/", nil)
rt.Serve(rr, req)
body, _ := ioutil.ReadAll(rr.Body)
expected := "Test"
resp, err := http.Get(server.URL + "/handler")
if err != nil {
t.Fatal(err)
}
body, _ := ioutil.ReadAll(resp.Body)
expected := "handler"
if string(body) != expected {
t.Errorf("Got %s, wanted %s.", string(body), expected)
}

if string(body) != expected {
t.Errorf("Got %s, wanted %s.", string(body), expected)
}
resp, err = http.Get(server.URL + "/func")
if err != nil {
t.Fatal(err)
}
body, _ = ioutil.ReadAll(resp.Body)
expected = "func"
if string(body) != expected {
t.Errorf("Got %s, wanted %s.", string(body), expected)
}
}

Expand All @@ -40,9 +47,9 @@ func (stringHandler) String() string {
var (
testStringer = stringHandler{}
testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Test"))
w.Write([]byte("handler"))
})
testFunc = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Test"))
w.Write([]byte("func"))
}
)
Loading

0 comments on commit cd79fae

Please sign in to comment.