From 48f7cbaddfc87d31b4a569405b9ec6ecbe79beed Mon Sep 17 00:00:00 2001
From: sashayakovtseva <sashayakovtseva@gmail.com>
Date: Thu, 17 Jun 2021 18:18:41 +0300
Subject: [PATCH] Support custom labels

---
 go.mod                                        |   4 +-
 internal/mocks/doc.go                         |   5 +-
 internal/mocks/metrics/Recorder.go            |   2 +-
 .../mocks/middleware/CustomLabelReporter.go   | 102 +++++++++++
 internal/mocks/middleware/Reporter.go         |   2 +-
 metrics/metrics.go                            |   6 +
 metrics/prometheus/prometheus.go              |  39 +++-
 metrics/prometheus/prometheus_test.go         |  36 +++-
 middleware/fasthttp/example_test.go           |  89 +++++++++-
 middleware/middleware.go                      |  25 ++-
 middleware/middleware_test.go                 | 168 ++++++++++++------
 11 files changed, 395 insertions(+), 83 deletions(-)
 create mode 100644 internal/mocks/middleware/CustomLabelReporter.go

diff --git a/go.mod b/go.mod
index 69f130a..a2ca9b0 100644
--- a/go.mod
+++ b/go.mod
@@ -1,5 +1,7 @@
 module github.com/slok/go-http-metrics
 
+go 1.15
+
 require (
 	contrib.go.opencensus.io/exporter/prometheus v0.3.0
 	github.com/emicklei/go-restful v2.15.0+incompatible
@@ -17,5 +19,3 @@ require (
 	go.opencensus.io v0.23.0
 	goji.io v2.0.2+incompatible
 )
-
-go 1.15
diff --git a/internal/mocks/doc.go b/internal/mocks/doc.go
index 74a6c32..ea51b4d 100644
--- a/internal/mocks/doc.go
+++ b/internal/mocks/doc.go
@@ -3,5 +3,6 @@ Package mocks will have all the mocks of the library.
 */
 package mocks // import "github.com/slok/go-http-metrics/internal/mocks"
 
-//go:generate mockery -output ./metrics -outpkg metrics -dir ../../metrics -name Recorder
-//go:generate mockery -output ./middleware -outpkg middleware -dir ../../middleware -name Reporter
+//go:generate mockery --output ./metrics --outpkg metrics --dir ../../metrics --name Recorder
+//go:generate mockery --output ./middleware --outpkg middleware --dir ../../middleware --name Reporter
+//go:generate mockery --output ./middleware --outpkg middleware --dir ../../middleware --name CustomLabelReporter
diff --git a/internal/mocks/metrics/Recorder.go b/internal/mocks/metrics/Recorder.go
index 2e04256..a297b2c 100644
--- a/internal/mocks/metrics/Recorder.go
+++ b/internal/mocks/metrics/Recorder.go
@@ -1,4 +1,4 @@
-// Code generated by mockery v1.0.0. DO NOT EDIT.
+// Code generated by mockery v2.8.0. DO NOT EDIT.
 
 package metrics
 
diff --git a/internal/mocks/middleware/CustomLabelReporter.go b/internal/mocks/middleware/CustomLabelReporter.go
new file mode 100644
index 0000000..79141dd
--- /dev/null
+++ b/internal/mocks/middleware/CustomLabelReporter.go
@@ -0,0 +1,102 @@
+// Code generated by mockery v2.8.0. DO NOT EDIT.
+
+package middleware
+
+import (
+	context "context"
+
+	mock "github.com/stretchr/testify/mock"
+)
+
+// CustomLabelReporter is an autogenerated mock type for the CustomLabelReporter type
+type CustomLabelReporter struct {
+	mock.Mock
+}
+
+// BytesWritten provides a mock function with given fields:
+func (_m *CustomLabelReporter) BytesWritten() int64 {
+	ret := _m.Called()
+
+	var r0 int64
+	if rf, ok := ret.Get(0).(func() int64); ok {
+		r0 = rf()
+	} else {
+		r0 = ret.Get(0).(int64)
+	}
+
+	return r0
+}
+
+// Context provides a mock function with given fields:
+func (_m *CustomLabelReporter) Context() context.Context {
+	ret := _m.Called()
+
+	var r0 context.Context
+	if rf, ok := ret.Get(0).(func() context.Context); ok {
+		r0 = rf()
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).(context.Context)
+		}
+	}
+
+	return r0
+}
+
+// CustomLabels provides a mock function with given fields:
+func (_m *CustomLabelReporter) CustomLabels() []string {
+	ret := _m.Called()
+
+	var r0 []string
+	if rf, ok := ret.Get(0).(func() []string); ok {
+		r0 = rf()
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).([]string)
+		}
+	}
+
+	return r0
+}
+
+// Method provides a mock function with given fields:
+func (_m *CustomLabelReporter) Method() string {
+	ret := _m.Called()
+
+	var r0 string
+	if rf, ok := ret.Get(0).(func() string); ok {
+		r0 = rf()
+	} else {
+		r0 = ret.Get(0).(string)
+	}
+
+	return r0
+}
+
+// StatusCode provides a mock function with given fields:
+func (_m *CustomLabelReporter) StatusCode() int {
+	ret := _m.Called()
+
+	var r0 int
+	if rf, ok := ret.Get(0).(func() int); ok {
+		r0 = rf()
+	} else {
+		r0 = ret.Get(0).(int)
+	}
+
+	return r0
+}
+
+// URLPath provides a mock function with given fields:
+func (_m *CustomLabelReporter) URLPath() string {
+	ret := _m.Called()
+
+	var r0 string
+	if rf, ok := ret.Get(0).(func() string); ok {
+		r0 = rf()
+	} else {
+		r0 = ret.Get(0).(string)
+	}
+
+	return r0
+}
diff --git a/internal/mocks/middleware/Reporter.go b/internal/mocks/middleware/Reporter.go
index 7ed1742..b618092 100644
--- a/internal/mocks/middleware/Reporter.go
+++ b/internal/mocks/middleware/Reporter.go
@@ -1,4 +1,4 @@
-// Code generated by mockery v1.0.0. DO NOT EDIT.
+// Code generated by mockery v2.8.0. DO NOT EDIT.
 
 package middleware
 
