Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions auth/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,17 @@ const (
func NewAuthClient(address string, enabled bool, logger *log.Logger) *AuthClient {
var l log.Logger

var err error

if logger != nil {
l = *logger
} else {
l = zap.InitializeLogger()
l, err = zap.InitializeLoggerWithError()
if err != nil {
l = &log.NoneLogger{}

l.Errorf("failed to initialize logger, using NoneLogger: %v\n", err)
}
}

if !enabled || address == "" {
Expand Down Expand Up @@ -112,8 +119,7 @@ func (auth *AuthClient) Authorize(sub, resource, action string) fiber.Handler {
return func(c *fiber.Ctx) error {
ctx := opentelemetry.ExtractHTTPContext(c)

tracer := commons.NewTracerFromContext(ctx)
reqID := commons.NewHeaderIDFromContext(ctx)
_, tracer, reqID, _ := commons.NewTrackingFromContext(ctx)

if !auth.Enabled || auth.Address == "" {
return c.Next()
Expand Down Expand Up @@ -158,8 +164,7 @@ func (auth *AuthClient) Authorize(sub, resource, action string) fiber.Handler {

// checkAuthorization sends an authorization request to the external service and returns whether the action is authorized.
func (auth *AuthClient) checkAuthorization(ctx context.Context, sub, resource, action, accessToken string) (bool, int, error) {
tracer := commons.NewTracerFromContext(ctx)
reqID := commons.NewHeaderIDFromContext(ctx)
_, tracer, reqID, _ := commons.NewTrackingFromContext(ctx)

ctx, span := tracer.Start(ctx, "lib_auth.check_authorization")
defer span.End()
Expand Down Expand Up @@ -298,8 +303,7 @@ func (auth *AuthClient) checkAuthorization(ctx context.Context, sub, resource, a
// It takes the client ID and client secret as parameters and returns the access token if the request is successful.
// If the request fails at any step, an error is returned with a descriptive message.
func (auth *AuthClient) GetApplicationToken(ctx context.Context, clientID, clientSecret string) (string, error) {
tracer := commons.NewTracerFromContext(ctx)
reqID := commons.NewHeaderIDFromContext(ctx)
_, tracer, reqID, _ := commons.NewTrackingFromContext(ctx)

ctx, span := tracer.Start(ctx, "lib_auth.get_application_token")
defer span.End()
Expand Down
128 changes: 126 additions & 2 deletions auth/middleware/middlewareGRPC.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package middleware

import (
"context"
"errors"
"fmt"
"net/http"
"os"
"strings"

"github.com/LerianStudio/lib-commons/v3/commons"
"github.com/LerianStudio/lib-commons/v3/commons/opentelemetry"
jwt "github.com/golang-jwt/jwt/v5"
"go.opentelemetry.io/otel/attribute"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -47,8 +50,7 @@ func NewGRPCAuthUnaryPolicy(auth *AuthClient, cfg PolicyConfig) grpc.UnaryServer
}

token, ok := extractTokenFromMD(ctx)
tracer := commons.NewTracerFromContext(ctx)
reqID := commons.NewHeaderIDFromContext(ctx)
_, tracer, reqID, _ := commons.NewTrackingFromContext(ctx)

ctx, span := tracer.Start(ctx, "lib_auth.authorize_grpc_unary_policy")
defer span.End()
Expand Down Expand Up @@ -97,6 +99,27 @@ func NewGRPCAuthUnaryPolicy(auth *AuthClient, cfg PolicyConfig) grpc.UnaryServer
return nil, status.Error(codes.PermissionDenied, "forbidden")
}

// Propagate tenant claims if multi-tenant mode is enabled
if os.Getenv("MULTI_TENANT_ENABLED") == "true" {
tenantID, tenantSlug, tOwner, _ := extractTenantClaims(token)
md, _ := metadata.FromIncomingContext(ctx)
md = md.Copy()

if tenantID != "" {
md.Set("md-tenant-id", tenantID)
}

if tenantSlug != "" {
md.Set("md-tenant-slug", tenantSlug)
}

if tOwner != "" {
md.Set("md-tenant-owner", tOwner)
}

ctx = metadata.NewIncomingContext(ctx, md)
}

return handler(ctx, req)
}
}
Expand Down Expand Up @@ -179,3 +202,104 @@ func SubFromMetadata(key string) func(ctx context.Context, fullMethod string, re
return vals[0], nil
}
}

// extractTenantClaims extracts tenant-related claims from a JWT without signature verification.
// Returns tenantID, tenantSlug, and owner from the token's custom claims.
// Used by gRPC interceptors to propagate tenant context to downstream services.
func extractTenantClaims(tokenString string) (tenantID, tenantSlug, owner string, err error) {
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return "", "", "", err
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", "", "", errors.New("invalid token claims")
}

tenantID, _ = claims["tenantId"].(string)
tenantSlug, _ = claims["tenantSlug"].(string)
owner, _ = claims["owner"].(string)

return tenantID, tenantSlug, owner, nil
}

// NewGRPCAuthStreamPolicy authorizes streaming RPCs via per-method Policy.
// Mirrors NewGRPCAuthUnaryPolicy behavior for streaming calls:
// - Resolves Policy by info.FullMethod; falls back to DefaultPolicy.
// - Rejects missing tokens with codes.Unauthenticated.
// - Propagates tenant claims when MULTI_TENANT_ENABLED=true.
func NewGRPCAuthStreamPolicy(auth *AuthClient, cfg PolicyConfig) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if auth == nil || !auth.Enabled || auth.Address == "" {
return handler(srv, ss)
}

ctx := ss.Context()
token, ok := extractTokenFromMD(ctx)

if !ok || commons.IsNilOrEmpty(&token) {
return status.Error(codes.Unauthenticated, "missing token")
}

pol, found := policyForMethod(cfg, info.FullMethod)
if !found {
return status.Error(codes.Internal, "internal configuration error")
}

var sub string

if cfg.SubResolver != nil {
var err error

sub, err = cfg.SubResolver(ctx, info.FullMethod, nil)
if err != nil {
return status.Error(codes.Internal, "internal configuration error")
}
}

authorized, httpStatus, err := auth.checkAuthorization(ctx, sub, pol.Resource, pol.Action, token)
if err != nil {
return grpcErrorFromHTTP(httpStatus)
}

if !authorized {
return status.Error(codes.PermissionDenied, "forbidden")
}

// Propagate tenant claims if multi-tenant mode is enabled
if os.Getenv("MULTI_TENANT_ENABLED") == "true" {
tenantID, tenantSlug, tOwner, _ := extractTenantClaims(token)
md, _ := metadata.FromIncomingContext(ctx)
md = md.Copy()

if tenantID != "" {
md.Set("md-tenant-id", tenantID)
}

if tenantSlug != "" {
md.Set("md-tenant-slug", tenantSlug)
}

if tOwner != "" {
md.Set("md-tenant-owner", tOwner)
}

ctx = metadata.NewIncomingContext(ctx, md)
ss = &wrappedServerStream{ServerStream: ss, ctx: ctx}
}

return handler(srv, ss)
}
}

// wrappedServerStream wraps grpc.ServerStream to override Context().
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}

// Context returns the wrapped context.
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}
Loading
Loading