Browse Source

Remove/cleanup request context helpers (#525)

* Remove context helpers in context.go
* Update request context funcs to take concrete types
* Move TestNativeContextMiddleware to mux_test.go
* Clarify KeepContext Go 1.7+ comment

Mux doesn't build on Go < 1.7 so the comment doesn't really need to
clarify anymore.
pull/463/merge
Franklin Harding 5 years ago committed by Matt Silverlock
parent
commit
f395758b85
  1. 18
      context.go
  2. 30
      context_test.go
  3. 22
      mux.go
  4. 24
      mux_test.go
  5. 2
      test_helpers.go

18
context.go

@ -1,18 +0,0 @@
package mux
import (
"context"
"net/http"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return r.Context().Value(key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
return r.WithContext(context.WithValue(r.Context(), key, val))
}

30
context_test.go

@ -1,30 +0,0 @@
package mux
import (
"context"
"net/http"
"testing"
"time"
)
func TestNativeContextMiddleware(t *testing.T) {
withTimeout := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
h.ServeHTTP(w, r.WithContext(ctx))
})
}
r := NewRouter()
r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := Vars(r)
if vars["foo"] != "bar" {
t.Fatal("Expected foo var to be set")
}
})))
rec := NewRecorder()
req := newRequest("GET", "/path/bar")
r.ServeHTTP(rec, req)
}

22
mux.go

@ -5,6 +5,7 @@
package mux
import (
"context"
"errors"
"fmt"
"net/http"
@ -58,8 +59,7 @@ type Router struct {
// If true, do not clear the request context after handling the request.
//
// Deprecated: No effect when go1.7+ is used, since the context is stored
// on the request itself.
// Deprecated: No effect, since the context is stored on the request itself.
KeepContext bool
// Slice of middlewares to be called after a match is found
@ -195,8 +195,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var handler http.Handler
if r.Match(req, &match) {
handler = match.Handler
req = setVars(req, match.Vars)
req = setCurrentRoute(req, match.Route)
req = requestWithVars(req, match.Vars)
req = requestWithRoute(req, match.Route)
}
if handler == nil && match.MatchErr == ErrMethodMismatch {
@ -426,7 +426,7 @@ const (
// Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string {
if rv := contextGet(r, varsKey); rv != nil {
if rv := r.Context().Value(varsKey); rv != nil {
return rv.(map[string]string)
}
return nil
@ -438,18 +438,20 @@ func Vars(r *http.Request) map[string]string {
// after the handler returns, unless the KeepContext option is set on the
// Router.
func CurrentRoute(r *http.Request) *Route {
if rv := contextGet(r, routeKey); rv != nil {
if rv := r.Context().Value(routeKey); rv != nil {
return rv.(*Route)
}
return nil
}
func setVars(r *http.Request, val interface{}) *http.Request {
return contextSet(r, varsKey, val)
func requestWithVars(r *http.Request, vars map[string]string) *http.Request {
ctx := context.WithValue(r.Context(), varsKey, vars)
return r.WithContext(ctx)
}
func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
return contextSet(r, routeKey, val)
func requestWithRoute(r *http.Request, route *Route) *http.Request {
ctx := context.WithValue(r.Context(), routeKey, route)
return r.WithContext(ctx)
}
// ----------------------------------------------------------------------------

24
mux_test.go

@ -7,6 +7,7 @@ package mux
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
@ -16,6 +17,7 @@ import (
"reflect"
"strings"
"testing"
"time"
)
func (r *Route) GoString() string {
@ -2804,6 +2806,28 @@ func TestSubrouterNotFound(t *testing.T) {
}
}
func TestContextMiddleware(t *testing.T) {
withTimeout := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
h.ServeHTTP(w, r.WithContext(ctx))
})
}
r := NewRouter()
r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := Vars(r)
if vars["foo"] != "bar" {
t.Fatal("Expected foo var to be set")
}
})))
rec := NewRecorder()
req := newRequest("GET", "/path/bar")
r.ServeHTTP(rec, req)
}
// mapToPairs converts a string map to a slice of string pairs
func mapToPairs(m map[string]string) []string {
var i int

2
test_helpers.go

@ -15,5 +15,5 @@ import "net/http"
// can be set by making a route that captures the required variables,
// starting a server and sending the request to that server.
func SetURLVars(r *http.Request, val map[string]string) *http.Request {
return setVars(r, val)
return requestWithVars(r, val)
}

Loading…
Cancel
Save