diff --git a/metrics/metrics.go b/metrics/metrics.go
index f4fa46a..d17b5d5 100644
--- a/metrics/metrics.go
+++ b/metrics/metrics.go
@@ -16,6 +16,9 @@ type HTTPReqProperties struct {
 	Method string
 	// Code is the response of the request.
 	Code string
+
+	// CustomLabels hold values of the custom labels, if any.
+	CustomLabels []string
 }
 
 // HTTPProperties are the metric properties for the global server metrics.
@@ -24,6 +27,9 @@ type HTTPProperties struct {
 	Service string
 	// ID is the id of the request handler.
 	ID string
+
+	// CustomLabels hold values of the custom labels, if any.
+	CustomLabels []string
 }
 
 // Recorder knows how to record and measure the metrics. This
diff --git a/metrics/prometheus/prometheus.go b/metrics/prometheus/prometheus.go
index 41e95b5..4ed0689 100644
--- a/metrics/prometheus/prometheus.go
+++ b/metrics/prometheus/prometheus.go
@@ -30,6 +30,9 @@ type Config struct {
 	MethodLabel string
 	// ServiceLabel is the name that will be set to the service label, by default is `service`.
 	ServiceLabel string
+
+	// CustomLabels hold names of the custom labels, if any.
+	CustomLabels []string
 }
 
 func (c *Config) defaults() {
@@ -73,6 +76,24 @@ type recorder struct {
 func NewRecorder(cfg Config) metrics.Recorder {
 	cfg.defaults()
 
+	perReqLabels := append(
+		[]string{
+			cfg.ServiceLabel,
+			cfg.HandlerIDLabel,
+			cfg.MethodLabel,
+			cfg.StatusCodeLabel,
+		},
+		cfg.CustomLabels...,
+	)
+
+	serviceLabels := append(
+		[]string{
+			cfg.ServiceLabel,
+			cfg.HandlerIDLabel,
+		},
+		cfg.CustomLabels...,
+	)
+
 	r := &recorder{
 		httpRequestDurHistogram: prometheus.NewHistogramVec(prometheus.HistogramOpts{
 			Namespace: cfg.Prefix,
@@ -80,7 +101,7 @@ func NewRecorder(cfg Config) metrics.Recorder {
 			Name:      "request_duration_seconds",
 			Help:      "The latency of the HTTP requests.",
 			Buckets:   cfg.DurationBuckets,
-		}, []string{cfg.ServiceLabel, cfg.HandlerIDLabel, cfg.MethodLabel, cfg.StatusCodeLabel}),
+		}, perReqLabels),
 
 		httpResponseSizeHistogram: prometheus.NewHistogramVec(prometheus.HistogramOpts{
 			Namespace: cfg.Prefix,
@@ -88,14 +109,14 @@ func NewRecorder(cfg Config) metrics.Recorder {
 			Name:      "response_size_bytes",
 			Help:      "The size of the HTTP responses.",
 			Buckets:   cfg.SizeBuckets,
-		}, []string{cfg.ServiceLabel, cfg.HandlerIDLabel, cfg.MethodLabel, cfg.StatusCodeLabel}),
+		}, perReqLabels),
 
 		httpRequestsInflight: prometheus.NewGaugeVec(prometheus.GaugeOpts{
 			Namespace: cfg.Prefix,
 			Subsystem: "http",
 			Name:      "requests_inflight",
 			Help:      "The number of inflight requests being handled at the same time.",
-		}, []string{cfg.ServiceLabel, cfg.HandlerIDLabel}),
+		}, serviceLabels),
 	}
 
 	cfg.Registry.MustRegister(
@@ -108,13 +129,19 @@ func NewRecorder(cfg Config) metrics.Recorder {
 }
 
 func (r recorder) ObserveHTTPRequestDuration(_ context.Context, p metrics.HTTPReqProperties, duration time.Duration) {
-	r.httpRequestDurHistogram.WithLabelValues(p.Service, p.ID, p.Method, p.Code).Observe(duration.Seconds())
+	lvs := []string{p.Service, p.ID, p.Method, p.Code}
+	lvs = append(lvs, p.CustomLabels...)
+	r.httpRequestDurHistogram.WithLabelValues(lvs...).Observe(duration.Seconds())
 }
 
 func (r recorder) ObserveHTTPResponseSize(_ context.Context, p metrics.HTTPReqProperties, sizeBytes int64) {
-	r.httpResponseSizeHistogram.WithLabelValues(p.Service, p.ID, p.Method, p.Code).Observe(float64(sizeBytes))
+	lvs := []string{p.Service, p.ID, p.Method, p.Code}
+	lvs = append(lvs, p.CustomLabels...)
+	r.httpResponseSizeHistogram.WithLabelValues(lvs...).Observe(float64(sizeBytes))
 }
 
 func (r recorder) AddInflightRequests(_ context.Context, p metrics.HTTPProperties, quantity int) {
-	r.httpRequestsInflight.WithLabelValues(p.Service, p.ID).Add(float64(quantity))
+	lvs := []string{p.Service, p.ID}
+	lvs = append(lvs, p.CustomLabels...)
+	r.httpRequestsInflight.WithLabelValues(lvs...).Add(float64(quantity))
 }
diff --git a/metrics/prometheus/prometheus_test.go b/metrics/prometheus/prometheus_test.go
index bbc8be8..936d80d 100644
--- a/metrics/prometheus/prometheus_test.go
+++ b/metrics/prometheus/prometheus_test.go
@@ -159,7 +159,7 @@ func TestPrometheusRecorder(t *testing.T) {
 			},
 		},
 		{
-			name: "Using a custom labels in the configuration should measure with those labels.",
+			name: "Using a custom label names in the configuration should measure with those labels.",
 			config: libprometheus.Config{
 				HandlerIDLabel:  "route_id",
 				StatusCodeLabel: "status_code",
@@ -186,12 +186,38 @@ func TestPrometheusRecorder(t *testing.T) {
 				`http_request_duration_seconds_count{http_method="GET",http_service="svc1",route_id="test1",status_code="200"} 2`,
 			},
 		},
+		{
+			name: "Using a custom labels in the configuration should measure with those labels.",
+			config: libprometheus.Config{
+				DurationBuckets: []float64{1, 10},
+				CustomLabels:    []string{"user_id"},
+			},
+			recordMetrics: func(r metrics.Recorder) {
+				r.ObserveHTTPRequestDuration(context.TODO(), metrics.HTTPReqProperties{
+					Service:      "svc1",
+					ID:           "test1",
+					Method:       http.MethodGet,
+					Code:         "200",
+					CustomLabels: []string{"userVIP"},
+				}, 6*time.Second)
+				r.AddInflightRequests(context.TODO(), metrics.HTTPProperties{
+					Service:      "svc1",
+					ID:           "test1",
+					CustomLabels: []string{"userVIP"},
+				}, 1)
+			},
+			expMetrics: []string{
+				`http_request_duration_seconds_bucket{code="200",handler="test1",method="GET",service="svc1",user_id="userVIP",le="1"} 0`,
+				`http_request_duration_seconds_bucket{code="200",handler="test1",method="GET",service="svc1",user_id="userVIP",le="10"} 1`,
+				`http_request_duration_seconds_bucket{code="200",handler="test1",method="GET",service="svc1",user_id="userVIP",le="+Inf"} 1`,
+				`http_request_duration_seconds_count{code="200",handler="test1",method="GET",service="svc1",user_id="userVIP"} 1`,
+				`http_requests_inflight{handler="test1",service="svc1",user_id="userVIP"} 1`,
+			},
+		},
 	}
 
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			assert := assert.New(t)
-
 			reg := prometheus.NewRegistry()
 			test.config.Registry = reg
 			mrecorder := libprometheus.NewRecorder(test.config)
