diff --git a/ee/http/middleware/pat.go b/ee/http/middleware/pat.go new file mode 100644 index 0000000000..3cc500ce53 --- /dev/null +++ b/ee/http/middleware/pat.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "net/http" + + "go.signoz.io/signoz/pkg/types/authtypes" +) + +type Pat struct { + uuid *authtypes.UUID + headers []string +} + +func NewPat(headers []string) *Pat { + return &Pat{uuid: authtypes.NewUUID(), headers: headers} +} + +func (p *Pat) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var values []string + for _, header := range p.headers { + values = append(values, r.Header.Get(header)) + } + + ctx, err := p.uuid.ContextFromRequest(r.Context(), values...) + if err != nil { + next.ServeHTTP(w, r) + return + } + + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) + +} diff --git a/ee/query-service/app/api/api.go b/ee/query-service/app/api/api.go index 6c507179fb..24989c8140 100644 --- a/ee/query-service/app/api/api.go +++ b/ee/query-service/app/api/api.go @@ -20,6 +20,7 @@ import ( basemodel "go.signoz.io/signoz/pkg/query-service/model" rules "go.signoz.io/signoz/pkg/query-service/rules" "go.signoz.io/signoz/pkg/query-service/version" + "go.signoz.io/signoz/pkg/types/authtypes" ) type APIHandlerOptions struct { @@ -41,6 +42,7 @@ type APIHandlerOptions struct { FluxInterval time.Duration UseLogsNewSchema bool UseTraceNewSchema bool + JWT *authtypes.JWT } type APIHandler struct { diff --git a/ee/query-service/app/api/auth.go b/ee/query-service/app/api/auth.go index 9a28fce263..23ddeb1d0f 100644 --- a/ee/query-service/app/api/auth.go +++ b/ee/query-service/app/api/auth.go @@ -50,7 +50,7 @@ func (ah *APIHandler) loginUser(w http.ResponseWriter, r *http.Request) { } // if all looks good, call auth - resp, err := baseauth.Login(ctx, &req) + resp, err := baseauth.Login(ctx, &req, ah.opts.JWT) if ah.HandleError(w, err, http.StatusUnauthorized) { return } @@ -253,7 +253,7 @@ func (ah *APIHandler) receiveGoogleAuth(w http.ResponseWriter, r *http.Request) return } - nextPage, err := ah.AppDao().PrepareSsoRedirect(ctx, redirectUri, identity.Email) + nextPage, err := ah.AppDao().PrepareSsoRedirect(ctx, redirectUri, identity.Email, ah.opts.JWT) if err != nil { zap.L().Error("[receiveGoogleAuth] failed to generate redirect URI after successful login ", zap.String("domain", domain.String()), zap.Error(err)) handleSsoError(w, r, redirectUri) @@ -331,7 +331,7 @@ func (ah *APIHandler) receiveSAML(w http.ResponseWriter, r *http.Request) { return } - nextPage, err := ah.AppDao().PrepareSsoRedirect(ctx, redirectUri, email) + nextPage, err := ah.AppDao().PrepareSsoRedirect(ctx, redirectUri, email, ah.opts.JWT) if err != nil { zap.L().Error("[receiveSAML] failed to generate redirect URI after successful login ", zap.String("domain", domain.String()), zap.Error(err)) handleSsoError(w, r, redirectUri) diff --git a/ee/query-service/app/api/cloudIntegrations.go b/ee/query-service/app/api/cloudIntegrations.go index 5cfc5f06ba..8f7568480b 100644 --- a/ee/query-service/app/api/cloudIntegrations.go +++ b/ee/query-service/app/api/cloudIntegrations.go @@ -37,7 +37,7 @@ func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseW return } - currentUser, err := auth.GetUserFromRequest(r) + currentUser, err := auth.GetUserFromReqContext(r.Context()) if err != nil { RespondError(w, basemodel.UnauthorizedError(fmt.Errorf( "couldn't deduce current user: %w", err, diff --git a/ee/query-service/app/api/pat.go b/ee/query-service/app/api/pat.go index 3ff8be74a2..08304788e8 100644 --- a/ee/query-service/app/api/pat.go +++ b/ee/query-service/app/api/pat.go @@ -34,7 +34,7 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { RespondError(w, model.BadRequest(err), nil) return } - user, err := auth.GetUserFromRequest(r) + user, err := auth.GetUserFromReqContext(r.Context()) if err != nil { RespondError(w, &model.ApiError{ Typ: model.ErrorUnauthorized, @@ -97,7 +97,7 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) { return } - user, err := auth.GetUserFromRequest(r) + user, err := auth.GetUserFromReqContext(r.Context()) if err != nil { RespondError(w, &model.ApiError{ Typ: model.ErrorUnauthorized, @@ -127,7 +127,7 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { ctx := context.Background() - user, err := auth.GetUserFromRequest(r) + user, err := auth.GetUserFromReqContext(r.Context()) if err != nil { RespondError(w, &model.ApiError{ Typ: model.ErrorUnauthorized, @@ -147,7 +147,7 @@ func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) { ctx := context.Background() id := mux.Vars(r)["id"] - user, err := auth.GetUserFromRequest(r) + user, err := auth.GetUserFromReqContext(r.Context()) if err != nil { RespondError(w, &model.ApiError{ Typ: model.ErrorUnauthorized, diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index a8addbc177..8ed5c83d75 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -14,6 +14,7 @@ import ( "github.com/rs/cors" "github.com/soheilhy/cmux" + eemiddleware "go.signoz.io/signoz/ee/http/middleware" "go.signoz.io/signoz/ee/query-service/app/api" "go.signoz.io/signoz/ee/query-service/app/db" "go.signoz.io/signoz/ee/query-service/auth" @@ -24,6 +25,7 @@ import ( "go.signoz.io/signoz/ee/query-service/rules" "go.signoz.io/signoz/pkg/http/middleware" "go.signoz.io/signoz/pkg/signoz" + "go.signoz.io/signoz/pkg/types/authtypes" "go.signoz.io/signoz/pkg/web" licensepkg "go.signoz.io/signoz/ee/query-service/license" @@ -72,6 +74,7 @@ type ServerOptions struct { GatewayUrl string UseLogsNewSchema bool UseTraceNewSchema bool + Jwt *authtypes.JWT } // Server runs HTTP api service @@ -261,6 +264,7 @@ func NewServer(serverOptions *ServerOptions) (*Server, error) { GatewayUrl: serverOptions.GatewayUrl, UseLogsNewSchema: serverOptions.UseLogsNewSchema, UseTraceNewSchema: serverOptions.UseTraceNewSchema, + JWT: serverOptions.Jwt, } apiHandler, err := api.NewAPIHandler(apiOpts) @@ -303,6 +307,8 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server, r := baseapp.NewRouter() + r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) + r.Use(eemiddleware.NewPat([]string{"SIGNOZ-API-KEY"}).Wrap) r.Use(middleware.NewTimeout(zap.L(), s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes, s.serverOptions.Config.APIServer.Timeout.Default, @@ -334,8 +340,8 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*h r := baseapp.NewRouter() // add auth middleware - getUserFromRequest := func(r *http.Request) (*basemodel.UserPayload, error) { - user, err := auth.GetUserFromRequest(r, apiHandler) + getUserFromRequest := func(ctx context.Context) (*basemodel.UserPayload, error) { + user, err := auth.GetUserFromRequestContext(ctx, apiHandler) if err != nil { return nil, err @@ -349,6 +355,8 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*h } am := baseapp.NewAuthMiddleware(getUserFromRequest) + r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) + r.Use(eemiddleware.NewPat([]string{"SIGNOZ-API-KEY"}).Wrap) r.Use(middleware.NewTimeout(zap.L(), s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes, s.serverOptions.Config.APIServer.Timeout.Default, diff --git a/ee/query-service/auth/auth.go b/ee/query-service/auth/auth.go index d45d050cca..ebaefa700d 100644 --- a/ee/query-service/auth/auth.go +++ b/ee/query-service/auth/auth.go @@ -3,20 +3,20 @@ package auth import ( "context" "fmt" - "net/http" "time" "go.signoz.io/signoz/ee/query-service/app/api" baseauth "go.signoz.io/signoz/pkg/query-service/auth" basemodel "go.signoz.io/signoz/pkg/query-service/model" "go.signoz.io/signoz/pkg/query-service/telemetry" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" ) -func GetUserFromRequest(r *http.Request, apiHandler *api.APIHandler) (*basemodel.UserPayload, error) { - patToken := r.Header.Get("SIGNOZ-API-KEY") - if len(patToken) > 0 { +func GetUserFromRequestContext(ctx context.Context, apiHandler *api.APIHandler) (*basemodel.UserPayload, error) { + patToken, ok := authtypes.UUIDFromContext(ctx) + if ok && patToken != "" { zap.L().Debug("Received a non-zero length PAT token") ctx := context.Background() dao := apiHandler.AppDao() @@ -52,5 +52,5 @@ func GetUserFromRequest(r *http.Request, apiHandler *api.APIHandler) (*basemodel return nil, err } } - return baseauth.GetUserFromRequest(r) + return baseauth.GetUserFromReqContext(ctx) } diff --git a/ee/query-service/dao/interface.go b/ee/query-service/dao/interface.go index 2fc81468d5..1708a4ada2 100644 --- a/ee/query-service/dao/interface.go +++ b/ee/query-service/dao/interface.go @@ -10,6 +10,7 @@ import ( basedao "go.signoz.io/signoz/pkg/query-service/dao" baseint "go.signoz.io/signoz/pkg/query-service/interfaces" basemodel "go.signoz.io/signoz/pkg/query-service/model" + "go.signoz.io/signoz/pkg/types/authtypes" ) type ModelDao interface { @@ -22,7 +23,7 @@ type ModelDao interface { // auth methods CanUsePassword(ctx context.Context, email string) (bool, basemodel.BaseApiError) - PrepareSsoRedirect(ctx context.Context, redirectUri, email string) (redirectURL string, apierr basemodel.BaseApiError) + PrepareSsoRedirect(ctx context.Context, redirectUri, email string, jwt *authtypes.JWT) (redirectURL string, apierr basemodel.BaseApiError) GetDomainFromSsoResponse(ctx context.Context, relayState *url.URL) (*model.OrgDomain, error) // org domain (auth domains) CRUD ops diff --git a/ee/query-service/dao/sqlite/auth.go b/ee/query-service/dao/sqlite/auth.go index b8bc5e0fa0..f97a6736be 100644 --- a/ee/query-service/dao/sqlite/auth.go +++ b/ee/query-service/dao/sqlite/auth.go @@ -14,6 +14,7 @@ import ( baseconst "go.signoz.io/signoz/pkg/query-service/constants" basemodel "go.signoz.io/signoz/pkg/query-service/model" "go.signoz.io/signoz/pkg/query-service/utils" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" ) @@ -64,7 +65,7 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) ( // PrepareSsoRedirect prepares redirect page link after SSO response // is successfully parsed (i.e. valid email is available) -func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email string) (redirectURL string, apierr basemodel.BaseApiError) { +func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email string, jwt *authtypes.JWT) (redirectURL string, apierr basemodel.BaseApiError) { userPayload, apierr := m.GetUserByEmail(ctx, email) if !apierr.IsNil() { @@ -85,7 +86,7 @@ func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email st user = &userPayload.User } - tokenStore, err := baseauth.GenerateJWTForUser(user) + tokenStore, err := baseauth.GenerateJWTForUser(user, jwt) if err != nil { zap.L().Error("failed to generate token for SSO login user", zap.Error(err)) return "", model.InternalErrorStr("failed to generate token for the user") diff --git a/ee/query-service/license/manager.go b/ee/query-service/license/manager.go index bce2d3d4dc..992490e644 100644 --- a/ee/query-service/license/manager.go +++ b/ee/query-service/license/manager.go @@ -10,8 +10,8 @@ import ( "sync" - "go.signoz.io/signoz/pkg/query-service/auth" baseconstants "go.signoz.io/signoz/pkg/query-service/constants" + "go.signoz.io/signoz/pkg/types/authtypes" validate "go.signoz.io/signoz/ee/query-service/integrations/signozio" "go.signoz.io/signoz/ee/query-service/model" @@ -237,10 +237,10 @@ func (lm *Manager) ValidateV3(ctx context.Context) (reterr error) { func (lm *Manager) ActivateV3(ctx context.Context, licenseKey string) (licenseResponse *model.LicenseV3, errResponse *model.ApiError) { defer func() { if errResponse != nil { - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_LICENSE_ACT_FAILED, - map[string]interface{}{"err": errResponse.Err.Error()}, userEmail, true, false) + map[string]interface{}{"err": errResponse.Err.Error()}, claims.Email, true, false) } } }() diff --git a/ee/query-service/main.go b/ee/query-service/main.go index 9d471210b7..5fab8286c9 100644 --- a/ee/query-service/main.go +++ b/ee/query-service/main.go @@ -20,6 +20,7 @@ import ( baseconst "go.signoz.io/signoz/pkg/query-service/constants" "go.signoz.io/signoz/pkg/query-service/version" "go.signoz.io/signoz/pkg/signoz" + "go.signoz.io/signoz/pkg/types/authtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -154,6 +155,16 @@ func main() { zap.L().Fatal("Failed to create signoz struct", zap.Error(err)) } + jwtSecret := os.Getenv("SIGNOZ_JWT_SECRET") + + if len(jwtSecret) == 0 { + zap.L().Warn("No JWT secret key is specified.") + } else { + zap.L().Info("JWT secret key set successfully.") + } + + jwt := authtypes.NewJWT(jwtSecret, 30*time.Minute, 30*24*time.Hour) + serverOptions := &app.ServerOptions{ Config: config, SigNoz: signoz, @@ -171,15 +182,7 @@ func main() { GatewayUrl: gatewayUrl, UseLogsNewSchema: useLogsNewSchema, UseTraceNewSchema: useTraceNewSchema, - } - - // Read the jwt secret key - auth.JwtSecret = os.Getenv("SIGNOZ_JWT_SECRET") - - if len(auth.JwtSecret) == 0 { - zap.L().Warn("No JWT secret key is specified.") - } else { - zap.L().Info("JWT secret key set successfully.") + Jwt: jwt, } server, err := app.NewServer(serverOptions) diff --git a/go.mod b/go.mod index 080da0822b..8345fb98fe 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/SigNoz/zap_otlp/zap_otlp_encoder v0.0.0-20230822164844-1b861a431974 github.com/SigNoz/zap_otlp/zap_otlp_sync v0.0.0-20230822164844-1b861a431974 github.com/antonmedv/expr v1.15.3 - github.com/auth0/go-jwt-middleware v1.0.1 github.com/cespare/xxhash/v2 v2.3.0 github.com/coreos/go-oidc/v3 v3.11.0 github.com/dustin/go-humanize v1.0.1 @@ -24,7 +23,7 @@ require ( github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redismock/v8 v8.11.5 github.com/go-viper/mapstructure/v2 v2.1.0 - github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.1 @@ -112,7 +111,6 @@ require ( github.com/expr-lang/expr v1.16.9 // indirect github.com/facette/natsort v0.0.0-20181210072756-2cd4dd1e2dcb // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/form3tech-oss/jwt-go v3.2.5+incompatible // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.7.1 // indirect @@ -132,7 +130,6 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/gofrs/uuid v4.4.0+incompatible // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect diff --git a/go.sum b/go.sum index e4676c0993..ad7dc839f5 100644 --- a/go.sum +++ b/go.sum @@ -124,8 +124,6 @@ github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= -github.com/auth0/go-jwt-middleware v1.0.1 h1:/fsQ4vRr4zod1wKReUH+0A3ySRjGiT9G34kypO/EKwI= -github.com/auth0/go-jwt-middleware v1.0.1/go.mod h1:YSeUX3z6+TF2H+7padiEqNJ73Zy9vXW72U//IgN0BIM= github.com/aws/aws-sdk-go v1.38.35/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= @@ -246,9 +244,6 @@ github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/form3tech-oss/jwt-go v3.2.5+incompatible h1:/l4kBbb4/vGSsdtB5nUe8L7B9mImVMaBPw9L/0TBHU8= -github.com/form3tech-oss/jwt-go v3.2.5+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/frankban/quicktest v1.14.3/go.mod h1:mgiwOwqx65TmIk1wJ6Q7wvnVMocbUorkibMOrVTHZps= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -337,8 +332,6 @@ github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -444,13 +437,10 @@ github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gophercloud/gophercloud v1.14.0 h1:Bt9zQDhPrbd4qX7EILGmy+i7GP35cc+AAL2+wIJpUE8= github.com/gophercloud/gophercloud v1.14.0/go.mod h1:aAVqcocTSXh2vYFZ1JTvx4EQmfgzxRcNupUfxZbBNDM= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4= github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= -github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= @@ -864,9 +854,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= @@ -920,8 +907,6 @@ github.com/uptrace/bun/dialect/pgdialect v1.2.9 h1:caf5uFbOGiXvadV6pA5gn87k0awFF github.com/uptrace/bun/dialect/pgdialect v1.2.9/go.mod h1:m7L9JtOp/Lt8HccET70ULxplMweE/u0S9lNUSxz2duo= github.com/uptrace/bun/dialect/sqlitedialect v1.2.9 h1:HLzGWXBh07sT8zhVPy6veYbbGrAtYq0KzyRHXBj+GjA= github.com/uptrace/bun/dialect/sqlitedialect v1.2.9/go.mod h1:dUR+ecoCWA0FIa9vhQVRnGtYYPpuCLJoEEtX9E1aiBU= -github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= -github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/valyala/fastjson v1.6.4 h1:uAUNq9Z6ymTgGhcm0UynUAB6tlbakBrz6CQFax3BXVQ= github.com/valyala/fastjson v1.6.4/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/vjeantet/grok v1.0.1 h1:2rhIR7J4gThTgcZ1m2JY4TrJZNgjn985U28kT2wQrJ4= @@ -1366,7 +1351,6 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3 golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= diff --git a/pkg/http/middleware/analytics.go b/pkg/http/middleware/analytics.go index f9922d878f..76b20f71fd 100644 --- a/pkg/http/middleware/analytics.go +++ b/pkg/http/middleware/analytics.go @@ -7,12 +7,10 @@ import ( "net/http" "regexp" - // TODO(remove): Remove auth packages - "go.signoz.io/signoz/pkg/query-service/auth" - "github.com/gorilla/mux" v3 "go.signoz.io/signoz/pkg/query-service/model/v3" "go.signoz.io/signoz/pkg/query-service/telemetry" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" ) @@ -30,8 +28,6 @@ func NewAnalytics(logger *zap.Logger) *Analytics { func (a *Analytics) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := auth.AttachJwtToContext(r.Context(), r) - r = r.WithContext(ctx) route := mux.CurrentRoute(r) path, _ := route.GetPathTemplate() @@ -50,9 +46,9 @@ func (a *Analytics) Wrap(next http.Handler) http.Handler { } if _, ok := telemetry.EnabledPaths()[path]; ok { - userEmail, err := auth.GetEmailFromJwt(r.Context()) - if err == nil { - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, userEmail, true, false) + claims, ok := authtypes.ClaimsFromContext(r.Context()) + if ok { + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, claims.Email, true, false) } } @@ -138,8 +134,8 @@ func (a *Analytics) extractQueryRangeData(path string, r *http.Request) (map[str data["queryType"] = queryInfoResult.QueryType data["panelType"] = queryInfoResult.PanelType - userEmail, err := auth.GetEmailFromJwt(r.Context()) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(r.Context()) + if ok { // switch case to set data["screen"] based on the referrer switch { case dashboardMatched: @@ -154,7 +150,7 @@ func (a *Analytics) extractQueryRangeData(path string, r *http.Request) (map[str data["screen"] = "unknown" return data, true } - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_QUERY_RANGE_API, data, userEmail, true, false) + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_QUERY_RANGE_API, data, claims.Email, true, false) } } return data, true diff --git a/pkg/http/middleware/auth.go b/pkg/http/middleware/auth.go new file mode 100644 index 0000000000..5ee3ba7de4 --- /dev/null +++ b/pkg/http/middleware/auth.go @@ -0,0 +1,44 @@ +package middleware + +import ( + "net/http" + + "go.signoz.io/signoz/pkg/types/authtypes" + "go.uber.org/zap" +) + +type Auth struct { + logger *zap.Logger + jwt *authtypes.JWT + headers []string +} + +func NewAuth(logger *zap.Logger, jwt *authtypes.JWT, headers []string) *Auth { + if logger == nil { + panic("cannot build auth middleware, logger is empty") + } + + return &Auth{logger: logger, jwt: jwt, headers: headers} +} + +func (a *Auth) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var values []string + for _, header := range a.headers { + values = append(values, r.Header.Get(header)) + } + + ctx, err := a.jwt.ContextFromRequest( + r.Context(), + values...) + if err != nil { + next.ServeHTTP(w, r) + return + } + + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) + +} diff --git a/pkg/http/middleware/logging.go b/pkg/http/middleware/logging.go index 10b608d94c..fe3e463d71 100644 --- a/pkg/http/middleware/logging.go +++ b/pkg/http/middleware/logging.go @@ -11,8 +11,8 @@ import ( "github.com/gorilla/mux" semconv "go.opentelemetry.io/otel/semconv/v1.26.0" - "go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/common" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" ) @@ -133,7 +133,11 @@ func (middleware *Logging) getLogCommentKVs(r *http.Request) map[string]string { client = "api" } - email, _ := auth.GetEmailFromJwt(r.Context()) + var email string + claims, ok := authtypes.ClaimsFromContext(r.Context()) + if ok { + email = claims.Email + } kvs := map[string]string{ "path": path, diff --git a/pkg/query-service/app/auth.go b/pkg/query-service/app/auth.go index abdbdc9c9c..aa2529b017 100644 --- a/pkg/query-service/app/auth.go +++ b/pkg/query-service/app/auth.go @@ -12,10 +12,10 @@ import ( ) type AuthMiddleware struct { - GetUserFromRequest func(r *http.Request) (*model.UserPayload, error) + GetUserFromRequest func(r context.Context) (*model.UserPayload, error) } -func NewAuthMiddleware(f func(r *http.Request) (*model.UserPayload, error)) *AuthMiddleware { +func NewAuthMiddleware(f func(ctx context.Context) (*model.UserPayload, error)) *AuthMiddleware { return &AuthMiddleware{ GetUserFromRequest: f, } @@ -29,7 +29,7 @@ func (am *AuthMiddleware) OpenAccess(f func(http.ResponseWriter, *http.Request)) func (am *AuthMiddleware) ViewAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user, err := am.GetUserFromRequest(r) + user, err := am.GetUserFromRequest(r.Context()) if err != nil { RespondError(w, &model.ApiError{ Typ: model.ErrorUnauthorized, @@ -53,7 +53,7 @@ func (am *AuthMiddleware) ViewAccess(f func(http.ResponseWriter, *http.Request)) func (am *AuthMiddleware) EditAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user, err := am.GetUserFromRequest(r) + user, err := am.GetUserFromRequest(r.Context()) if err != nil { RespondError(w, &model.ApiError{ Typ: model.ErrorUnauthorized, @@ -76,7 +76,7 @@ func (am *AuthMiddleware) EditAccess(f func(http.ResponseWriter, *http.Request)) func (am *AuthMiddleware) SelfAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user, err := am.GetUserFromRequest(r) + user, err := am.GetUserFromRequest(r.Context()) if err != nil { RespondError(w, &model.ApiError{ Typ: model.ErrorUnauthorized, @@ -100,7 +100,7 @@ func (am *AuthMiddleware) SelfAccess(f func(http.ResponseWriter, *http.Request)) func (am *AuthMiddleware) AdminAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user, err := am.GetUserFromRequest(r) + user, err := am.GetUserFromRequest(r.Context()) if err != nil { RespondError(w, &model.ApiError{ Typ: model.ErrorUnauthorized, diff --git a/pkg/query-service/app/clickhouseReader/reader.go b/pkg/query-service/app/clickhouseReader/reader.go index a923b826cb..970872b07d 100644 --- a/pkg/query-service/app/clickhouseReader/reader.go +++ b/pkg/query-service/app/clickhouseReader/reader.go @@ -34,6 +34,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2/lib/driver" "github.com/jmoiron/sqlx" "go.signoz.io/signoz/pkg/cache" + "go.signoz.io/signoz/pkg/types/authtypes" promModel "github.com/prometheus/common/model" "go.uber.org/zap" @@ -43,7 +44,6 @@ import ( "go.signoz.io/signoz/pkg/query-service/app/resource" "go.signoz.io/signoz/pkg/query-service/app/services" "go.signoz.io/signoz/pkg/query-service/app/traces/tracedetail" - "go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/common" "go.signoz.io/signoz/pkg/query-service/constants" chErrors "go.signoz.io/signoz/pkg/query-service/errors" @@ -1164,23 +1164,23 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea if traceSummary.NumSpans > uint64(params.MaxSpansInTrace) { zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace), zap.Uint64("Count", traceSummary.NumSpans)) - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { data := map[string]interface{}{ "traceSize": traceSummary.NumSpans, "maxSpansInTraceLimit": params.MaxSpansInTrace, } - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_MAX_SPANS_ALLOWED_LIMIT_REACHED, data, userEmail, true, false) + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_MAX_SPANS_ALLOWED_LIMIT_REACHED, data, claims.Email, true, false) } return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details") } - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { data := map[string]interface{}{ "traceSize": traceSummary.NumSpans, } - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, data, userEmail, true, false) + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, data, claims.Email, true, false) } var startTime, endTime, durationNano uint64 @@ -1266,13 +1266,13 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea } end = time.Now() zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start))) - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { data := map[string]interface{}{ "traceSize": len(searchScanResponses), "spansRenderLimit": params.SpansRenderLimit, } - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LARGE_TRACE_OPENED, data, userEmail, true, false) + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LARGE_TRACE_OPENED, data, claims.Email, true, false) } } else { for i, item := range searchSpanResponses { @@ -1306,23 +1306,23 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc if countSpans > uint64(params.MaxSpansInTrace) { zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace), zap.Uint64("Count", countSpans)) - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { data := map[string]interface{}{ "traceSize": countSpans, "maxSpansInTraceLimit": params.MaxSpansInTrace, } - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_MAX_SPANS_ALLOWED_LIMIT_REACHED, data, userEmail, true, false) + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_MAX_SPANS_ALLOWED_LIMIT_REACHED, data, claims.Email, true, false) } return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details") } - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { data := map[string]interface{}{ "traceSize": countSpans, } - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, data, userEmail, true, false) + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, data, claims.Email, true, false) } var startTime, endTime, durationNano uint64 @@ -1379,13 +1379,13 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc } end = time.Now() zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start))) - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { data := map[string]interface{}{ "traceSize": len(searchScanResponses), "spansRenderLimit": params.SpansRenderLimit, } - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LARGE_TRACE_OPENED, data, userEmail, true, false) + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LARGE_TRACE_OPENED, data, claims.Email, true, false) } } else { for i, item := range searchSpanResponses { @@ -1455,7 +1455,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con var serviceNameIntervalMap = map[string][]tracedetail.Interval{} var hasMissingSpans bool - userEmail , emailErr := auth.GetEmailFromJwt(ctx) + claims, claimsPresent := authtypes.ClaimsFromContext(ctx) cachedTraceData, err := r.GetWaterfallSpansForTraceWithMetadataCache(ctx, traceID) if err == nil { startTime = cachedTraceData.StartTime @@ -1468,8 +1468,8 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con totalErrorSpans = cachedTraceData.TotalErrorSpans hasMissingSpans = cachedTraceData.HasMissingSpans - if emailErr == nil { - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, userEmail, true, false) + if claimsPresent { + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false) } } @@ -1485,8 +1485,8 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con } totalSpans = uint64(len(searchScanResponses)) - if emailErr == nil { - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, userEmail, true, false) + if claimsPresent { + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false) } processingBeforeCache := time.Now() @@ -1531,8 +1531,8 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con if startTime == 0 || startTimeUnixNano < startTime { startTime = startTimeUnixNano } - if endTime == 0 || (startTimeUnixNano + jsonItem.DurationNano ) > endTime { - endTime = (startTimeUnixNano + jsonItem.DurationNano ) + if endTime == 0 || (startTimeUnixNano+jsonItem.DurationNano) > endTime { + endTime = (startTimeUnixNano + jsonItem.DurationNano) } if durationNano == 0 || jsonItem.DurationNano > durationNano { durationNano = jsonItem.DurationNano @@ -1709,12 +1709,12 @@ func (r *ClickHouseReader) GetFlamegraphSpansForTrace(ctx context.Context, trace } // metadata calculation - startTimeUnixNano := uint64(item.TimeUnixNano.UnixNano()) + startTimeUnixNano := uint64(item.TimeUnixNano.UnixNano()) if startTime == 0 || startTimeUnixNano < startTime { startTime = startTimeUnixNano } - if endTime == 0 || ( startTimeUnixNano + jsonItem.DurationNano ) > endTime { - endTime = (startTimeUnixNano + jsonItem.DurationNano ) + if endTime == 0 || (startTimeUnixNano+jsonItem.DurationNano) > endTime { + endTime = (startTimeUnixNano + jsonItem.DurationNano) } if durationNano == 0 || jsonItem.DurationNano > durationNano { durationNano = jsonItem.DurationNano @@ -1778,7 +1778,7 @@ func (r *ClickHouseReader) GetFlamegraphSpansForTrace(ctx context.Context, trace trace.Spans = selectedSpansForRequest trace.StartTimestampMillis = startTime / 1000000 - trace.EndTimestampMillis = endTime / 1000000 + trace.EndTimestampMillis = endTime / 1000000 return trace, nil } @@ -3464,9 +3464,9 @@ func (r *ClickHouseReader) GetLogs(ctx context.Context, params *model.LogsFilter "lenFilters": lenFilters, } if lenFilters != 0 { - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, userEmail, true, false) + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) } } @@ -3506,9 +3506,9 @@ func (r *ClickHouseReader) TailLogs(ctx context.Context, client *model.LogsTailC "lenFilters": lenFilters, } if lenFilters != 0 { - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, userEmail, true, false) + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) } } @@ -3598,9 +3598,9 @@ func (r *ClickHouseReader) AggregateLogs(ctx context.Context, params *model.Logs "lenFilters": lenFilters, } if lenFilters != 0 { - userEmail, err := auth.GetEmailFromJwt(ctx) - if err == nil { - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, userEmail, true, false) + claims, ok := authtypes.ClaimsFromContext(ctx) + if ok { + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) } } diff --git a/pkg/query-service/app/explorer/db.go b/pkg/query-service/app/explorer/db.go index 07c9a18bfa..acb7aefc15 100644 --- a/pkg/query-service/app/explorer/db.go +++ b/pkg/query-service/app/explorer/db.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" "github.com/jmoiron/sqlx" - "go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/model" v3 "go.signoz.io/signoz/pkg/query-service/model/v3" "go.signoz.io/signoz/pkg/query-service/telemetry" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" ) @@ -125,13 +125,13 @@ func CreateView(ctx context.Context, view v3.SavedView) (string, error) { createdAt := time.Now() updatedAt := time.Now() - email, err := auth.GetEmailFromJwt(ctx) - if err != nil { - return "", err + claims, ok := authtypes.ClaimsFromContext(ctx) + if !ok { + return "", fmt.Errorf("error in getting email from context") } - createBy := email - updatedBy := email + createBy := claims.Email + updatedBy := claims.Email _, err = db.Exec( "INSERT INTO saved_views (uuid, name, category, created_at, created_by, updated_at, updated_by, source_page, tags, data, extra_data) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", @@ -186,13 +186,13 @@ func UpdateView(ctx context.Context, uuid_ string, view v3.SavedView) error { return fmt.Errorf("error in marshalling explorer query data: %s", err.Error()) } - email, err := auth.GetEmailFromJwt(ctx) - if err != nil { - return err + claims, ok := authtypes.ClaimsFromContext(ctx) + if !ok { + return fmt.Errorf("error in getting email from context") } updatedAt := time.Now() - updatedBy := email + updatedBy := claims.Email _, err = db.Exec("UPDATE saved_views SET updated_at = ?, updated_by = ?, name = ?, category = ?, source_page = ?, tags = ?, data = ?, extra_data = ? WHERE uuid = ?", updatedAt, updatedBy, view.Name, view.Category, view.SourcePage, strings.Join(view.Tags, ","), data, view.ExtraData, uuid_) diff --git a/pkg/query-service/app/http_handler.go b/pkg/query-service/app/http_handler.go index d406bb8c63..9072b3a891 100644 --- a/pkg/query-service/app/http_handler.go +++ b/pkg/query-service/app/http_handler.go @@ -49,6 +49,7 @@ import ( "go.signoz.io/signoz/pkg/query-service/contextlinks" v3 "go.signoz.io/signoz/pkg/query-service/model/v3" "go.signoz.io/signoz/pkg/query-service/postprocess" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" @@ -126,6 +127,8 @@ type APIHandler struct { jobsRepo *inframetrics.JobsRepo pvcsRepo *inframetrics.PvcsRepo + + JWT *authtypes.JWT } type APIHandlerOpts struct { @@ -165,6 +168,8 @@ type APIHandlerOpts struct { UseLogsNewSchema bool UseTraceNewSchema bool + + JWT *authtypes.JWT } // NewAPIHandler returns an APIHandler @@ -237,6 +242,7 @@ func NewAPIHandler(opts APIHandlerOpts) (*APIHandler, error) { statefulsetsRepo: statefulsetsRepo, jobsRepo: jobsRepo, pvcsRepo: pvcsRepo, + JWT: opts.JWT, } logsQueryBuilder := logsv3.PrepareLogsQuery @@ -1616,9 +1622,9 @@ func (aH *APIHandler) submitFeedback(w http.ResponseWriter, r *http.Request) { "email": email, "message": message, } - userEmail, err := auth.GetEmailFromJwt(r.Context()) - if err == nil { - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_INPRODUCT_FEEDBACK, data, userEmail, true, false) + claims, ok := authtypes.ClaimsFromContext(r.Context()) + if ok { + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_INPRODUCT_FEEDBACK, data, claims.Email, true, false) } } @@ -1628,9 +1634,9 @@ func (aH *APIHandler) registerEvent(w http.ResponseWriter, r *http.Request) { if aH.HandleError(w, err, http.StatusBadRequest) { return } - userEmail, err := auth.GetEmailFromJwt(r.Context()) - if err == nil { - telemetry.GetInstance().SendEvent(request.EventName, request.Attributes, userEmail, request.RateLimited, true) + claims, ok := authtypes.ClaimsFromContext(r.Context()) + if ok { + telemetry.GetInstance().SendEvent(request.EventName, request.Attributes, claims.Email, request.RateLimited, true) aH.WriteJSON(w, r, map[string]string{"data": "Event Processed Successfully"}) } else { RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, nil) @@ -1734,9 +1740,9 @@ func (aH *APIHandler) getServices(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "number": len(*result), } - userEmail, err := auth.GetEmailFromJwt(r.Context()) - if err == nil { - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_NUMBER_OF_SERVICES, data, userEmail, true, false) + claims, ok := authtypes.ClaimsFromContext(r.Context()) + if ok { + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_NUMBER_OF_SERVICES, data, claims.Email, true, false) } if (data["number"] != 0) && (data["number"] != telemetry.DEFAULT_NUMBER_OF_SERVICES) { @@ -2160,7 +2166,7 @@ func (aH *APIHandler) loginUser(w http.ResponseWriter, r *http.Request) { // req.RefreshToken = c.Value // } - resp, err := auth.Login(context.Background(), req) + resp, err := auth.Login(context.Background(), req, aH.JWT) if aH.HandleError(w, err, http.StatusUnauthorized) { return } @@ -2442,11 +2448,11 @@ func (aH *APIHandler) editOrg(w http.ResponseWriter, r *http.Request) { "isAnonymous": req.IsAnonymous, "organizationName": req.Name, } - userEmail, err := auth.GetEmailFromJwt(r.Context()) - if err != nil { - zap.L().Error("failed to get user email from jwt", zap.Error(err)) + claims, ok := authtypes.ClaimsFromContext(r.Context()) + if !ok { + zap.L().Error("failed to get user email from jwt") } - telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_ORG_SETTINGS, data, userEmail, true, false) + telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_ORG_SETTINGS, data, claims.Email, true, false) aH.WriteJSON(w, r, map[string]string{"data": "org updated successfully"}) } @@ -5006,8 +5012,8 @@ func sendQueryResultEvents(r *http.Request, result []*v3.Result, queryRangeParam if len(result) > 0 && (len(result[0].Series) > 0 || len(result[0].List) > 0) { - userEmail, err := auth.GetEmailFromJwt(r.Context()) - if err == nil { + claims, ok := authtypes.ClaimsFromContext(r.Context()) + if ok { queryInfoResult := telemetry.GetInstance().CheckQueryInfo(queryRangeParams) if queryInfoResult.LogsUsed || queryInfoResult.MetricsUsed || queryInfoResult.TracesUsed { @@ -5047,7 +5053,7 @@ func sendQueryResultEvents(r *http.Request, result []*v3.Result, queryRangeParam "filterApplied": queryInfoResult.FilterApplied, "dashboardId": dashboardID, "widgetId": widgetID, - }, userEmail, true, false) + }, claims.Email, true, false) } if alertMatched { var alertID string @@ -5074,7 +5080,7 @@ func sendQueryResultEvents(r *http.Request, result []*v3.Result, queryRangeParam "aggregateAttributeKey": queryInfoResult.AggregateAttributeKey, "filterApplied": queryInfoResult.FilterApplied, "alertId": alertID, - }, userEmail, true, false) + }, claims.Email, true, false) } } } diff --git a/pkg/query-service/app/logparsingpipeline/controller.go b/pkg/query-service/app/logparsingpipeline/controller.go index 6425319313..645d21741f 100644 --- a/pkg/query-service/app/logparsingpipeline/controller.go +++ b/pkg/query-service/app/logparsingpipeline/controller.go @@ -11,10 +11,10 @@ import ( "github.com/jmoiron/sqlx" "github.com/pkg/errors" "go.signoz.io/signoz/pkg/query-service/agentConf" - "go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/constants" "go.signoz.io/signoz/pkg/query-service/model" "go.signoz.io/signoz/pkg/query-service/utils" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" ) @@ -50,9 +50,9 @@ func (ic *LogParsingPipelineController) ApplyPipelines( postable []PostablePipeline, ) (*PipelinesResponse, *model.ApiError) { // get user id from context - userId, authErr := auth.ExtractUserIdFromContext(ctx) - if authErr != nil { - return nil, model.UnauthorizedError(errors.Wrap(authErr, "failed to get userId from context")) + claims, ok := authtypes.ClaimsFromContext(ctx) + if !ok { + return nil, model.UnauthorizedError(fmt.Errorf("failed to get userId from context")) } var pipelines []Pipeline @@ -84,7 +84,7 @@ func (ic *LogParsingPipelineController) ApplyPipelines( } // prepare config by calling gen func - cfg, err := agentConf.StartNewVersion(ctx, userId, agentConf.ElementTypeLogPipelines, elements) + cfg, err := agentConf.StartNewVersion(ctx, claims.UserID, agentConf.ElementTypeLogPipelines, elements) if err != nil || cfg == nil { return nil, err } diff --git a/pkg/query-service/app/logparsingpipeline/db.go b/pkg/query-service/app/logparsingpipeline/db.go index 1e8efeb0e0..3a622c4cae 100644 --- a/pkg/query-service/app/logparsingpipeline/db.go +++ b/pkg/query-service/app/logparsingpipeline/db.go @@ -9,8 +9,8 @@ import ( "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/pkg/errors" - "go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/model" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" ) @@ -45,14 +45,9 @@ func (r *Repo) insertPipeline( )) } - jwt, ok := auth.ExtractJwtFromContext(ctx) + claims, ok := authtypes.ClaimsFromContext(ctx) if !ok { - return nil, model.UnauthorizedError(err) - } - - claims, err := auth.ParseJWT(jwt) - if err != nil { - return nil, model.UnauthorizedError(err) + return nil, model.UnauthorizedError(fmt.Errorf("failed to get email from context")) } insertRow := &Pipeline{ @@ -66,7 +61,7 @@ func (r *Repo) insertPipeline( Config: postable.Config, RawConfig: string(rawConfig), Creator: Creator{ - CreatedBy: claims["email"].(string), + CreatedBy: claims.Email, CreatedAt: time.Now(), }, } diff --git a/pkg/query-service/app/server.go b/pkg/query-service/app/server.go index 6509c161bb..21549bdc9f 100644 --- a/pkg/query-service/app/server.go +++ b/pkg/query-service/app/server.go @@ -25,6 +25,7 @@ import ( opAmpModel "go.signoz.io/signoz/pkg/query-service/app/opamp/model" "go.signoz.io/signoz/pkg/query-service/app/preferences" "go.signoz.io/signoz/pkg/signoz" + "go.signoz.io/signoz/pkg/types/authtypes" "go.signoz.io/signoz/pkg/web" "go.signoz.io/signoz/pkg/query-service/app/explorer" @@ -61,6 +62,7 @@ type ServerOptions struct { UseLogsNewSchema bool UseTraceNewSchema bool SigNoz *signoz.SigNoz + Jwt *authtypes.JWT } // Server runs HTTP, Mux and a grpc server @@ -193,6 +195,7 @@ func NewServer(serverOptions *ServerOptions) (*Server, error) { FluxInterval: fluxInterval, UseLogsNewSchema: serverOptions.UseLogsNewSchema, UseTraceNewSchema: serverOptions.UseTraceNewSchema, + JWT: serverOptions.Jwt, }) if err != nil { return nil, err @@ -247,6 +250,7 @@ func (s *Server) createPrivateServer(api *APIHandler) (*http.Server, error) { r := NewRouter() + r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) r.Use(middleware.NewTimeout(zap.L(), s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes, s.serverOptions.Config.APIServer.Timeout.Default, @@ -277,6 +281,7 @@ func (s *Server) createPublicServer(api *APIHandler, web web.Web) (*http.Server, r := NewRouter() + r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) r.Use(middleware.NewTimeout(zap.L(), s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes, s.serverOptions.Config.APIServer.Timeout.Default, @@ -286,8 +291,8 @@ func (s *Server) createPublicServer(api *APIHandler, web web.Web) (*http.Server, r.Use(middleware.NewLogging(zap.L(), s.serverOptions.Config.APIServer.Logging.ExcludedRoutes).Wrap) // add auth middleware - getUserFromRequest := func(r *http.Request) (*model.UserPayload, error) { - user, err := auth.GetUserFromRequest(r) + getUserFromRequest := func(ctx context.Context) (*model.UserPayload, error) { + user, err := auth.GetUserFromReqContext(ctx) if err != nil { return nil, err diff --git a/pkg/query-service/auth/auth.go b/pkg/query-service/auth/auth.go index 685d79a135..708bf9585b 100644 --- a/pkg/query-service/auth/auth.go +++ b/pkg/query-service/auth/auth.go @@ -8,7 +8,6 @@ import ( "text/template" "time" - "github.com/golang-jwt/jwt" "github.com/google/uuid" "github.com/pkg/errors" @@ -18,6 +17,7 @@ import ( "go.signoz.io/signoz/pkg/query-service/telemetry" "go.signoz.io/signoz/pkg/query-service/utils" smtpservice "go.signoz.io/signoz/pkg/query-service/utils/smtpService" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" ) @@ -75,17 +75,12 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons return nil, errors.Wrap(err, "invalid invite request") } - jwtAdmin, ok := ExtractJwtFromContext(ctx) + claims, ok := authtypes.ClaimsFromContext(ctx) if !ok { - return nil, errors.Wrap(err, "failed to extract admin jwt token") + return nil, errors.Wrap(err, "failed to extract admin user id") } - adminUser, err := validateUser(jwtAdmin) - if err != nil { - return nil, errors.Wrap(err, "failed to validate admin jwt token") - } - - au, apiErr := dao.DB().GetUser(ctx, adminUser.Id) + au, apiErr := dao.DB().GetUser(ctx, claims.UserID) if apiErr != nil { return nil, errors.Wrap(err, "failed to query admin user from the DB") } @@ -123,17 +118,12 @@ func InviteUsers(ctx context.Context, req *model.BulkInviteRequest) (*model.Bulk FailedInvites: []model.FailedInvite{}, } - jwtAdmin, ok := ExtractJwtFromContext(ctx) + claims, ok := authtypes.ClaimsFromContext(ctx) if !ok { - return nil, errors.New("failed to extract admin jwt token") - } - - adminUser, err := validateUser(jwtAdmin) - if err != nil { - return nil, errors.Wrap(err, "failed to validate admin jwt token") + return nil, errors.New("failed to extract admin user id") } - au, apiErr := dao.DB().GetUser(ctx, adminUser.Id) + au, apiErr := dao.DB().GetUser(ctx, claims.UserID) if apiErr != nil { return nil, errors.Wrap(apiErr.Err, "failed to query admin user from the DB") } @@ -550,16 +540,16 @@ func Register(ctx context.Context, req *RegisterRequest) (*model.User, *model.Ap } // Login method returns access and refresh tokens on successful login, else it errors out. -func Login(ctx context.Context, request *model.LoginRequest) (*model.LoginResponse, error) { +func Login(ctx context.Context, request *model.LoginRequest, jwt *authtypes.JWT) (*model.LoginResponse, error) { zap.L().Debug("Login method called for user", zap.String("email", request.Email)) - user, err := authenticateLogin(ctx, request) + user, err := authenticateLogin(ctx, request, jwt) if err != nil { zap.L().Error("Failed to authenticate login request", zap.Error(err)) return nil, err } - userjwt, err := GenerateJWTForUser(&user.User) + userjwt, err := GenerateJWTForUser(&user.User, jwt) if err != nil { zap.L().Error("Failed to generate JWT against login creds", zap.Error(err)) return nil, err @@ -576,20 +566,36 @@ func Login(ctx context.Context, request *model.LoginRequest) (*model.LoginRespon }, nil } -// authenticateLogin is responsible for querying the DB and validating the credentials. -func authenticateLogin(ctx context.Context, req *model.LoginRequest) (*model.UserPayload, error) { +func claimsToUserPayload(claims authtypes.Claims) (*model.UserPayload, error) { + user := &model.UserPayload{ + User: model.User{ + Id: claims.UserID, + GroupId: claims.GroupID, + Email: claims.Email, + OrgId: claims.OrgID, + }, + } + return user, nil +} +// authenticateLogin is responsible for querying the DB and validating the credentials. +func authenticateLogin(ctx context.Context, req *model.LoginRequest, jwt *authtypes.JWT) (*model.UserPayload, error) { // If refresh token is valid, then simply authorize the login request. if len(req.RefreshToken) > 0 { - user, err := validateUser(req.RefreshToken) + // parse the refresh token + claims, err := jwt.Claims(req.RefreshToken) if err != nil { - return nil, errors.Wrap(err, "failed to validate refresh token") + return nil, errors.Wrap(err, "failed to parse refresh token") } - if user.OrgId == "" { + if claims.OrgID == "" { return nil, model.UnauthorizedError(errors.New("orgId is missing in the claims")) } + user, err := claimsToUserPayload(claims) + if err != nil { + return nil, errors.Wrap(err, "failed to convert claims to user payload") + } return user, nil } @@ -618,34 +624,17 @@ func passwordMatch(hash, password string) bool { return err == nil } -func GenerateJWTForUser(user *model.User) (model.UserJwtObject, error) { +func GenerateJWTForUser(user *model.User, jwt *authtypes.JWT) (model.UserJwtObject, error) { j := model.UserJwtObject{} var err error - j.AccessJwtExpiry = time.Now().Add(JwtExpiry).Unix() - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "id": user.Id, - "gid": user.GroupId, - "email": user.Email, - "exp": j.AccessJwtExpiry, - "orgId": user.OrgId, - }) - - j.AccessJwt, err = token.SignedString([]byte(JwtSecret)) + j.AccessJwtExpiry = time.Now().Add(jwt.JwtExpiry).Unix() + j.AccessJwt, err = jwt.AccessToken(user.OrgId, user.Id, user.GroupId, user.Email) if err != nil { return j, errors.Errorf("failed to encode jwt: %v", err) } - j.RefreshJwtExpiry = time.Now().Add(JwtRefresh).Unix() - token = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "id": user.Id, - "gid": user.GroupId, - "email": user.Email, - "exp": j.RefreshJwtExpiry, - "orgId": user.OrgId, - }) - - j.RefreshJwt, err = token.SignedString([]byte(JwtSecret)) + j.RefreshJwtExpiry = time.Now().Add(jwt.JwtRefresh).Unix() + j.RefreshJwt, err = jwt.RefreshToken(user.OrgId, user.Id, user.GroupId, user.Email) if err != nil { return j, errors.Errorf("failed to encode jwt: %v", err) } diff --git a/pkg/query-service/auth/jwt.go b/pkg/query-service/auth/jwt.go deleted file mode 100644 index 637119fe18..0000000000 --- a/pkg/query-service/auth/jwt.go +++ /dev/null @@ -1,134 +0,0 @@ -package auth - -import ( - "context" - "fmt" - "net/http" - "time" - - jwtmiddleware "github.com/auth0/go-jwt-middleware" - "github.com/golang-jwt/jwt" - "github.com/pkg/errors" - "go.signoz.io/signoz/pkg/query-service/model" - "go.uber.org/zap" -) - -var ( - JwtSecret string - JwtExpiry = 30 * time.Minute - JwtRefresh = 30 * 24 * time.Hour -) - -func ParseJWT(jwtStr string) (jwt.MapClaims, error) { - // TODO[@vikrantgupta25] : to update this to the claims check function for better integrity of JWT - // reference - https://pkg.go.dev/github.com/golang-jwt/jwt/v5#Parser.ParseWithClaims - token, err := jwt.Parse(jwtStr, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, errors.Errorf("unknown signing algo: %v", token.Header["alg"]) - } - return []byte(JwtSecret), nil - }) - - if err != nil { - return nil, errors.Wrapf(err, "failed to parse jwt token") - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok || !token.Valid { - return nil, errors.Errorf("Not a valid jwt claim") - } - - return claims, nil -} - -func validateUser(tok string) (*model.UserPayload, error) { - claims, err := ParseJWT(tok) - if err != nil { - return nil, err - } - now := time.Now().Unix() - if !claims.VerifyExpiresAt(now, true) { - return nil, model.ErrorTokenExpired - } - - var orgId string - if claims["orgId"] != nil { - orgId = claims["orgId"].(string) - } - - return &model.UserPayload{ - User: model.User{ - Id: claims["id"].(string), - GroupId: claims["gid"].(string), - Email: claims["email"].(string), - OrgId: orgId, - }, - }, nil -} - -// AttachJwtToContext attached the jwt token from the request header to the context. -func AttachJwtToContext(ctx context.Context, r *http.Request) context.Context { - token, err := ExtractJwtFromRequest(r) - if err != nil { - zap.L().Error("Error while getting token from header", zap.Error(err)) - return ctx - } - - return context.WithValue(ctx, AccessJwtKey, token) -} - -func ExtractJwtFromContext(ctx context.Context) (string, bool) { - jwtToken, ok := ctx.Value(AccessJwtKey).(string) - return jwtToken, ok -} - -func ExtractJwtFromRequest(r *http.Request) (string, error) { - authHeaderJwt, err := jwtmiddleware.FromAuthHeader(r) - if err != nil { - return "", err - } - - if len(authHeaderJwt) > 0 { - return authHeaderJwt, nil - } - - // We expect websocket connections to send auth JWT in the - // `Sec-Websocket-Protocol` header. - // - // The standard js websocket API doesn't allow setting headers - // other than the `Sec-WebSocket-Protocol` header, which is often - // used for auth purposes as a result. - return r.Header.Get("Sec-WebSocket-Protocol"), nil -} - -func ExtractUserIdFromContext(ctx context.Context) (string, error) { - userId := "" - jwt, ok := ExtractJwtFromContext(ctx) - if !ok { - return "", model.InternalError(fmt.Errorf("failed to extract jwt from context")) - } - - claims, err := ParseJWT(jwt) - if err != nil { - return "", model.InternalError(fmt.Errorf("failed get claims from jwt %v", err)) - } - - if v, ok := claims["id"]; ok { - userId = v.(string) - } - return userId, nil -} - -func GetEmailFromJwt(ctx context.Context) (string, error) { - jwt, ok := ExtractJwtFromContext(ctx) - if !ok { - return "", model.InternalError(fmt.Errorf("failed to extract jwt from context")) - } - - claims, err := ParseJWT(jwt) - if err != nil { - return "", model.InternalError(fmt.Errorf("failed get claims from jwt %v", err)) - } - - return claims["email"].(string), nil -} diff --git a/pkg/query-service/auth/rbac.go b/pkg/query-service/auth/rbac.go index 44f65576ed..e5fe687b2c 100644 --- a/pkg/query-service/auth/rbac.go +++ b/pkg/query-service/auth/rbac.go @@ -2,12 +2,12 @@ package auth import ( "context" - "net/http" "github.com/pkg/errors" "go.signoz.io/signoz/pkg/query-service/constants" "go.signoz.io/signoz/pkg/query-service/dao" "go.signoz.io/signoz/pkg/query-service/model" + "go.signoz.io/signoz/pkg/types/authtypes" ) type Group struct { @@ -48,15 +48,19 @@ func InitAuthCache(ctx context.Context) error { return nil } -func GetUserFromRequest(r *http.Request) (*model.UserPayload, error) { - accessJwt, err := ExtractJwtFromRequest(r) - if err != nil { - return nil, err +func GetUserFromReqContext(ctx context.Context) (*model.UserPayload, error) { + claims, ok := authtypes.ClaimsFromContext(ctx) + if !ok { + return nil, errors.New("no claims found in context") } - user, err := validateUser(accessJwt) - if err != nil { - return nil, err + user := &model.UserPayload{ + User: model.User{ + Id: claims.UserID, + GroupId: claims.GroupID, + Email: claims.Email, + OrgId: claims.OrgID, + }, } return user, nil } diff --git a/pkg/query-service/main.go b/pkg/query-service/main.go index 2d03746232..55c00a54c5 100644 --- a/pkg/query-service/main.go +++ b/pkg/query-service/main.go @@ -17,6 +17,7 @@ import ( "go.signoz.io/signoz/pkg/query-service/constants" "go.signoz.io/signoz/pkg/query-service/version" "go.signoz.io/signoz/pkg/signoz" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -98,6 +99,17 @@ func main() { zap.L().Fatal("Failed to create signoz struct", zap.Error(err)) } + // Read the jwt secret key + jwtSecret := os.Getenv("SIGNOZ_JWT_SECRET") + + if len(jwtSecret) == 0 { + zap.L().Warn("No JWT secret key is specified.") + } else { + zap.L().Info("JWT secret key set successfully.") + } + + jwt := authtypes.NewJWT(jwtSecret, 30*time.Minute, 30*24*time.Hour) + serverOptions := &app.ServerOptions{ Config: config, HTTPHostPort: constants.HTTPHostPort, @@ -114,15 +126,7 @@ func main() { UseLogsNewSchema: useLogsNewSchema, UseTraceNewSchema: useTraceNewSchema, SigNoz: signoz, - } - - // Read the jwt secret key - auth.JwtSecret = os.Getenv("SIGNOZ_JWT_SECRET") - - if len(auth.JwtSecret) == 0 { - zap.L().Warn("No JWT secret key is specified.") - } else { - zap.L().Info("JWT secret key set successfully.") + Jwt: jwt, } server, err := app.NewServer(serverOptions) diff --git a/pkg/query-service/rules/db.go b/pkg/query-service/rules/db.go index 343023dd88..f3f89f0156 100644 --- a/pkg/query-service/rules/db.go +++ b/pkg/query-service/rules/db.go @@ -10,11 +10,12 @@ import ( "time" "github.com/jmoiron/sqlx" - "go.signoz.io/signoz/pkg/query-service/auth" + "github.com/pkg/errors" "go.signoz.io/signoz/pkg/query-service/common" am "go.signoz.io/signoz/pkg/query-service/integrations/alertManager" "go.signoz.io/signoz/pkg/query-service/model" v3 "go.signoz.io/signoz/pkg/query-service/model/v3" + "go.signoz.io/signoz/pkg/types/authtypes" "go.uber.org/zap" ) @@ -267,10 +268,13 @@ func (r *ruleDB) GetPlannedMaintenanceByID(ctx context.Context, id string) (*Pla func (r *ruleDB) CreatePlannedMaintenance(ctx context.Context, maintenance PlannedMaintenance) (int64, error) { - email, _ := auth.GetEmailFromJwt(ctx) - maintenance.CreatedBy = email + claims, ok := authtypes.ClaimsFromContext(ctx) + if !ok { + return 0, errors.New("no claims found in context") + } + maintenance.CreatedBy = claims.Email maintenance.CreatedAt = time.Now() - maintenance.UpdatedBy = email + maintenance.UpdatedBy = claims.Email maintenance.UpdatedAt = time.Now() query := "INSERT INTO planned_maintenance (name, description, schedule, alert_ids, created_at, created_by, updated_at, updated_by) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" @@ -298,8 +302,11 @@ func (r *ruleDB) DeletePlannedMaintenance(ctx context.Context, id string) (strin } func (r *ruleDB) EditPlannedMaintenance(ctx context.Context, maintenance PlannedMaintenance, id string) (string, error) { - email, _ := auth.GetEmailFromJwt(ctx) - maintenance.UpdatedBy = email + claims, ok := authtypes.ClaimsFromContext(ctx) + if !ok { + return "", errors.New("no claims found in context") + } + maintenance.UpdatedBy = claims.Email maintenance.UpdatedAt = time.Now() query := "UPDATE planned_maintenance SET name=$1, description=$2, schedule=$3, alert_ids=$4, updated_at=$5, updated_by=$6 WHERE id=$7" diff --git a/pkg/query-service/tests/integration/filter_suggestions_test.go b/pkg/query-service/tests/integration/filter_suggestions_test.go index 793ca7d442..4fa9dff51e 100644 --- a/pkg/query-service/tests/integration/filter_suggestions_test.go +++ b/pkg/query-service/tests/integration/filter_suggestions_test.go @@ -11,6 +11,7 @@ import ( mockhouse "github.com/srikanthccv/ClickHouse-go-mock" "github.com/stretchr/testify/require" + "go.signoz.io/signoz/pkg/http/middleware" "go.signoz.io/signoz/pkg/query-service/app" "go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/constants" @@ -299,13 +300,16 @@ func NewFilterSuggestionsTestBed(t *testing.T) *FilterSuggestionsTestBed { Reader: reader, AppDao: dao.DB(), FeatureFlags: fm, + JWT: jwt, }) if err != nil { t.Fatalf("could not create a new ApiHandler: %v", err) } router := app.NewRouter() - am := app.NewAuthMiddleware(auth.GetUserFromRequest) + //add the jwt middleware + router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) + am := app.NewAuthMiddleware(auth.GetUserFromReqContext) apiHandler.RegisterRoutes(router, am) apiHandler.RegisterQueryRangeV3Routes(router, am) diff --git a/pkg/query-service/tests/integration/logparsingpipeline_test.go b/pkg/query-service/tests/integration/logparsingpipeline_test.go index 50c577002b..7c0b218168 100644 --- a/pkg/query-service/tests/integration/logparsingpipeline_test.go +++ b/pkg/query-service/tests/integration/logparsingpipeline_test.go @@ -22,7 +22,6 @@ import ( "go.signoz.io/signoz/pkg/query-service/app/logparsingpipeline" "go.signoz.io/signoz/pkg/query-service/app/opamp" opampModel "go.signoz.io/signoz/pkg/query-service/app/opamp/model" - "go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/constants" "go.signoz.io/signoz/pkg/query-service/dao" "go.signoz.io/signoz/pkg/query-service/model" @@ -470,6 +469,7 @@ func NewTestbedWithoutOpamp(t *testing.T, testDB *sqlx.DB) *LogPipelinesTestBed apiHandler, err := app.NewAPIHandler(app.APIHandlerOpts{ AppDao: dao.DB(), LogsParsingPipelineController: controller, + JWT: jwt, }) if err != nil { t.Fatalf("could not create a new ApiHandler: %v", err) @@ -540,7 +540,12 @@ func (tb *LogPipelinesTestBed) PostPipelinesToQSExpectingStatusCode( } respWriter := httptest.NewRecorder() - ctx := auth.AttachJwtToContext(req.Context(), req) + + ctx, err := tb.apiHandler.JWT.ContextFromRequest(req.Context(), req.Header.Get("Authorization")) + if err != nil { + tb.t.Fatalf("couldn't get jwt from request: %v", err) + } + req = req.WithContext(ctx) tb.apiHandler.CreateLogsPipeline(respWriter, req) diff --git a/pkg/query-service/tests/integration/signoz_cloud_integrations_test.go b/pkg/query-service/tests/integration/signoz_cloud_integrations_test.go index a0d9c99f7f..7de502a563 100644 --- a/pkg/query-service/tests/integration/signoz_cloud_integrations_test.go +++ b/pkg/query-service/tests/integration/signoz_cloud_integrations_test.go @@ -12,6 +12,7 @@ import ( "github.com/jmoiron/sqlx" mockhouse "github.com/srikanthccv/ClickHouse-go-mock" "github.com/stretchr/testify/require" + "go.signoz.io/signoz/pkg/http/middleware" "go.signoz.io/signoz/pkg/query-service/app" "go.signoz.io/signoz/pkg/query-service/app/cloudintegrations" "go.signoz.io/signoz/pkg/query-service/auth" @@ -19,6 +20,7 @@ import ( "go.signoz.io/signoz/pkg/query-service/featureManager" "go.signoz.io/signoz/pkg/query-service/model" "go.signoz.io/signoz/pkg/query-service/utils" + "go.uber.org/zap" ) func TestAWSIntegrationAccountLifecycle(t *testing.T) { @@ -361,13 +363,15 @@ func NewCloudIntegrationsTestBed(t *testing.T, testDB *sqlx.DB) *CloudIntegratio AppDao: dao.DB(), CloudIntegrationsController: controller, FeatureFlags: fm, + JWT: jwt, }) if err != nil { t.Fatalf("could not create a new ApiHandler: %v", err) } router := app.NewRouter() - am := app.NewAuthMiddleware(auth.GetUserFromRequest) + router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) + am := app.NewAuthMiddleware(auth.GetUserFromReqContext) apiHandler.RegisterRoutes(router, am) apiHandler.RegisterCloudIntegrationsRoutes(router, am) diff --git a/pkg/query-service/tests/integration/signoz_integrations_test.go b/pkg/query-service/tests/integration/signoz_integrations_test.go index 9c124e9fc4..92308aafc5 100644 --- a/pkg/query-service/tests/integration/signoz_integrations_test.go +++ b/pkg/query-service/tests/integration/signoz_integrations_test.go @@ -11,6 +11,7 @@ import ( "github.com/jmoiron/sqlx" mockhouse "github.com/srikanthccv/ClickHouse-go-mock" "github.com/stretchr/testify/require" + "go.signoz.io/signoz/pkg/http/middleware" "go.signoz.io/signoz/pkg/query-service/app" "go.signoz.io/signoz/pkg/query-service/app/cloudintegrations" "go.signoz.io/signoz/pkg/query-service/app/dashboards" @@ -22,6 +23,7 @@ import ( "go.signoz.io/signoz/pkg/query-service/model" v3 "go.signoz.io/signoz/pkg/query-service/model/v3" "go.signoz.io/signoz/pkg/query-service/utils" + "go.uber.org/zap" ) // Higher level tests for UI facing APIs @@ -568,6 +570,7 @@ func NewIntegrationsTestBed(t *testing.T, testDB *sqlx.DB) *IntegrationsTestBed AppDao: dao.DB(), IntegrationsController: controller, FeatureFlags: fm, + JWT: jwt, CloudIntegrationsController: cloudIntegrationsController, }) if err != nil { @@ -575,7 +578,8 @@ func NewIntegrationsTestBed(t *testing.T, testDB *sqlx.DB) *IntegrationsTestBed } router := app.NewRouter() - am := app.NewAuthMiddleware(auth.GetUserFromRequest) + router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) + am := app.NewAuthMiddleware(auth.GetUserFromReqContext) apiHandler.RegisterRoutes(router, am) apiHandler.RegisterIntegrationRoutes(router, am) diff --git a/pkg/query-service/tests/integration/test_utils.go b/pkg/query-service/tests/integration/test_utils.go index 464b7ad81a..8e0fbf6480 100644 --- a/pkg/query-service/tests/integration/test_utils.go +++ b/pkg/query-service/tests/integration/test_utils.go @@ -25,9 +25,12 @@ import ( "go.signoz.io/signoz/pkg/query-service/dao" "go.signoz.io/signoz/pkg/query-service/interfaces" "go.signoz.io/signoz/pkg/query-service/model" + "go.signoz.io/signoz/pkg/types/authtypes" "golang.org/x/exp/maps" ) +var jwt = authtypes.NewJWT("secret", 1*time.Hour, 2*time.Hour) + func NewMockClickhouseReader( t *testing.T, testDB *sqlx.DB, featureFlags interfaces.FeatureLookup, ) ( @@ -184,7 +187,7 @@ func AuthenticatedRequestForTest( path string, postData interface{}, ) (*http.Request, error) { - userJwt, err := auth.GenerateJWTForUser(user) + userJwt, err := auth.GenerateJWTForUser(user, jwt) if err != nil { return nil, err } diff --git a/pkg/types/authtypes/jwt.go b/pkg/types/authtypes/jwt.go new file mode 100644 index 0000000000..7754eaf185 --- /dev/null +++ b/pkg/types/authtypes/jwt.go @@ -0,0 +1,141 @@ +package authtypes + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +type jwtClaimsKey struct{} + +type Claims struct { + jwt.RegisteredClaims + UserID string `json:"id"` + GroupID string `json:"gid"` + Email string `json:"email"` + OrgID string `json:"orgId"` +} + +type JWT struct { + JwtSecret string + JwtExpiry time.Duration + JwtRefresh time.Duration +} + +func NewJWT(jwtSecret string, jwtExpiry time.Duration, jwtRefresh time.Duration) *JWT { + return &JWT{ + JwtSecret: jwtSecret, + JwtExpiry: jwtExpiry, + JwtRefresh: jwtRefresh, + } +} + +func parseBearerAuth(auth string) (string, bool) { + const prefix = "Bearer " + // Case insensitive prefix match + if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { + return "", false + } + + return auth[len(prefix):], true +} + +func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context.Context, error) { + var value string + for _, v := range values { + if v != "" { + value = v + break + } + } + + if value == "" { + return ctx, errors.New("missing Authorization header") + } + + // parse from + bearerToken, ok := parseBearerAuth(value) + if !ok { + // this will take care that if the value is not of type bearer token, directly use it + bearerToken = value + } + + claims, err := j.Claims(bearerToken) + if err != nil { + return ctx, err + } + + return NewContextWithClaims(ctx, claims), nil +} + +func (j *JWT) Claims(jwtStr string) (Claims, error) { + token, err := jwt.ParseWithClaims(jwtStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unknown signing algo: %v", token.Header["alg"]) + } + return []byte(j.JwtSecret), nil + }) + + if err != nil { + return Claims{}, fmt.Errorf("failed to parse jwt token: %w", err) + } + + // Type assertion to retrieve claims from the token + userClaims, ok := token.Claims.(*Claims) + if !ok { + return Claims{}, errors.New("failed to retrieve claims from token") + } + + return *userClaims, nil +} + +// NewContextWithClaims attaches individual claims to the context. +func NewContextWithClaims(ctx context.Context, claims Claims) context.Context { + ctx = context.WithValue(ctx, jwtClaimsKey{}, claims) + return ctx +} + +// signToken creates and signs a JWT token with the given claims +func (j *JWT) signToken(claims Claims) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(j.JwtSecret)) +} + +// AccessToken creates an access token with the provided claims +func (j *JWT) AccessToken(orgId, userId, groupId, email string) (string, error) { + claims := Claims{ + UserID: userId, + GroupID: groupId, + Email: email, + OrgID: orgId, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtExpiry)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + return j.signToken(claims) +} + +// RefreshToken creates a refresh token with the provided claims +func (j *JWT) RefreshToken(orgId, userId, groupId, email string) (string, error) { + claims := Claims{ + UserID: userId, + GroupID: groupId, + Email: email, + OrgID: orgId, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtRefresh)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + return j.signToken(claims) +} + +func ClaimsFromContext(ctx context.Context) (Claims, bool) { + claims, ok := ctx.Value(jwtClaimsKey{}).(Claims) + return claims, ok +} diff --git a/pkg/types/authtypes/jwt_test.go b/pkg/types/authtypes/jwt_test.go new file mode 100644 index 0000000000..bfd9749f54 --- /dev/null +++ b/pkg/types/authtypes/jwt_test.go @@ -0,0 +1,129 @@ +package authtypes + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestGetAccessJwt(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + token, err := jwtService.AccessToken("orgId", "userId", "groupId", "email@example.com") + + assert.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestGetRefreshJwt(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + token, err := jwtService.RefreshToken("orgId", "userId", "groupId", "email@example.com") + + assert.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestGetJwtClaims(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + + // Create a valid token + claims := Claims{ + UserID: "userId", + GroupID: "groupId", + Email: "email@example.com", + OrgID: "orgId", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + tokenString, err := jwtService.signToken(claims) + assert.NoError(t, err) + + // Test retrieving claims from the token + retrievedClaims, err := jwtService.Claims(tokenString) + assert.NoError(t, err) + assert.Equal(t, claims.UserID, retrievedClaims.UserID) + assert.Equal(t, claims.GroupID, retrievedClaims.GroupID) + assert.Equal(t, claims.Email, retrievedClaims.Email) + assert.Equal(t, claims.OrgID, retrievedClaims.OrgID) +} + +func TestGetJwtClaimsInvalidToken(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + + // Test retrieving claims from an invalid token + _, err := jwtService.Claims("invalid.token.string") + assert.Error(t, err) + assert.Contains(t, err.Error(), "token is malformed") +} + +func TestGetJwtClaimsExpiredToken(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + + // Create an expired token + claims := Claims{ + UserID: "userId", + GroupID: "groupId", + Email: "email@example.com", + OrgID: "orgId", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + tokenString, err := jwtService.signToken(claims) + assert.NoError(t, err) + + _, err = jwtService.Claims(tokenString) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token is expired") +} + +func TestGetJwtClaimsInvalidSignature(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + + // Create a valid token + claims := Claims{ + UserID: "userId", + GroupID: "groupId", + Email: "email@example.com", + OrgID: "orgId", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + }, + } + validToken, err := jwtService.signToken(claims) + assert.NoError(t, err) + + // Modify the token to create an invalid signature + invalidToken := validToken + "tampered" + + // Test retrieving claims from the invalid signature token + _, err = jwtService.Claims(invalidToken) + assert.Error(t, err) + assert.Contains(t, err.Error(), "signature is invalid") +} + +func TestParseBearerAuth(t *testing.T) { + tests := []struct { + auth string + expected string + expectOk bool + }{ + {"Bearer validToken", "validToken", true}, + {"bearer validToken", "validToken", true}, + {"InvalidToken", "", false}, + {"Bearer", "", false}, + {"", "", false}, + } + + for _, test := range tests { + t.Run(test.auth, func(t *testing.T) { + token, ok := parseBearerAuth(test.auth) + assert.Equal(t, test.expected, token) + assert.Equal(t, test.expectOk, ok) + }) + } +} diff --git a/pkg/types/authtypes/uuid.go b/pkg/types/authtypes/uuid.go new file mode 100644 index 0000000000..c387a8cd2e --- /dev/null +++ b/pkg/types/authtypes/uuid.go @@ -0,0 +1,40 @@ +package authtypes + +import ( + "context" + "errors" +) + +type uuidKey struct{} + +type UUID struct { +} + +func NewUUID() *UUID { + return &UUID{} +} + +func (u *UUID) ContextFromRequest(ctx context.Context, values ...string) (context.Context, error) { + var value string + for _, v := range values { + if v != "" { + value = v + break + } + } + + if value == "" { + return ctx, errors.New("missing Authorization header") + } + + return NewContextWithUUID(ctx, value), nil +} + +func NewContextWithUUID(ctx context.Context, uuid string) context.Context { + return context.WithValue(ctx, uuidKey{}, uuid) +} + +func UUIDFromContext(ctx context.Context) (string, bool) { + uuid, ok := ctx.Value(uuidKey{}).(string) + return uuid, ok +}