Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting Both inheriting directions for context (from Base to derived and derived to base) #86

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
59 changes: 45 additions & 14 deletions router_serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (rootRouter *Router) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
closure.Routers = make([]*Router, 1, rootRouter.maxChildrenDepth)
closure.Routers[0] = rootRouter
closure.Contexts = make([]reflect.Value, 1, rootRouter.maxChildrenDepth)
closure.Contexts[0] = reflect.New(rootRouter.contextType)
closure.Contexts[0] = reflect.New(rootRouter.contextType.Type)
closure.currentMiddlewareLen = len(rootRouter.middleware)
closure.RootRouter = rootRouter
closure.Request.rootContext = closure.Contexts[0]
Expand Down Expand Up @@ -220,20 +220,53 @@ func contextsFor(contexts []reflect.Value, routers []*Router) []reflect.Value {

for i := 1; i < routersLen; i++ {
var ctx reflect.Value
if routers[i].contextType == routers[i-1].contextType {
if routers[i].contextType.Type == routers[i-1].contextType.Type {
ctx = contexts[i-1]
} else {
ctx = reflect.New(routers[i].contextType)
ctxType := routers[i].contextType.Type
// set the first field to the parent
f := reflect.Indirect(ctx).Field(0)
f.Set(contexts[i-1])
if routers[i].contextType.IsDerived {
ctx = createDrivedContext(contexts[i-1], ctxType)
} else {
ctxType = reflect.PtrTo(ctxType)
ctx = getMatchedParentContext(contexts[i-1], ctxType)
}
}
contexts = append(contexts, ctx)
}

return contexts
}

func createDrivedContext(context reflect.Value, neededType reflect.Type) reflect.Value {
ctx := reflect.New(neededType)
childCtx := ctx
for {
f := reflect.Indirect(childCtx).Field(0)
if f.Type() != context.Type() && f.Kind() == reflect.Ptr {
childCtx = reflect.New(f.Type().Elem())
f.Set(childCtx)
continue
} else {
f.Set(context)
break
}
}
return ctx
}

func getMatchedParentContext(context reflect.Value, neededType reflect.Type) reflect.Value {
if neededType != context.Type() {
for {
context = reflect.Indirect(context).Field(0)
if context.Type() == neededType {
break
}
}
}
return context
}

// If there's a panic in the root middleware (so that we don't have a route/target), then invoke the root handler or default.
// If there's a panic in other middleware, then invoke the target action's function.
// If there's a panic in the action handler, then invoke the target action's function.
Expand All @@ -250,19 +283,17 @@ func (rootRouter *Router) handlePanic(rw *appResponseWriter, req *Request, err i

for !targetRouter.errorHandler.IsValid() && targetRouter.parent != nil {
targetRouter = targetRouter.parent

// Need to set context to the next context, UNLESS the context is the same type.
curContextStruct := reflect.Indirect(context)
if targetRouter.contextType != curContextStruct.Type() {
context = curContextStruct.Field(0)
if reflect.Indirect(context).Type() != targetRouter.contextType {
panic("bug: shouldn't get here")
}
}
}
}

if targetRouter.errorHandler.IsValid() {
// Need to set context to the next context, UNLESS the context is the same type.
if _, err := validateContext(reflect.Indirect(reflect.New(targetRouter.contextType.Type)).Interface(), reflect.Indirect(context).Type()); err != nil {
panic(err)
}

ctxType := reflect.PtrTo(targetRouter.contextType.Type)
context = getMatchedParentContext(context, ctxType)
invoke(targetRouter.errorHandler, context, []reflect.Value{reflect.ValueOf(rw), reflect.ValueOf(req), reflect.ValueOf(err)})
} else {
http.Error(rw, DefaultPanicResponse, http.StatusInternalServerError)
Expand Down
81 changes: 53 additions & 28 deletions router_setup.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package web

import (
"errors"
"reflect"
"strings"
)
Expand All @@ -19,6 +20,11 @@ const (

var httpMethods = []httpMethod{httpMethodGet, httpMethodPost, httpMethodPut, httpMethodDelete, httpMethodPatch, httpMethodHead, httpMethodOptions}

type ContextSt struct {
Type reflect.Type
IsDerived bool //true if it's drived from main route, false if main route is drived from it
}

// Router implements net/http's Handler interface and is what you attach middleware, routes/handlers, and subrouters to.
type Router struct {
// Hierarchy:
Expand All @@ -27,7 +33,7 @@ type Router struct {
maxChildrenDepth int

// For each request we'll create one of these objects
contextType reflect.Type
contextType ContextSt

// Eg, "/" or "/admin". Any routes added to this router will be prefixed with this.
pathPrefix string
Expand Down Expand Up @@ -89,10 +95,10 @@ var emptyInterfaceType = reflect.TypeOf((*interface{})(nil)).Elem()
// whose purpose is to communicate type information. On each request, an instance of this
// context type will be automatically allocated and sent to handlers.
func New(ctx interface{}) *Router {
validateContext(ctx, nil)
// validateContext(ctx, nil)

r := &Router{}
r.contextType = reflect.TypeOf(ctx)
r.contextType = ContextSt{Type: reflect.TypeOf(ctx)}
r.pathPrefix = "/"
r.maxChildrenDepth = 1
r.root = make(map[httpMethod]*pathNode)
Expand All @@ -116,10 +122,14 @@ func NewWithPrefix(ctx interface{}, pathPrefix string) *Router {
// embed a pointer to the previous context in the first slot. You can also pass
// a pathPrefix that each route will have. If "" is passed, then no path prefix is applied.
func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router {
validateContext(ctx, r.contextType)

// Create new router, link up hierarchy
newRouter := &Router{parent: r}
contextType, err := validateContext(ctx, r.contextType.Type)
if err != nil {
panic(err)
}
newRouter.contextType = *contextType
r.children = append(r.children, newRouter)

// Increment maxChildrenDepth if this is the first child of the router
Expand All @@ -131,7 +141,6 @@ func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router {
}
}

newRouter.contextType = reflect.TypeOf(ctx)
newRouter.pathPrefix = appendPath(r.pathPrefix, pathPrefix)
newRouter.root = r.root

Expand All @@ -141,7 +150,7 @@ func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router {
// Middleware adds the specified middleware tot he router and returns the router.
func (r *Router) Middleware(fn interface{}) *Router {
vfn := reflect.ValueOf(fn)
validateMiddleware(vfn, r.contextType)
validateMiddleware(vfn, r.contextType.Type)
if vfn.Type().NumIn() == 3 {
r.middleware = append(r.middleware, &middlewareHandler{Generic: true, GenericMiddleware: fn.(func(ResponseWriter, *Request, NextMiddlewareFunc))})
} else {
Expand All @@ -154,7 +163,7 @@ func (r *Router) Middleware(fn interface{}) *Router {
// Error sets the specified function as the error handler (when panics happen) and returns the router.
func (r *Router) Error(fn interface{}) *Router {
vfn := reflect.ValueOf(fn)
validateErrorHandler(vfn, r.contextType)
validateErrorHandler(vfn, r.contextType.Type)
r.errorHandler = vfn
return r
}
Expand All @@ -166,7 +175,7 @@ func (r *Router) NotFound(fn interface{}) *Router {
panic("You can only set a NotFoundHandler on the root router.")
}
vfn := reflect.ValueOf(fn)
validateNotFoundHandler(vfn, r.contextType)
validateNotFoundHandler(vfn, r.contextType.Type)
r.notFoundHandler = vfn
return r
}
Expand All @@ -178,7 +187,7 @@ func (r *Router) OptionsHandler(fn interface{}) *Router {
panic("You can only set an OptionsHandler on the root router.")
}
vfn := reflect.ValueOf(fn)
validateOptionsHandler(vfn, r.contextType)
validateOptionsHandler(vfn, r.contextType.Type)
r.optionsHandler = vfn
return r
}
Expand Down Expand Up @@ -220,7 +229,7 @@ func (r *Router) Options(path string, fn interface{}) *Router {

func (r *Router) addRoute(method httpMethod, path string, fn interface{}) *Router {
vfn := reflect.ValueOf(fn)
validateHandler(vfn, r.contextType)
validateHandler(vfn, r.contextType.Type)
fullPath := appendPath(r.pathPrefix, path)
route := &route{Method: method, Path: fullPath, Router: r}
if vfn.Type().NumIn() == 2 {
Expand Down Expand Up @@ -249,26 +258,40 @@ func (r *Router) depth() int {
// Private methods:
//

// Panics unless validation is correct
func validateContext(ctx interface{}, parentCtxType reflect.Type) {
ctxType := reflect.TypeOf(ctx)

if ctxType.Kind() != reflect.Struct {
panic("web: Context needs to be a struct type")
}

if parentCtxType != nil && parentCtxType != ctxType {
if ctxType.NumField() == 0 {
panic("web: Context needs to have first field be a pointer to parent context")
// validate contexts
func validateContext(ctx interface{}, parentCtxType reflect.Type) (*ContextSt, error) {
doCheck := func(ctxType reflect.Type, parentCtxType reflect.Type) error {
for {
if ctxType.Kind() == reflect.Ptr {
ctxType = ctxType.Elem()
}
if ctxType.Kind() != reflect.Struct {
if ctxType == reflect.TypeOf(ctx) {
return errors.New("web: Context needs to be a struct type\n " + ctxType.String())
}
return errors.New("web: Context needs to have first field be a pointer to parent context\n" +
"Main Context: " + parentCtxType.String() + " Given Context: " + reflect.TypeOf(ctx).String())

}
if ctxType == parentCtxType {
break
}
if ctxType.NumField() == 0 {
return errors.New("web: Context needs to have first field be a pointer to parent context")
}
ctxType = ctxType.Field(0).Type
}
return nil
}

fldType := ctxType.Field(0).Type

// Ensure fld is a pointer to parentCtxType
if fldType != reflect.PtrTo(parentCtxType) {
panic("web: Context needs to have first field be a pointer to parent context")
ctxType := reflect.TypeOf(ctx)
if err1 := doCheck(ctxType, parentCtxType); err1 != nil {
if err2 := doCheck(parentCtxType, ctxType); err2 != nil {
return nil, err1
}
return &ContextSt{ctxType, false}, nil
}
return &ContextSt{ctxType, true}, nil
}

// Panics unless fn is a proper handler wrt ctxType
Expand Down Expand Up @@ -338,8 +361,10 @@ func isValidHandler(vfn reflect.Value, ctxType reflect.Type, types ...reflect.Ty
} else if numIn == (typesLen + 1) {
// context, types
firstArgType := fnType.In(0)
if firstArgType != reflect.PtrTo(ctxType) && firstArgType != emptyInterfaceType {
return false
if firstArgType != emptyInterfaceType {
if _, err := validateContext(reflect.Indirect(reflect.New(firstArgType.Elem())).Interface(), ctxType); err != nil {
return false
}
}
typesStartIdx = 1
} else {
Expand Down