@@ -205,10 +231,10 @@ func TestPrometheusRecorder(t *testing.T) {
 			resp := rec.Result()
 
 			// Check all metrics are present.
-			if assert.Equal(http.StatusOK, resp.StatusCode) {
+			if assert.Equal(t, http.StatusOK, resp.StatusCode) {
 				body, _ := ioutil.ReadAll(resp.Body)
 				for _, expMetric := range test.expMetrics {
-					assert.Contains(string(body), expMetric, "metric not present on the result")
+					assert.Contains(t, string(body), expMetric, "metric not present on the result")
 				}
 			}
 		})
diff --git a/middleware/fasthttp/example_test.go b/middleware/fasthttp/example_test.go
index ac565e0..8cd1d9f 100644
--- a/middleware/fasthttp/example_test.go
+++ b/middleware/fasthttp/example_test.go
@@ -1,33 +1,73 @@
 package fasthttp_test
 
 import (
+	"context"
+	"fmt"
 	"log"
 	"net/http"
 
+	"github.com/fasthttp/router"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
-	metrics "github.com/slok/go-http-metrics/metrics/prometheus"
+	promMetrics "github.com/slok/go-http-metrics/metrics/prometheus"
 	"github.com/slok/go-http-metrics/middleware"
 	fasthttpMiddleware "github.com/slok/go-http-metrics/middleware/fasthttp"
 	"github.com/valyala/fasthttp"
 )
 
