Skip to content

Commit

Permalink
Merge pull request #12 from carverauto/updates/cleanup
Browse files Browse the repository at this point in the history
Updates/cleanup
  • Loading branch information
mfreeman451 authored Oct 6, 2024
2 parents 1fad3b7 + c31c148 commit ba200d8
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package main
import (
"github.com/carverauto/eventrunner/cmd/api/migrations"
"github.com/carverauto/eventrunner/pkg/api/handlers"
"github.com/carverauto/eventrunner/pkg/api/middleware"
middlewarePkg "github.com/carverauto/eventrunner/pkg/api/middleware"

Check failure on line 6 in cmd/api/main.go

View workflow job for this annotation

GitHub Actions / lint

"github.com/carverauto/eventrunner/pkg/api/middleware" imported as middlewarePkg and not used
"gofr.dev/pkg/gofr"
"gofr.dev/pkg/gofr/datasource/mongo"
)
Expand Down
13 changes: 9 additions & 4 deletions cmd/event-ingest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"errors"
"log"

"github.com/carverauto/eventrunner/pkg/api/middleware"
Expand All @@ -25,6 +26,8 @@ func main() {
log.Fatalf("Failed to initialize JWT middleware: %v", err)
}

app.UseMiddleware(middleware.CustomHeadersMiddleware())

// Set up gRPC connection to API
grpcServerAddress := app.Config.Get("GRPC_SERVER_ADDRESS")

Expand Down Expand Up @@ -57,8 +60,11 @@ func main() {
// combineMiddleware chains multiple middleware functions together
func combineMiddleware(middlewares ...interface{}) gofr.Handler {
return func(c *gofr.Context) (interface{}, error) {
// Create the initial custom context from the GoFr context
cc := customctx.NewCustomContext(c)
// Retrieve the custom context from the original context
customCtx, ok := c.Request.Context().Value("customCtx").(*customctx.CustomContext)
if !ok {
return nil, errors.New("failed to retrieve custom context")
}

// Define the final handler that will be called after applying all middleware
finalHandler := func(ctx customctx.Context) (interface{}, error) {
Expand All @@ -74,7 +80,6 @@ func combineMiddleware(middlewares ...interface{}) gofr.Handler {
finalHandler = m
} else {
// Wrap the final handler in the current function
// nextHandler := finalHandler
finalHandler = func(ctx customctx.Context) (interface{}, error) {
return m(ctx)
}
Expand All @@ -86,6 +91,6 @@ func combineMiddleware(middlewares ...interface{}) gofr.Handler {
}

// Execute the final middleware chain with the custom context
return finalHandler(cc)
return finalHandler(customCtx)
}
}
39 changes: 39 additions & 0 deletions pkg/api/middleware/http_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Package middleware pkg/api/middleware/http_headers.go
package middleware

import (
"context"
"net/http"

customctx "github.com/carverauto/eventrunner/pkg/context"
gofr "gofr.dev/pkg/gofr"
gofrHTTP "gofr.dev/pkg/gofr/http"
)

func CustomHeadersMiddleware() gofrHTTP.Middleware {
return func(inner http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use the existing context to create a custom context
gofrCtx := &gofr.Context{
Request: gofrHTTP.NewRequest(r),
}
customCtx := customctx.NewCustomContext(gofrCtx)

// Extract headers from the HTTP request and store them in the custom context
for key, values := range r.Header {
if len(values) > 0 {
customCtx.SetClaim(key, values[0])
}
}

// Create a new context with the custom context
ctxWithCustom := context.WithValue(r.Context(), "customCtx", customCtx)

// Store the custom context back into the request
r = r.WithContext(ctxWithCustom)

// Call the next handler in the chain
inner.ServeHTTP(w, r)
})
}
}
48 changes: 48 additions & 0 deletions pkg/api/middleware/http_headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Package middleware pkg/api/middleware/http_headers_test.go
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

customctx "github.com/carverauto/eventrunner/pkg/context"
"github.com/stretchr/testify/assert"
)

func TestCustomHeadersMiddleware(t *testing.T) {
// Create a sample HTTP request with headers
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("X-Custom-Header", "CustomValue")
req.Header.Set("Authorization", "Bearer token")

// Create a response recorder to capture the response
rr := httptest.NewRecorder()

// Create a mock final handler that will be called after the middleware
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract the custom context from the request context
customCtx, ok := r.Context().Value("customCtx").(*customctx.CustomContext)
if !ok {
t.Errorf("Failed to retrieve custom context")
return
}

// Validate that the headers were correctly set in the custom context
headerValue, _ := customCtx.GetClaim("X-Custom-Header")
assert.Equal(t, "CustomValue", headerValue)

authValue, _ := customCtx.GetClaim("Authorization")
assert.Equal(t, "Bearer token", authValue)
})

// Wrap the handler with the CustomHeadersMiddleware
middleware := CustomHeadersMiddleware()
wrappedHandler := middleware(handler)

// Serve the request
wrappedHandler.ServeHTTP(rr, req)

// Verify the response status code
assert.Equal(t, http.StatusOK, rr.Code)
}
18 changes: 18 additions & 0 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@ package context
import (
"github.com/google/uuid"
"gofr.dev/pkg/gofr"
"net/http"
)

type CustomContext struct {
gofrContext *gofr.Context
claims map[string]interface{}
headers http.Header
}

// NewCustomContext creates a new Context.
func NewCustomContext(c *gofr.Context) *CustomContext {
return &CustomContext{
gofrContext: c,
claims: make(map[string]interface{}),
headers: http.Header{},
}
}

Expand Down Expand Up @@ -106,3 +109,18 @@ func (c *CustomContext) Bind(v interface{}) error {
func (c *CustomContext) Context() *gofr.Context {
return c.gofrContext
}

// SetHeader sets an HTTP header.
func (c *CustomContext) SetHeader(key, value string) {
c.headers.Set(key, value)
}

// GetHeader retrieves an HTTP header value.
func (c *CustomContext) GetHeader(key string) string {
return c.headers.Get(key)
}

// Headers returns all HTTP headers.
func (c *CustomContext) Headers() http.Header {
return c.headers
}

0 comments on commit ba200d8

Please sign in to comment.