+func handleHello(rCtx *fasthttp.RequestCtx) {
+	userID, ok := rCtx.UserValue("user_id").(string)
+	if !ok {
+		userID = "unknown"
+	}
+
+	rCtx.SetStatusCode(fasthttp.StatusOK)
+	rCtx.SetBodyString(fmt.Sprintf("Hello, %s!", userID))
+}
+
 // FasthttpMiddleware shows how you would create a default middleware
 // factory and use it to create a fasthttp compatible middleware.
 func Example_fasthttpMiddleware() {
 	// Create our middleware factory with the default settings.
 	mdlw := middleware.New(middleware.Config{
-		Recorder: metrics.NewRecorder(metrics.Config{}),
+		Recorder: promMetrics.NewRecorder(promMetrics.Config{}),
 	})
 
-	// Add our handler and middleware
-	h := func(rCtx *fasthttp.RequestCtx) {
-		rCtx.SetStatusCode(fasthttp.StatusOK)
-		rCtx.SetBodyString("OK")
+	// Create our fasthttp instance.
+	srv := &fasthttp.Server{
+		Handler: fasthttpMiddleware.Handler("", mdlw, handleHello),
 	}
 
-	// Create our fasthttp instance.
+	// Serve metrics from the default prometheus registry.
+	log.Printf("serving metrics at: %s", ":8081")
+	go func() {
+		_ = http.ListenAndServe(":8081", promhttp.Handler())
+	}()
+
+	// Serve our handler.
+	log.Printf("listening at: %s", ":8080")
+	if err := srv.ListenAndServe(":8080"); err != nil {
+		log.Panicf("error while serving: %s", err)
+	}
+}
+
+func Example_fasthttpCustomLabels() {
+	mdlw := middleware.New(middleware.Config{
+		Recorder: promMetrics.NewRecorder(promMetrics.Config{
+			CustomLabels: []string{"user_id"},
+		}),
+	})
+
+	mux := router.New()
+	mux.GET("/{user_id}",
+		func(c *fasthttp.RequestCtx) {
+			mdlw.Measure("/hello", userIDReporter{c}, func() {
+				handleHello(c)
+			})
+		},
+	)
+
 	srv := &fasthttp.Server{
-		Handler: fasthttpMiddleware.Handler("", mdlw, h),
+		Handler: mux.Handler,
 	}
 
 	// Serve metrics from the default prometheus registry.
@@ -42,3 +82,36 @@ func Example_fasthttpMiddleware() {
 		log.Panicf("error while serving: %s", err)
 	}
 }
+
+type userIDReporter struct {
+	c *fasthttp.RequestCtx
+}
+
+func (r userIDReporter) Method() string {
+	return string(r.c.Method())
+}
+
+func (r userIDReporter) Context() context.Context {
+	return r.c
+}
+
+func (r userIDReporter) URLPath() string {
+	return string(r.c.Path())
+}
+
+func (r userIDReporter) StatusCode() int {
+	return r.c.Response.StatusCode()
+}
+
+func (r userIDReporter) BytesWritten() int64 {
+	return int64(len(r.c.Response.Body()))
+}
+
+func (r userIDReporter) CustomLabels() []string {
+	userID, ok := r.c.UserValue("user_id").(string)
+	if !ok {
+		return nil
+	}
+
+	return []string{userID}
+}
diff --git a/middleware/middleware.go b/middleware/middleware.go
index 079e77c..b6d0b77 100644
--- a/middleware/middleware.go
+++ b/middleware/middleware.go
@@ -67,6 +67,11 @@ func New(cfg Config) Middleware {
 func (m Middleware) Measure(handlerID string, reporter Reporter, next func()) {
 	ctx := reporter.Context()
 
+	var customLabels []string
+	if cr, ok := reporter.(CustomLabelReporter); ok {
+		customLabels = cr.CustomLabels()
+	}
+
 	// If there isn't predefined handler ID we
 	// set that ID as the URL path.
 	hid := handlerID
@@ -77,9 +82,11 @@ func (m Middleware) Measure(handlerID string, reporter Reporter, next func()) {
 	// Measure inflights if required.
 	if !m.cfg.DisableMeasureInflight {
 		props := metrics.HTTPProperties{
-			Service: m.cfg.Service,
-			ID:      hid,
+			Service:      m.cfg.Service,
+			ID:           hid,
+			CustomLabels: customLabels,
 		}
+
 		m.cfg.Recorder.AddInflightRequests(ctx, props, 1)
 		defer m.cfg.Recorder.AddInflightRequests(ctx, props, -1)
 	}
@@ -100,10 +107,11 @@ func (m Middleware) Measure(handlerID string, reporter Reporter, next func()) {
 		}
 
 		props := metrics.HTTPReqProperties{
-			Service: m.cfg.Service,
-			ID:      hid,
-			Method:  reporter.Method(),
-			Code:    code,
+			Service:      m.cfg.Service,
+			ID:           hid,
+			Method:       reporter.Method(),
+			Code:         code,
+			CustomLabels: customLabels,
 		}
 		m.cfg.Recorder.ObserveHTTPRequestDuration(ctx, props, duration)
 
@@ -126,3 +134,8 @@ type Reporter interface {
 	StatusCode() int
 	BytesWritten() int64
 }
+
+type CustomLabelReporter interface {
+	Reporter
+	CustomLabels() []string
+}
diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go
index 0e70e68..2255274 100644
--- a/middleware/middleware_test.go
+++ b/middleware/middleware_test.go
@@ -14,19 +14,23 @@ import (
 )
 
 func TestMiddlewareMeasure(t *testing.T) {
-	tests := map[string]struct {
+	tests := []struct {
+		name      string
 		handlerID string
-		config    func() middleware.Config
-		mock      func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter)
+		config    middleware.Config
+		recorder  func() metrics.Recorder
+		setup     func() (metrics.Recorder, middleware.Reporter, func(t *testing.T))
 	}{
-		"Having default config with service, it should measure the metrics.": {
+		{
+			name:      "Having default config with service, it should measure the metrics.",
 			handlerID: "test01",
-			config: func() middleware.Config {
-				return middleware.Config{
-					Service: "svc1",
-				}
+			config: middleware.Config{
+				Service: "svc1",
 			},
-			mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) {
+			setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) {
+				mrec := &mockmetrics.Recorder{}
+				mrep := &mockmiddleware.Reporter{}
+
 				// Reporter mocks.
 				mrep.On("Context").Once().Return(context.TODO())
 				mrep.On("StatusCode").Once().Return(418)
@@ -41,15 +45,61 @@ func TestMiddlewareMeasure(t *testing.T) {
 				mrec.On("AddInflightRequests", mock.Anything, expProps, -1).Once()
 				mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once()
 				mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, int64(42)).Once()
+
+				return mrec, mrep, func(t *testing.T) {
+					mrec.AssertExpectations(t)
+					mrep.AssertExpectations(t)
+				}
 			},
 		},
+		{
+			name:      "Custom labels should work",
+			handlerID: "test01",
+			config: middleware.Config{
+				Service: "svc1",
+			},
+			setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) {
+				mrec := &mockmetrics.Recorder{}
+				mrep := &mockmiddleware.CustomLabelReporter{}
 
-		"Without having handler ID, it should measure the metrics using the request path.": {
-			handlerID: "",
-			config: func() middleware.Config {
-				return middleware.Config{}
+				mrep.On("Context").Once().Return(context.TODO())
+				mrep.On("StatusCode").Once().Return(418)
+				mrep.On("Method").Once().Return("PATCH")
+				mrep.On("BytesWritten").Once().Return(int64(42))
+				mrep.On("CustomLabels").Once().Return([]string{"user_VIP"})
+
+				expProps := metrics.HTTPProperties{
+					Service:      "svc1",
+					ID:           "test01",
+					CustomLabels: []string{"user_VIP"},
+				}
+				expRepProps := metrics.HTTPReqProperties{
+					Service:      "svc1",
+					ID:           "test01",
+					Method:       "PATCH",
+					Code:         "418",
+					CustomLabels: []string{"user_VIP"},
+				}
+
+				mrec.On("AddInflightRequests", mock.Anything, expProps, 1).Once()
+				mrec.On("AddInflightRequests", mock.Anything, expProps, -1).Once()
+				mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once()
+				mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, int64(42)).Once()
+
+				return mrec, mrep, func(t *testing.T) {
+					mrec.AssertExpectations(t)
+					mrep.AssertExpectations(t)
+				}
 			},
-			mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) {
+		},
+		{
+			name:      "Without having handler ID, it should measure the metrics using the request path.",
+			handlerID: "",
+			config:    middleware.Config{},
+			setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) {
+				mrec := &mockmetrics.Recorder{}
+				mrep := &mockmiddleware.Reporter{}
+
 				// Reporter mocks.
 				mrep.On("URLPath").Once().Return("/test/01")
 				mrep.On("Context").Once().Return(context.TODO())
@@ -64,17 +114,23 @@ func TestMiddlewareMeasure(t *testing.T) {
 				mrec.On("AddInflightRequests", mock.Anything, mock.Anything, mock.Anything).Once()
 				mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once()
 				mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, mock.Anything).Once()
+
+				return mrec, mrep, func(t *testing.T) {
+					mrec.AssertExpectations(t)
+					mrep.AssertExpectations(t)
+				}
 			},
 		},
-
-		"Having grouped status code, it should measure the metrics using grouped status codes.": {
+		{
+			name:      "Having grouped status code, it should measure the metrics using grouped status codes.",
 			handlerID: "test01",
-			config: func() middleware.Config {
-				return middleware.Config{
-					GroupedStatus: true,
-				}
+			config: middleware.Config{
+				GroupedStatus: true,
 			},
-			mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) {
+			setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) {
+				mrec := &mockmetrics.Recorder{}
+				mrep := &mockmiddleware.Reporter{}
+
 				// Reporter mocks.
 				mrep.On("Context").Once().Return(context.TODO())
 				mrep.On("StatusCode").Once().Return(418)
@@ -88,17 +144,23 @@ func TestMiddlewareMeasure(t *testing.T) {
 				mrec.On("AddInflightRequests", mock.Anything, mock.Anything, mock.Anything).Once()
 				mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once()
 				mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, mock.Anything).Once()
+
+				return mrec, mrep, func(t *testing.T) {
+					mrec.AssertExpectations(t)
+					mrep.AssertExpectations(t)
+				}
 			},
 		},
-
-		"Disabling inflight requests measuring, it shouldn't measure inflight metrics.": {
+		{
+			name:      "Disabling inflight requests measuring, it shouldn't measure inflight metrics.",
 			handlerID: "test01",
-			config: func() middleware.Config {
-				return middleware.Config{
-					DisableMeasureInflight: true,
-				}
+			config: middleware.Config{
+				DisableMeasureInflight: true,
 			},
-			mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) {
+			setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) {
+				mrec := &mockmetrics.Recorder{}
+				mrep := &mockmiddleware.Reporter{}
+
 				// Reporter mocks.
 				mrep.On("Context").Once().Return(context.TODO())
 				mrep.On("StatusCode").Once().Return(418)
@@ -110,17 +172,23 @@ func TestMiddlewareMeasure(t *testing.T) {
 
 				mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once()
 				mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, mock.Anything).Once()
+
+				return mrec, mrep, func(t *testing.T) {
+					mrec.AssertExpectations(t)
+					mrep.AssertExpectations(t)
+				}
 			},
 		},
-
-		"Disabling size measuring, it shouldn't measure size metrics.": {
+		{
+			name:      "Disabling size measuring, it shouldn't measure size metrics.",
 			handlerID: "test01",
-			config: func() middleware.Config {
-				return middleware.Config{
-					DisableMeasureSize: true,
-				}
+			config: middleware.Config{
+				DisableMeasureSize: true,
 			},
-			mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) {
+			setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) {
+				mrec := &mockmetrics.Recorder{}
+				mrep := &mockmiddleware.Reporter{}
+
 				// Reporter mocks.
 				mrep.On("Context").Once().Return(context.TODO())
 				mrep.On("StatusCode").Once().Return(418)
@@ -132,31 +200,27 @@ func TestMiddlewareMeasure(t *testing.T) {
 				mrec.On("AddInflightRequests", mock.Anything, mock.Anything, mock.Anything).Once()
 				mrec.On("AddInflightRequests", mock.Anything, mock.Anything, mock.Anything).Once()
 				mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once()
+
+				return mrec, mrep, func(t *testing.T) {
+					mrec.AssertExpectations(t)
+					mrep.AssertExpectations(t)
+				}
 			},
 		},
 	}
 
-	for name, test := range tests {
-		t.Run(name, func(t *testing.T) {
-			assert := assert.New(t)
-
-			// Mocks.
-			mrec := &mockmetrics.Recorder{}
-			mrep := &mockmiddleware.Reporter{}
-			test.mock(mrec, mrep)
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			mrec, mrep, cleanup := tc.setup()
 
-			// Execute.
-			config := test.config()
-			config.Recorder = mrec // Set mocked recorder.
-			mdlw := middleware.New(config)
+			tc.config.Recorder = mrec
+			mdlw := middleware.New(tc.config)
 
 			calledNext := false
-			mdlw.Measure(test.handlerID, mrep, func() { calledNext = true })
+			mdlw.Measure(tc.handlerID, mrep, func() { calledNext = true })
 
-			// Check.
-			mrec.AssertExpectations(t)
-			mrep.AssertExpectations(t)
-			assert.True(calledNext)
+			cleanup(t)
+			assert.True(t, calledNext)
 		})
 	}
 }