diff --git a/internal/api/api.go b/internal/api/api.go index 534f168b2..a09d045cb 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -10,7 +10,9 @@ import ( "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/hooks/hookshttp" + "github.com/supabase/auth/internal/hooks/hookspgfunc" + "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" @@ -33,7 +35,7 @@ type API struct { config *conf.GlobalConfiguration version string - hooksMgr *hooks.Manager + hooksMgr *v0hooks.Manager hibpClient *hibp.PwnedClient // overrideTime can be used to override the clock used by handlers. Should only be used in tests! @@ -87,7 +89,9 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne api.limiterOpts = NewLimiterOptions(globalConfig) } if api.hooksMgr == nil { - api.hooksMgr = hooks.NewManager(db, globalConfig) + httpDr := hookshttp.New() + pgfuncDr := hookspgfunc.New(db) + api.hooksMgr = v0hooks.NewManager(globalConfig, httpDr, pgfuncDr) } if api.config.Password.HIBP.Enabled { httpClient := &http.Client{ diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go index 9a3097de8..8195f1104 100644 --- a/internal/api/hooks_test.go +++ b/internal/api/hooks_test.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "testing" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks/hookserrors" "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" @@ -70,7 +70,7 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { testURL := "http://localhost:54321/functions/v1/custom-sms-sender" ts.Config.Hook.SendSMS.URI = testURL - unsuccessfulResponse := v0hooks.AuthHookError{ + unsuccessfulResponse := hookserrors.Error{ HTTPCode: http.StatusUnprocessableEntity, Message: "test error", } @@ -78,12 +78,12 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { testCases := []struct { description string expectError bool - mockResponse v0hooks.AuthHookError + mockResponse hookserrors.Error }{ { description: "Hook returns success", expectError: false, - mockResponse: v0hooks.AuthHookError{}, + mockResponse: hookserrors.Error{}, }, { description: "Hook returns error", @@ -102,23 +102,22 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { Post("/"). MatchType("json"). Reply(http.StatusUnprocessableEntity). - JSON(v0hooks.SendSMSOutput{HookError: unsuccessfulResponse}) + JSON(struct { + Error *hookserrors.Error `json:"error,omitempty"` + }{Error: &unsuccessfulResponse}) for _, tc := range testCases { ts.Run(tc.description, func() { req, _ := http.NewRequest("POST", ts.Config.Hook.SendSMS.URI, nil) - body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input) + + var output v0hooks.SendSMSOutput + err := ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output) if !tc.expectError { require.NoError(ts.T(), err) } else { require.Error(ts.T(), err) - if body != nil { - var output v0hooks.SendSMSOutput - require.NoError(ts.T(), json.Unmarshal(body, &output)) - require.Equal(ts.T(), unsuccessfulResponse.HTTPCode, output.HookError.HTTPCode) - require.Equal(ts.T(), unsuccessfulResponse.Message, output.HookError.Message) - } + require.Equal(ts.T(), output, v0hooks.SendSMSOutput{}) } }) } @@ -154,12 +153,9 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) require.NoError(ts.T(), err) - body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input) - require.NoError(ts.T(), err) - var output v0hooks.SendSMSOutput - err = json.Unmarshal(body, &output) - require.NoError(ts.T(), err, "Unmarshal should not fail") + err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output) + require.NoError(ts.T(), err) // Ensure that all expected HTTP interactions (mocks) have been called require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry") @@ -186,10 +182,10 @@ func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() { req, err := http.NewRequest("POST", "http://localhost:9999/otp", nil) require.NoError(ts.T(), err) - _, err = ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input) + var output v0hooks.SendSMSOutput + err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output) require.Error(ts.T(), err, "Expected an error due to wrong content type") require.Contains(ts.T(), err.Error(), "Invalid JSON response.") - require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called") } diff --git a/internal/api/token.go b/internal/api/token.go index c2efcae1a..0f041faf5 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "net/http" "net/url" "strconv" @@ -9,6 +10,7 @@ import ( "github.com/gofrs/uuid" "github.com/golang-jwt/jwt/v5" + "github.com/xeipuuv/gojsonschema" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/hooks/v0hooks" @@ -369,6 +371,9 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user if err != nil { return "", 0, err } + if err := validateTokenClaims(output.Claims); err != nil { + return "", 0, err + } gotrueClaims = jwt.MapClaims(output.Claims) } @@ -376,7 +381,6 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user if err != nil { return "", 0, err } - return signed, expiresAt.Unix(), nil } @@ -491,3 +495,86 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, User: user, }, nil } + +var schemaLoader = gojsonschema.NewStringLoader(MinimumViableTokenSchema) + +func validateTokenClaims(outputClaims map[string]interface{}) error { + documentLoader := gojsonschema.NewGoLoader(outputClaims) + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + return err + } + + if !result.Valid() { + var errorMessages string + + for _, desc := range result.Errors() { + errorMessages += fmt.Sprintf("- %s\n", desc) + fmt.Printf("- %s\n", desc) + } + return fmt.Errorf( + "output claims do not conform to the expected schema: \n%s", errorMessages) + + } + + return nil +} + +// #nosec +const MinimumViableTokenSchema = `{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "aud": { + "type": ["string", "array"] + }, + "exp": { + "type": "integer" + }, + "jti": { + "type": "string" + }, + "iat": { + "type": "integer" + }, + "iss": { + "type": "string" + }, + "nbf": { + "type": "integer" + }, + "sub": { + "type": "string" + }, + "email": { + "type": "string" + }, + "phone": { + "type": "string" + }, + "app_metadata": { + "type": "object", + "additionalProperties": true + }, + "user_metadata": { + "type": "object", + "additionalProperties": true + }, + "role": { + "type": "string" + }, + "aal": { + "type": "string" + }, + "amr": { + "type": "array", + "items": { + "type": "object" + } + }, + "session_id": { + "type": "string" + } + }, + "required": ["aud", "exp", "iat", "sub", "email", "phone", "role", "aal", "session_id", "is_anonymous"] +}` diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index c9b7aff58..5cb2f50a4 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -640,6 +640,9 @@ type HookConfiguration struct { CustomAccessToken ExtensibilityPointConfiguration `json:"custom_access_token" split_words:"true"` SendEmail ExtensibilityPointConfiguration `json:"send_email" split_words:"true"` SendSMS ExtensibilityPointConfiguration `json:"send_sms" split_words:"true"` + + BeforeUserCreated ExtensibilityPointConfiguration `json:"before_user_created" split_words:"true"` + AfterUserCreated ExtensibilityPointConfiguration `json:"after_user_created" split_words:"true"` } type HTTPHookSecrets []string @@ -671,6 +674,8 @@ func (h *HookConfiguration) Validate() error { h.CustomAccessToken, h.SendSMS, h.SendEmail, + h.BeforeUserCreated, + h.AfterUserCreated, } for _, point := range points { if err := point.ValidateExtensibilityPoint(); err != nil { @@ -888,6 +893,18 @@ func populateGlobal(config *GlobalConfiguration) error { } } + if config.Hook.BeforeUserCreated.Enabled { + if err := config.Hook.BeforeUserCreated.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.Hook.AfterUserCreated.Enabled { + if err := config.Hook.AfterUserCreated.PopulateExtensibilityPoint(); err != nil { + return err + } + } + if config.SAML.Enabled { if err := config.SAML.PopulateFields(config.API.ExternalURL); err != nil { return err diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index b54d7e9da..488a1636f 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -176,6 +176,36 @@ func TestGlobal(t *testing.T) { os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") } + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + cfg.Hook = HookConfiguration{ + BeforeUserCreated: ExtensibilityPointConfiguration{ + Enabled: true, + URI: "\n", + }, + } + + err := populateGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + + { + os.Setenv("API_EXTERNAL_URL", "") + cfg := new(GlobalConfiguration) + cfg.Hook = HookConfiguration{ + AfterUserCreated: ExtensibilityPointConfiguration{ + Enabled: true, + URI: "\n", + }, + } + + err := populateGlobal(cfg) + require.Error(t, err) + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + } + { os.Setenv("API_EXTERNAL_URL", "") cfg := new(GlobalConfiguration) @@ -490,6 +520,11 @@ func TestValidate(t *testing.T) { err: `conf: session timebox duration must` + ` be positive when set, was -1`, }, + { + val: &SessionsConfiguration{InactivityTimeout: toPtr(time.Duration(-1))}, + err: `conf: session inactivity timeout duration must` + + ` be positive when set, was -1ns`, + }, { val: &SessionsConfiguration{AllowLowAAL: nil}, }, @@ -532,6 +567,17 @@ func TestValidate(t *testing.T) { err: `conf: mailer validation headers not a map[string][]string format:` + ` invalid character 'i' looking for beginning of value`, }, + { + val: &MailerConfiguration{EmailValidationBlockedMX: "invalid"}, + err: `conf: email_validation_blocked_mx`, + }, + { + val: &MailerConfiguration{EmailValidationBlockedMX: `["foo.com"]`}, + check: func(t *testing.T, v any) { + got := (v.(*MailerConfiguration)).GetEmailValidationBlockedMXRecords() + require.True(t, got["foo.com"]) + }, + }, { val: &CaptchaConfiguration{Enabled: false}, diff --git a/internal/e2e/e2e.go b/internal/e2e/e2e.go index 201bd47f8..471a10652 100644 --- a/internal/e2e/e2e.go +++ b/internal/e2e/e2e.go @@ -16,8 +16,14 @@ var ( configPath string ) +var isTesting func() bool = testing.Testing + func init() { - if testing.Testing() { + initPackage() +} + +func initPackage() { + if isTesting() { _, thisFile, _, _ := runtime.Caller(0) projectRoot = filepath.Join(filepath.Dir(thisFile), "../..") configPath = filepath.Join(GetProjectRoot(), "hack", "test.env") diff --git a/internal/e2e/e2e_test.go b/internal/e2e/e2e_test.go index ba5b0990c..e10d07a6a 100644 --- a/internal/e2e/e2e_test.go +++ b/internal/e2e/e2e_test.go @@ -94,4 +94,27 @@ func TestUtils(t *testing.T) { t.Fatal("exp non-nil err") } }() + + // block init from main() + func() { + restore := isTesting + defer func() { + isTesting = restore + }() + isTesting = func() bool { return false } + + var errStr string + func() { + defer func() { + errStr = recover().(string) + }() + + initPackage() + }() + + exp := "package e2e may not be used in a main package" + if errStr != exp { + t.Fatalf("exp %v; got %v", exp, errStr) + } + }() } diff --git a/internal/e2e/e2eapi/e2eapi.go b/internal/e2e/e2eapi/e2eapi.go index 7c53c09f4..f28e89e51 100644 --- a/internal/e2e/e2eapi/e2eapi.go +++ b/internal/e2e/e2eapi/e2eapi.go @@ -90,12 +90,18 @@ func Do( if err != nil { return err } - if err := json.Unmarshal(data, res); err != nil { - return err + if len(data) > 0 { + if err := json.Unmarshal(data, res); err != nil { + return err + } } return nil } +const responseLimit = 1e6 + +var defaultClient = http.DefaultClient + func do( ctx context.Context, method string, @@ -113,7 +119,7 @@ func do( h.Add("Content-Type", "application/json") h.Add("Accept", "application/json") - httpRes, err := http.DefaultClient.Do(httpReq) + httpRes, err := defaultClient.Do(httpReq) if err != nil { return nil, err } @@ -124,7 +130,7 @@ func do( return nil, nil case sc >= 400: - data, err := io.ReadAll(io.LimitReader(httpRes.Body, 1e8)) + data, err := io.ReadAll(io.LimitReader(httpRes.Body, responseLimit)) if err != nil { return nil, err } @@ -142,7 +148,7 @@ func do( return nil, err default: - data, err := io.ReadAll(io.LimitReader(httpRes.Body, 1e8)) + data, err := io.ReadAll(io.LimitReader(httpRes.Body, responseLimit)) if err != nil { return nil, err } diff --git a/internal/e2e/e2eapi/e2eapi_test.go b/internal/e2e/e2eapi/e2eapi_test.go index ed5095218..391de9514 100644 --- a/internal/e2e/e2eapi/e2eapi_test.go +++ b/internal/e2e/e2eapi/e2eapi_test.go @@ -2,8 +2,13 @@ package e2eapi import ( "context" + "errors" + "fmt" + "io" "net/http" + "net/http/httptest" "testing" + "testing/iotest" "time" "github.com/gofrs/uuid" @@ -17,42 +22,34 @@ func TestInstance(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*4) defer cancel() - globalCfg := e2e.Must(e2e.Config()) - inst, err := New(globalCfg) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } - defer inst.Close() + t.Run("New", func(t *testing.T) { + t.Run("Success", func(t *testing.T) { + globalCfg := e2e.Must(e2e.Config()) + inst, err := New(globalCfg) + require.NoError(t, err) + defer inst.Close() - { - email := "e2etesthooks_" + uuid.Must(uuid.NewV4()).String() + "@localhost" - req := &api.SignupParams{ - Email: email, - Password: "password", - } - res := new(models.User) - err := Do(ctx, http.MethodPost, inst.APIServer.URL+"/signup", req, res) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } - require.Equal(t, email, res.Email.String()) - } -} + email := "e2etesthooks_" + uuid.Must(uuid.NewV4()).String() + "@localhost" + req := &api.SignupParams{ + Email: email, + Password: "password", + } + res := new(models.User) + err = Do(ctx, http.MethodPost, inst.APIServer.URL+"/signup", req, res) + require.NoError(t, err) + require.Equal(t, email, res.Email.String()) + }) -func TestNew(t *testing.T) { - { - globalCfg := e2e.Must(e2e.Config()) - globalCfg.DB.Driver = "" - globalCfg.DB.URL = "invalid" + t.Run("Failure", func(t *testing.T) { + globalCfg := e2e.Must(e2e.Config()) + globalCfg.DB.Driver = "" + globalCfg.DB.URL = "invalid" - inst, err := New(globalCfg) - if err == nil { - t.Fatal("exp non-nil err") - } - if inst != nil { - t.Fatal("exp nil *Instance") - } - } + inst, err := New(globalCfg) + require.Error(t, err) + require.Nil(t, inst) + }) + }) } func TestDo(t *testing.T) { @@ -61,59 +58,130 @@ func TestDo(t *testing.T) { globalCfg := e2e.Must(e2e.Config()) inst, err := New(globalCfg) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } + require.NoError(t, err) defer inst.Close() - { + // Covers calls to Do with a `req` param type which can't marshaled + t.Run("InvalidRequestType", func(t *testing.T) { req := make(chan string) err := Do(ctx, http.MethodPost, "http://localhost", &req, nil) - if err == nil { - t.Fatal("exp non-nil err") - } + require.Error(t, err) require.ErrorContains(t, err, "json: unsupported type: chan string") - } + }) - { - res := make(chan string) - err := Do(ctx, http.MethodGet, inst.APIServer.URL+"/user", nil, &res) - if err == nil { - t.Fatal("exp non-nil err") - } - require.ErrorContains(t, err, "401: This endpoint requires a Bearer token") - } - - { + // Covers calls to Do with a `res` param type which can't marshaled + t.Run("InvalidResponseType", func(t *testing.T) { res := make(chan string) err := Do(ctx, http.MethodGet, inst.APIServer.URL+"/settings", nil, &res) - if err == nil { - t.Fatal("exp non-nil err") - } + require.Error(t, err) require.ErrorContains(t, err, "json: cannot unmarshal object into Go value of type chan string") - } + }) + + // Covers status code >= 400 error handling switch statement + t.Run("api.HTTPErrorResponse_to_apierrors.HTTPError", func(t *testing.T) { + res := make(chan string) + err := Do(ctx, http.MethodGet, inst.APIServer.URL+"/user", nil, &res) + require.Error(t, err) + require.ErrorContains(t, err, "401: This endpoint requires a Bearer token") + }) - { + // Covers http.NewRequestWithContext + t.Run("InvalidHTTPMethod", func(t *testing.T) { err := Do(ctx, "\x01", "http://localhost", nil, nil) - if err == nil { - t.Fatal("exp non-nil err") - } + require.Error(t, err) require.ErrorContains(t, err, "net/http: invalid method") - } + }) - { + // Covers status code >= 400 error handling switch statement json.Unmarshal + // by hitting the default error handler that returns html + t.Run("InvalidResponse", func(t *testing.T) { err := Do(ctx, http.MethodGet, inst.APIServer.URL+"/404", nil, nil) - if err == nil { - t.Fatal("exp non-nil err") - } + require.Error(t, err) require.ErrorContains(t, err, "invalid character") - } + }) - { + // Covers defaultClient.Do failure + t.Run("InvalidURL", func(t *testing.T) { err := Do(ctx, http.MethodPost, "invalid", nil, nil) - if err == nil { - t.Fatal("exp non-nil err") - } + require.Error(t, err) require.ErrorContains(t, err, "unsupported protocol") - } + }) + + // Covers http.StatusNoContent handling + t.Run("InvalidRequestType", func(t *testing.T) { + hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + + ts := httptest.NewServer(hr) + defer ts.Close() + + err := Do(ctx, http.MethodPost, ts.URL, nil, nil) + require.NoError(t, err) + }) + + // Covers IO errors + t.Run("IOError", func(t *testing.T) { + + for _, statusCode := range []int{http.StatusBadRequest, http.StatusOK} { + + // Covers IO errors for the sc >= 400 and default status code + // handling in the switch statement within do. + testName := fmt.Sprintf("Status=%v", http.StatusText(statusCode)) + t.Run(testName, func(t *testing.T) { + + // We assign a sentinel error to ensure propagation. + sentinel := errors.New("sentinel") + + // This implementation of the http.RoundTripper is a way to + // cover the io.ReadAll(io.LimitReader(...)) lines in the switch + // statements inside do(). + rtFn := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + + // Call the default http.RoundTripper implementation provided + // by the http.Default client to build a valid http.Response. + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + // Wrap the res.Body in an io.ErrReader using our sentinel + // error. This causes the first call to read the response + // body to return our sentinel error. + res.Body = io.NopCloser(iotest.ErrReader(sentinel)) + return res, nil + }) + + // We need to swap the defaultClient with a new client which has + // the (*Client).Transport set to our http.RoundTripper above. + prev := defaultClient + defer func() { + defaultClient = prev + }() + defaultClient = new(http.Client) + defaultClient.Transport = rtFn + + hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(statusCode) + }) + + ts := httptest.NewServer(hr) + defer ts.Close() + + // We send the request and expect back our sentinel error. + err := Do(ctx, http.MethodPost, ts.URL, nil, nil) + require.Error(t, err) + require.Equal(t, sentinel, err) + + }) + } + }) +} + +// roundTripperFunc is like http.HandlerFunc for a http.RoundTripper +type roundTripperFunc func(*http.Request) (*http.Response, error) + +// RoundTrip implements http.RoundTripper by calling itself. +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) } diff --git a/internal/e2e/e2ehooks/e2ehooks.go b/internal/e2e/e2ehooks/e2ehooks.go new file mode 100644 index 000000000..0bd0bd522 --- /dev/null +++ b/internal/e2e/e2ehooks/e2ehooks.go @@ -0,0 +1,201 @@ +// Package e2ehooks provides utilities for end-to-end testing of hooks. +package e2ehooks + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/http/httputil" + "slices" + "sync" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/e2e/e2eapi" + "github.com/supabase/auth/internal/hooks/v0hooks" +) + +type Instance struct { + *e2eapi.Instance + + HookServer *httptest.Server + HookRecorder *HookRecorder +} + +func (o *Instance) Close() error { + defer o.Instance.Close() + defer o.HookServer.Close() + return nil +} + +func New(globalCfg *conf.GlobalConfiguration) (*Instance, error) { + hookRec := NewHookRecorder() + hookSrv := httptest.NewServer(hookRec) + hookRec.Register(&globalCfg.Hook, hookSrv.URL) + + test, err := e2eapi.New(globalCfg) + if err != nil { + defer hookSrv.Close() + + return nil, err + } + + o := &Instance{ + Instance: test, + HookServer: hookSrv, + HookRecorder: hookRec, + } + return o, nil +} + +func HandleSuccess() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "application/json") + _, _ = io.WriteString(w, "{}") + }) +} + +type Hook struct { + mu sync.Mutex + name v0hooks.Name + calls []*HookCall + + hr http.Handler +} + +func NewHook(name v0hooks.Name) *Hook { + o := &Hook{ + name: name, + } + o.SetHandler(HandleSuccess()) + return o +} + +func (o *Hook) ClearCalls() { + o.mu.Lock() + defer o.mu.Unlock() + o.calls = nil +} + +func (o *Hook) GetCalls() []*HookCall { + o.mu.Lock() + defer o.mu.Unlock() + return slices.Clone(o.calls) +} + +func (o *Hook) SetHandler(hr http.Handler) { + o.mu.Lock() + defer o.mu.Unlock() + o.hr = hr +} + +func (o *Hook) ServeHTTP(w http.ResponseWriter, r *http.Request) { + o.mu.Lock() + defer o.mu.Unlock() + + dump, _ := httputil.DumpRequest(r, true) + body, err := io.ReadAll(r.Body) + if err != nil { + code := http.StatusInternalServerError + http.Error(w, http.StatusText(code), code) + return + } + r.Body = io.NopCloser(bytes.NewReader(body)) + + hc := &HookCall{ + Dump: string(dump), + Body: string(body), + Header: r.Header.Clone(), + } + o.calls = append(o.calls, hc) + + o.hr.ServeHTTP(w, r) +} + +type HookCall struct { + Header http.Header + Body string + Dump string +} + +func (o *HookCall) Unmarshal(v any) error { + return json.Unmarshal([]byte(o.Body), v) +} + +type HookRecorder struct { + mux *http.ServeMux + BeforeUserCreated *Hook + AfterUserCreated *Hook + CustomizeAccessToken *Hook + MFAVerification *Hook + PasswordVerification *Hook + SendEmail *Hook + SendSMS *Hook +} + +func NewHookRecorder() *HookRecorder { + o := &HookRecorder{ + mux: http.NewServeMux(), + BeforeUserCreated: NewHook(v0hooks.BeforeUserCreated), + AfterUserCreated: NewHook(v0hooks.AfterUserCreated), + CustomizeAccessToken: NewHook(v0hooks.CustomizeAccessToken), + MFAVerification: NewHook(v0hooks.MFAVerification), + PasswordVerification: NewHook(v0hooks.PasswordVerification), + SendEmail: NewHook(v0hooks.SendEmail), + SendSMS: NewHook(v0hooks.SendSMS), + } + + o.mux.HandleFunc("POST /hooks/{hook}", func(w http.ResponseWriter, r *http.Request) { + //exhaustive:ignore + switch v0hooks.Name(r.PathValue("hook")) { + case v0hooks.BeforeUserCreated: + o.BeforeUserCreated.ServeHTTP(w, r) + + case v0hooks.AfterUserCreated: + o.AfterUserCreated.ServeHTTP(w, r) + + case v0hooks.CustomizeAccessToken: + o.CustomizeAccessToken.ServeHTTP(w, r) + + case v0hooks.MFAVerification: + o.MFAVerification.ServeHTTP(w, r) + + case v0hooks.PasswordVerification: + o.PasswordVerification.ServeHTTP(w, r) + + case v0hooks.SendEmail: + o.SendEmail.ServeHTTP(w, r) + + case v0hooks.SendSMS: + o.SendSMS.ServeHTTP(w, r) + + default: + http.NotFound(w, r) + } + }) + return o +} + +func (o *HookRecorder) Register( + hookCfg *conf.HookConfiguration, + baseURL string, +) { + set := func(cfg *conf.ExtensibilityPointConfiguration, name v0hooks.Name) { + *cfg = conf.ExtensibilityPointConfiguration{ + Enabled: true, + URI: baseURL + "/hooks/" + string(name), + } + } + set(&hookCfg.BeforeUserCreated, v0hooks.BeforeUserCreated) + set(&hookCfg.AfterUserCreated, v0hooks.AfterUserCreated) + set(&hookCfg.CustomAccessToken, v0hooks.CustomizeAccessToken) + set(&hookCfg.MFAVerificationAttempt, v0hooks.MFAVerification) + set(&hookCfg.PasswordVerificationAttempt, v0hooks.PasswordVerification) + set(&hookCfg.SendEmail, v0hooks.SendEmail) + set(&hookCfg.SendSMS, v0hooks.SendSMS) +} + +func (o *HookRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) { + o.mux.ServeHTTP(w, r) +} diff --git a/internal/e2e/e2ehooks/e2ehooks_test.go b/internal/e2e/e2ehooks/e2ehooks_test.go new file mode 100644 index 000000000..5c209e6c0 --- /dev/null +++ b/internal/e2e/e2ehooks/e2ehooks_test.go @@ -0,0 +1,186 @@ +package e2ehooks + +import ( + "context" + "errors" + "net/http/httptest" + "strings" + "testing" + "testing/iotest" + "time" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks/v0hooks" +) + +func TestInstance(t *testing.T) { + { + globalCfg, err := conf.LoadGlobal("../../../hack/test.env") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + globalCfg.DB.Driver = "" + globalCfg.DB.URL = "invalid" + + inst, err := New(globalCfg) + if err == nil { + t.Fatal("exp non-nil err") + } + if inst != nil { + t.Fatal("exp nil *Instance") + } + } + + { + globalCfg, err := conf.LoadGlobal("../../../hack/test.env") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + inst, err := New(globalCfg) + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if inst == nil { + t.Fatal("exp non-nil *Instance") + } + if err := inst.Close(); err != nil { + t.Fatalf("exp nil err from Close; got %v", err) + } + } +} + +func TestHook(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + hook := NewHook(v0hooks.AfterUserCreated) + + { + calls := hook.GetCalls() + if exp, got := 0, len(calls); exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + + u := "http://localhost" + rdr := strings.NewReader("12345") + req := httptest.NewRequestWithContext(ctx, "POST", u, rdr) + res := httptest.NewRecorder() + + hook.ServeHTTP(res, req) + + calls = hook.GetCalls() + if exp, got := 1, len(calls); exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + call := calls[0] + + var got int + if err := call.Unmarshal(&got); err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if exp := 12345; exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + } + + { + u := "http://localhost/hooks/before-user-created" + sentinel := errors.New("sentinel") + rdr := iotest.ErrReader(sentinel) + req := httptest.NewRequestWithContext(ctx, "POST", u, rdr) + res := httptest.NewRecorder() + + hook.ServeHTTP(res, req) + } +} + +func TestHookRecorder(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + hookRec := NewHookRecorder() + tests := []struct { + name v0hooks.Name + hook *Hook + }{ + { + name: v0hooks.BeforeUserCreated, + hook: hookRec.BeforeUserCreated, + }, + { + name: v0hooks.AfterUserCreated, + hook: hookRec.AfterUserCreated, + }, + { + name: v0hooks.CustomizeAccessToken, + hook: hookRec.CustomizeAccessToken, + }, + { + name: v0hooks.MFAVerification, + hook: hookRec.MFAVerification, + }, + { + name: v0hooks.PasswordVerification, + hook: hookRec.PasswordVerification, + }, + { + name: v0hooks.SendEmail, + hook: hookRec.SendEmail, + }, + { + name: v0hooks.SendSMS, + hook: hookRec.SendSMS, + }, + } + + for _, test := range tests { + + { + calls := test.hook.GetCalls() + if exp, got := 0, len(calls); exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + } + + u := "http://localhost/hooks/" + string(test.name) + rdr := strings.NewReader("12345") + req := httptest.NewRequestWithContext(ctx, "POST", u, rdr) + res := httptest.NewRecorder() + hookRec.ServeHTTP(res, req) + + { + calls := test.hook.GetCalls() + if exp, got := 1, len(calls); exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + call := calls[0] + + test.hook.ClearCalls() + if exp, got := 0, len(test.hook.GetCalls()); exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + + var got int + if err := call.Unmarshal(&got); err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if exp := 12345; exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + } + } + + // not found + { + u := "http://localhost/hooks/__invalid-hook-name__" + rdr := strings.NewReader("12345") + req := httptest.NewRequestWithContext(ctx, "POST", u, rdr) + res := httptest.NewRecorder() + hookRec.ServeHTTP(res, req) + + if exp, got := 404, res.Result().StatusCode; exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + } +} diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go deleted file mode 100644 index 797e7d75c..000000000 --- a/internal/hooks/hooks.go +++ /dev/null @@ -1,42 +0,0 @@ -package hooks - -import ( - "net/http" - - "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/hooks/hookshttp" - "github.com/supabase/auth/internal/hooks/hookspgfunc" - "github.com/supabase/auth/internal/hooks/v0hooks" - "github.com/supabase/auth/internal/storage" -) - -type Manager struct { - v0mgr *v0hooks.Manager -} - -func NewManager( - db *storage.Connection, - config *conf.GlobalConfiguration, -) *Manager { - httpDr := hookshttp.New() - pgfuncDr := hookspgfunc.New(db) - return &Manager{ - v0mgr: v0hooks.NewManager(config, httpDr, pgfuncDr), - } -} - -func (o *Manager) InvokeHook( - conn *storage.Connection, - r *http.Request, - input, output any, -) error { - return o.v0mgr.InvokeHook(conn, r, input, output) -} - -func (o *Manager) RunHTTPHook( - r *http.Request, - hookConfig conf.ExtensibilityPointConfiguration, - input any, -) ([]byte, error) { - return o.v0mgr.RunHTTPHook(r, hookConfig, input) -} diff --git a/internal/hooks/hooks_test.go b/internal/hooks/hooks_test.go deleted file mode 100644 index 656142400..000000000 --- a/internal/hooks/hooks_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package hooks_test - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/gofrs/uuid" - "github.com/supabase/auth/internal/api" - "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/hooks" - "github.com/supabase/auth/internal/hooks/v0hooks" - "github.com/supabase/auth/internal/models" - "github.com/supabase/auth/internal/storage" - "github.com/supabase/auth/internal/storage/test" -) - -const ( - apiTestVersion = "1" - apiTestConfig = "../../hack/test.env" -) - -func TestNewManager(t *testing.T) { - { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*4) - defer cancel() - - config := helpConfig(t, apiTestConfig) - conn := helpConn(t, config) - - hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("content-type", "application/json") - - fmt.Fprintln(w, `{}`) - }) - - ts := httptest.NewServer(hr) - defer ts.Close() - - config.Hook.SendEmail.Enabled = true - config.Hook.SendEmail.URI = ts.URL + "/SendEmail" - - a := newAPI(config, conn) - mgr := hooks.NewManager(a.GetDB(), a.GetConfig()) - - { - in := &v0hooks.SendEmailInput{ - User: &models.User{ - ID: uuid.Must(uuid.NewV4()), - }, - } - buf := new(bytes.Buffer) - err := json.NewEncoder(buf).Encode(in) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } - - out := &v0hooks.SendEmailOutput{} - req, err := http.NewRequestWithContext( - ctx, "POST", config.Hook.SendEmail.URI, buf) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } - - err = mgr.InvokeHook(nil, req, in, out) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } - if exp, got := "", out.HookError.Message; exp != got { - t.Fatalf("exp %v; got %v", exp, got) - } - } - - { - in := &v0hooks.SendEmailInput{ - User: &models.User{ - ID: uuid.Must(uuid.NewV4()), - }, - } - buf := new(bytes.Buffer) - err := json.NewEncoder(buf).Encode(in) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } - - req, err := http.NewRequestWithContext( - ctx, "POST", config.Hook.SendEmail.URI, buf) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } - - res, err := mgr.RunHTTPHook(req, config.Hook.SendEmail, in) - if err != nil { - t.Fatalf("exp nil err; got %v", err) - } - - out := &v0hooks.SendEmailOutput{} - if err := json.Unmarshal(res, out); err != nil { - t.Fatalf("exp nil err; got %v", err) - } - if exp, got := "", out.HookError.Message; exp != got { - t.Fatalf("exp %v; got %v", exp, got) - } - } - } -} - -func newAPI( - config *conf.GlobalConfiguration, - conn *storage.Connection, -) *api.API { - limiterOpts := api.NewLimiterOptions(config) - return api.NewAPIWithVersion(config, conn, apiTestVersion, limiterOpts) -} - -func helpConfig(tb testing.TB, configPath string) *conf.GlobalConfiguration { - tb.Helper() - - config, err := conf.LoadGlobal(configPath) - if err != nil { - tb.Fatalf("error loading config %q; got %v", configPath, err) - } - return config -} - -func helpConn(tb testing.TB, config *conf.GlobalConfiguration) *storage.Connection { - tb.Helper() - - conn, err := test.SetupDBConnection(config) - if err != nil { - tb.Fatalf("error setting up db connection: %v", err) - } - return conn -} diff --git a/internal/hooks/hookserrors/hookserrors.go b/internal/hooks/hookserrors/hookserrors.go new file mode 100644 index 000000000..b0de4c9bd --- /dev/null +++ b/internal/hooks/hookserrors/hookserrors.go @@ -0,0 +1,106 @@ +// Package hookserrors holds the Error type and some functions to Check +// responses for errors. +package hookserrors + +import ( + "encoding/json" + "net/http" + + "github.com/supabase/auth/internal/api/apierrors" +) + +// Error is the type propagated by hook endpoints to communicate failure. +type Error struct { + HTTPCode int `json:"http_code,omitempty"` + Message string `json:"message,omitempty"` +} + +// Error implements the error interface by returning e.Message. +func (e *Error) Error() string { return e.Message } + +// As implements the errors.As interface to allow unwrapping as either an +// Error or apierrors.HTTPError, depending on the needs of the caller. +func (e *Error) As(target any) bool { + switch T := target.(type) { + case **Error: + v := (*T) + if v == nil { + return false + } + v.HTTPCode = e.HTTPCode + v.Message = e.Message + return true + case *Error: + T.HTTPCode = e.HTTPCode + T.Message = e.Message + return true + case **apierrors.HTTPError: + v := (*T) + if v == nil { + return false + } + v.HTTPStatus = e.HTTPCode + v.Message = e.Message + return true + case *apierrors.HTTPError: + T.HTTPStatus = e.HTTPCode + T.Message = e.Message + return true + default: + return false + } +} + +// Check will attempt to extract a hook Error from a byte slice and return a +// non-nil error, otherwise Check returns nil if no error was found. +func Check(b []byte) error { + e, ok := fromBytes(b) + if !ok { + return nil + } + return check(e) +} + +func check(e *Error) error { + if e == nil { + return nil + } + + // TODO(cstockton): Changing this would be a BC break, but it also + // doesn't seem to be the best API. For example returning an error object + // with an http_code field set to 500 would not count as an error. + if e.Message == "" { + return nil + } + + httpCode := e.HTTPCode + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &apierrors.HTTPError{ + HTTPStatus: httpCode, + Message: e.Message, + } + return httpError.WithInternalError(e) +} + +func fromBytes(b []byte) (*Error, bool) { + var dst struct { + Error *struct { + HTTPCode int `json:"http_code,omitempty"` + Message string `json:"message,omitempty"` + } `json:"error,omitempty"` + } + if err := json.Unmarshal(b, &dst); err != nil { + return nil, false + } + if dst.Error == nil { + return nil, false + } + e := &Error{ + HTTPCode: dst.Error.HTTPCode, + Message: dst.Error.Message, + } + return e, true +} diff --git a/internal/hooks/hookserrors/hookserrors_test.go b/internal/hooks/hookserrors/hookserrors_test.go new file mode 100644 index 000000000..fb3b2f3ea --- /dev/null +++ b/internal/hooks/hookserrors/hookserrors_test.go @@ -0,0 +1,255 @@ +package hookserrors + +import ( + "errors" + "testing" + + "github.com/supabase/auth/internal/api/apierrors" +) + +func TestFromBytes(t *testing.T) { + cases := []struct { + from string + ok bool + exp *Error + }{ + {from: `text`}, + {from: `null`}, + {from: `{}`}, + {from: `{"key": "val"}`}, + {from: `{"error": null}`}, + + { + from: `{"error": {"message": "failed"}}`, + ok: true, exp: &Error{HTTPCode: 0, Message: "failed"}, + }, + { + from: `{"error": {"http_code": 400}}`, + ok: true, exp: &Error{HTTPCode: 400}, + }, + { + from: `{"error": {"message": "failed", "http_code": 400}}`, + ok: true, exp: &Error{HTTPCode: 400, Message: "failed"}, + }, + { + from: `{"error": {"message": "failed", "http_code": 403}}`, + ok: true, exp: &Error{HTTPCode: 403, Message: "failed"}, + }, + } + for idx, tc := range cases { + t.Logf("test #%v - exp Check(%v) = (%#v, %v)", idx, tc.from, tc.exp, tc.ok) + + e, ok := fromBytes([]byte(tc.from)) + if exp, got := tc.ok, ok; exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + if !tc.ok { + if e != nil { + t.Fatalf("exp nil; got %v", e) + } + continue + } + if exp, got := tc.exp.HTTPCode, e.HTTPCode; exp != got { + t.Fatalf("exp HTTPCode %v; got %v", exp, got) + } + if exp, got := tc.exp.Message, e.Message; exp != got { + t.Fatalf("exp Message %q; got %q", exp, got) + } + + err := (error)(e) + if exp, got := tc.exp.Message, err.Error(); exp != got { + t.Fatalf("exp Error() %q; got %q", exp, got) + } + } +} + +func TestCheck(t *testing.T) { + { + if err := Check([]byte(`invalidjson`)); err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if err := Check([]byte(`{"error": nil}`)); err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if err := Check([]byte(`{"error": {"message": "failed"}}`)); err == nil { + t.Fatal("exp non-nil err") + } + + { + data := `{"error": {"message": "failed", "http_code": 403}}` + err := Check([]byte(data)) + if err == nil { + t.Fatal("exp non-nil err") + } + + e, ok := err.(*apierrors.HTTPError) + if !ok { + t.Fatal("exp error to be http.Error") + } + if exp, got := e.HTTPStatus, 403; exp != got { + t.Fatalf("exp HTTPCode %v; got %v", exp, got) + } + if exp, got := e.Message, "failed"; exp != got { + t.Fatalf("exp Message %q; got %q", exp, got) + } + } + } + + { + if err := check(nil); err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if err := check(&Error{Message: ""}); err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if err := check(&Error{Message: "failed"}); err == nil { + t.Fatal("exp non-nil err") + } + } +} + +func TestAs(t *testing.T) { + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := new(apierrors.HTTPError) + if !errors.As(err, &e) { + t.Fatal("exp errors.As to return true") + } + if exp, got := e.HTTPStatus, 403; exp != got { + t.Fatalf("exp HTTPCode %v; got %v", exp, got) + } + if exp, got := e.Message, "failed"; exp != got { + t.Fatalf("exp Message %q; got %q", exp, got) + } + } + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := new(Error) + if !errors.As(err, &e) { + t.Fatal("exp errors.As to return true") + } + if exp, got := e.HTTPCode, 403; exp != got { + t.Fatalf("exp HTTPCode %v; got %v", exp, got) + } + if exp, got := e.Message, "failed"; exp != got { + t.Fatalf("exp Message %q; got %q", exp, got) + } + } + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := new(apierrors.HTTPError) + if !err.As(&e) { + t.Fatal("exp errors.As to return true") + } + if exp, got := e.HTTPStatus, 403; exp != got { + t.Fatalf("exp HTTPCode %v; got %v", exp, got) + } + if exp, got := e.Message, "failed"; exp != got { + t.Fatalf("exp Message %q; got %q", exp, got) + } + } + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := new(Error) + if !err.As(&e) { + t.Fatal("exp errors.As to return true") + } + if exp, got := e.HTTPCode, 403; exp != got { + t.Fatalf("exp HTTPCode %v; got %v", exp, got) + } + if exp, got := e.Message, "failed"; exp != got { + t.Fatalf("exp Message %q; got %q", exp, got) + } + } + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := new(apierrors.HTTPError) + if !err.As(e) { + t.Fatal("exp errors.As to return true") + } + if exp, got := e.HTTPStatus, 403; exp != got { + t.Fatalf("exp HTTPCode %v; got %v", exp, got) + } + if exp, got := e.Message, "failed"; exp != got { + t.Fatalf("exp Message %q; got %q", exp, got) + } + } + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := new(Error) + if !err.As(e) { + t.Fatal("exp errors.As to return true") + } + if exp, got := e.HTTPCode, 403; exp != got { + t.Fatalf("exp HTTPCode %v; got %v", exp, got) + } + if exp, got := e.Message, "failed"; exp != got { + t.Fatalf("exp Message %q; got %q", exp, got) + } + } + + { + err := errors.New("sentinel") + e := new(Error) + if errors.As(err, &e) { + t.Fatal("exp errors.As to return false") + } + } + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := (*Error)(nil) + if err.As(&e) { + t.Fatal("exp errors.As to return false") + } + } + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := (*apierrors.HTTPError)(nil) + if err.As(&e) { + t.Fatal("exp errors.As to return false") + } + } + + { + err := &Error{ + Message: "failed", + HTTPCode: 403, + } + e := (*error)(nil) + if err.As(&e) { + t.Fatal("exp errors.As to return false") + } + } +} diff --git a/internal/hooks/hookshttp/hookshttp.go b/internal/hooks/hookshttp/hookshttp.go index 96d13b741..e8c39397b 100644 --- a/internal/hooks/hookshttp/hookshttp.go +++ b/internal/hooks/hookshttp/hookshttp.go @@ -82,11 +82,11 @@ func New(opts ...Option) *Dispatcher { func (o *Dispatcher) Dispatch( ctx context.Context, - cfg conf.ExtensibilityPointConfiguration, + cfg *conf.ExtensibilityPointConfiguration, req any, res any, ) error { - data, err := o.RunHTTPHook(ctx, cfg, req) + data, err := o.runHTTPHook(ctx, cfg, req) if err != nil { return err } @@ -99,9 +99,9 @@ func (o *Dispatcher) Dispatch( return nil } -func (o *Dispatcher) RunHTTPHook( +func (o *Dispatcher) runHTTPHook( ctx context.Context, - hookConfig conf.ExtensibilityPointConfiguration, + hookConfig *conf.ExtensibilityPointConfiguration, input any, ) ([]byte, error) { client := http.Client{ diff --git a/internal/hooks/hookshttp/hookshttp_test.go b/internal/hooks/hookshttp/hookshttp_test.go index 84090579a..efce86af4 100644 --- a/internal/hooks/hookshttp/hookshttp_test.go +++ b/internal/hooks/hookshttp/hookshttp_test.go @@ -267,7 +267,7 @@ func TestDispatch(t *testing.T) { } res := M{} - err := dr.Dispatch(testCtx, cfg, tc.req, &res) + err := dr.Dispatch(testCtx, &cfg, tc.req, &res) if tc.err != nil { require.Error(t, err) require.Equal(t, tc.err, err) diff --git a/internal/hooks/hookspgfunc/hookspgfunc.go b/internal/hooks/hookspgfunc/hookspgfunc.go index 165c548b9..5d26a006d 100644 --- a/internal/hooks/hookspgfunc/hookspgfunc.go +++ b/internal/hooks/hookspgfunc/hookspgfunc.go @@ -8,6 +8,7 @@ import ( "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks/hookserrors" "github.com/supabase/auth/internal/storage" ) @@ -47,12 +48,11 @@ func New(db *storage.Connection, opts ...Option) *Dispatcher { func (o *Dispatcher) Dispatch( ctx context.Context, - cfg conf.ExtensibilityPointConfiguration, + cfg *conf.ExtensibilityPointConfiguration, tx *storage.Connection, - req any, - res any, + req, res any, ) error { - data, err := o.RunPostgresHook(ctx, cfg, tx, req) + data, err := o.runPostgresHook(ctx, *cfg, tx, req) if err != nil { return err } @@ -65,7 +65,7 @@ func (o *Dispatcher) Dispatch( return nil } -func (o *Dispatcher) RunPostgresHook( +func (o *Dispatcher) runPostgresHook( ctx context.Context, hookConfig conf.ExtensibilityPointConfiguration, tx *storage.Connection, @@ -108,5 +108,8 @@ func (o *Dispatcher) RunPostgresHook( return nil, err } } + if err := hookserrors.Check(response); err != nil { + return nil, err + } return response, nil } diff --git a/internal/hooks/hookspgfunc/hookspgfunc_test.go b/internal/hooks/hookspgfunc/hookspgfunc_test.go index e9cacb23a..0ffd26125 100644 --- a/internal/hooks/hookspgfunc/hookspgfunc_test.go +++ b/internal/hooks/hookspgfunc/hookspgfunc_test.go @@ -182,6 +182,38 @@ func TestDispatch(t *testing.T) { end; $$ language plpgsql;`, errStr: "500: Error unmarshaling JSON output.", }, + + { + desc: "fail - returned error", + cfg: conf.ExtensibilityPointConfiguration{ + URI: `pg-functions://postgres/auth/v0pgfunc_test_return_input`, + HookName: `"auth"."v0pgfunc_test_return_input"`, + }, + req: M{"error": M{"message": "failed"}}, + sql: ` + create or replace function v0pgfunc_test_return_input(input jsonb) + returns json as $$ + begin + return input; + end; $$ language plpgsql;`, + errStr: "500: failed", + }, + + { + desc: "fail - returned error with status", + cfg: conf.ExtensibilityPointConfiguration{ + URI: `pg-functions://postgres/auth/v0pgfunc_test_return_input`, + HookName: `"auth"."v0pgfunc_test_return_input"`, + }, + req: M{"error": M{"message": "failed", "http_code": 403}}, + sql: ` + create or replace function v0pgfunc_test_return_input(input jsonb) + returns json as $$ + begin + return input; + end; $$ language plpgsql;`, + errStr: "403: failed", + }, } for idx, tc := range cases { @@ -210,7 +242,7 @@ func TestDispatch(t *testing.T) { tx := tc.tx cfg := tc.cfg res := M{} - err := dr.Dispatch(testCtx, cfg, tx, tc.req, &res) + err := dr.Dispatch(testCtx, &cfg, tx, tc.req, &res) if tc.err != nil { require.Error(t, err) require.Equal(t, tc.err, err) diff --git a/internal/hooks/v0hooks/manager.go b/internal/hooks/v0hooks/manager.go index 1be2cc7c3..87fccc2e5 100644 --- a/internal/hooks/v0hooks/manager.go +++ b/internal/hooks/v0hooks/manager.go @@ -1,13 +1,14 @@ package v0hooks import ( + "context" + "errors" "fmt" "net/http" "strings" "time" "github.com/sirupsen/logrus" - "github.com/xeipuuv/gojsonschema" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" @@ -18,9 +19,9 @@ import ( ) type Manager struct { - config *conf.GlobalConfiguration - v0http *hookshttp.Dispatcher - v0pgfunc *hookspgfunc.Dispatcher + config *conf.GlobalConfiguration + http *hookshttp.Dispatcher + pgfunc *hookspgfunc.Dispatcher } func NewManager( @@ -29,9 +30,40 @@ func NewManager( pgfuncDr *hookspgfunc.Dispatcher, ) *Manager { return &Manager{ - config: config, - v0http: httpDr, - v0pgfunc: pgfuncDr, + config: config, + http: httpDr, + pgfunc: pgfuncDr, + } +} + +func (o *Manager) Enabled(name Name) bool { + if cfg, ok := configByName(&o.config.Hook, name); ok { + return cfg.Enabled + } + return false +} + +func configByName( + cfg *conf.HookConfiguration, + name Name, +) (*conf.ExtensibilityPointConfiguration, bool) { + switch name { + case SendSMS: + return &cfg.SendSMS, true + case SendEmail: + return &cfg.SendEmail, true + case CustomizeAccessToken: + return &cfg.CustomAccessToken, true + case MFAVerification: + return &cfg.MFAVerificationAttempt, true + case PasswordVerification: + return &cfg.PasswordVerificationAttempt, true + case BeforeUserCreated: + return &cfg.BeforeUserCreated, true + case AfterUserCreated: + return &cfg.AfterUserCreated, true + default: + return nil, false } } @@ -43,14 +75,6 @@ func (o *Manager) InvokeHook( return o.invokeHook(conn, r, input, output) } -func (o *Manager) RunHTTPHook( - r *http.Request, - hookConfig conf.ExtensibilityPointConfiguration, - input any, -) ([]byte, error) { - return o.v0http.RunHTTPHook(r.Context(), hookConfig, input) -} - // invokeHook invokes the hook code. conn can be nil, in which case a new // transaction is opened. If calling invokeHook within a transaction, always // pass the current transaction, as pool-exhaustion deadlocks are very easy to @@ -60,103 +84,88 @@ func (o *Manager) invokeHook( r *http.Request, input, output any, ) error { - var err error switch input.(type) { default: return apierrors.NewInternalServerError( "Unknown hook type %T.", input) case *SendSMSInput: - hookOutput, ok := output.(*SendSMSOutput) - if !ok { + if _, ok := output.(*SendSMSOutput); !ok { return apierrors.NewInternalServerError( "output should be *hooks.SendSMSOutput") } - if err = o.runHook(r, conn, o.config.Hook.SendSMS, input, hookOutput); err != nil { - return err - } - return checkError(hookOutput) + return o.dispatch( + r.Context(), &o.config.Hook.SendSMS, conn, input, output) case *SendEmailInput: - hookOutput, ok := output.(*SendEmailOutput) - if !ok { + if _, ok := output.(*SendEmailOutput); !ok { return apierrors.NewInternalServerError( "output should be *hooks.SendEmailOutput") } - if err := o.runHook(r, conn, o.config.Hook.SendEmail, input, hookOutput); err != nil { - return err - } - return checkError(hookOutput) + return o.dispatch( + r.Context(), &o.config.Hook.SendEmail, conn, input, output) case *MFAVerificationAttemptInput: - hookOutput, ok := output.(*MFAVerificationAttemptOutput) - if !ok { + if _, ok := output.(*MFAVerificationAttemptOutput); !ok { return apierrors.NewInternalServerError( "output should be *hooks.MFAVerificationAttemptOutput") } - if err := o.runHook(r, conn, o.config.Hook.MFAVerificationAttempt, input, hookOutput); err != nil { - return err - } - return checkError(hookOutput) + return o.dispatch( + r.Context(), &o.config.Hook.MFAVerificationAttempt, conn, input, output) case *PasswordVerificationAttemptInput: - hookOutput, ok := output.(*PasswordVerificationAttemptOutput) - if !ok { + if _, ok := output.(*PasswordVerificationAttemptOutput); !ok { return apierrors.NewInternalServerError( "output should be *hooks.PasswordVerificationAttemptOutput") } - if err := o.runHook(r, conn, o.config.Hook.PasswordVerificationAttempt, input, hookOutput); err != nil { - return err - } - return checkError(hookOutput) + return o.dispatch( + r.Context(), &o.config.Hook.PasswordVerificationAttempt, conn, input, output) case *CustomAccessTokenInput: - hookOutput, ok := output.(*CustomAccessTokenOutput) + _, ok := output.(*CustomAccessTokenOutput) if !ok { return apierrors.NewInternalServerError( "output should be *hooks.CustomAccessTokenOutput") } - if err := o.runHook(r, conn, o.config.Hook.CustomAccessToken, input, hookOutput); err != nil { - return err - } - if err := checkError(hookOutput); err != nil { - return err + return o.dispatch( + r.Context(), &o.config.Hook.CustomAccessToken, conn, input, output) + + case *BeforeUserCreatedInput: + if _, ok := output.(*BeforeUserCreatedOutput); !ok { + return apierrors.NewInternalServerError( + "output should be *hooks.BeforeUserCreatedOutput") } - if err := validateTokenClaims(hookOutput.Claims); err != nil { - httpCode := hookOutput.HookError.HTTPCode - - if httpCode == 0 { - httpCode = http.StatusInternalServerError - } - httpError := &apierrors.HTTPError{ - HTTPStatus: httpCode, - Message: err.Error(), - } - return httpError + return o.dispatch( + r.Context(), &o.config.Hook.BeforeUserCreated, conn, input, output) + + case *AfterUserCreatedInput: + _, ok := output.(*AfterUserCreatedOutput) + if !ok { + return apierrors.NewInternalServerError( + "output should be *hooks.AfterUserCreatedOutput") } - return nil + return o.dispatch( + r.Context(), &o.config.Hook.AfterUserCreated, conn, input, output) } } -func (o *Manager) runHook( - r *http.Request, +func (o *Manager) dispatch( + ctx context.Context, + hookConfig *conf.ExtensibilityPointConfiguration, conn *storage.Connection, - hookConfig conf.ExtensibilityPointConfiguration, input, output any, ) error { - ctx := r.Context() - - logEntry := observability.GetLogEntry(r) + logEntry := observability.GetLogEntryFromContext(ctx) hookStart := time.Now() var err error switch { case strings.HasPrefix(hookConfig.URI, "http:") || strings.HasPrefix(hookConfig.URI, "https:"): - err = o.v0http.Dispatch(ctx, hookConfig, input, output) + err = o.http.Dispatch(ctx, hookConfig, input, output) case strings.HasPrefix(hookConfig.URI, "pg-functions:"): - err = o.v0pgfunc.Dispatch(ctx, hookConfig, conn, input, output) + err = o.pgfunc.Dispatch(ctx, hookConfig, conn, input, output) default: return fmt.Errorf( @@ -174,6 +183,10 @@ func (o *Manager) runHook( "duration": duration.Microseconds(), }).WithError(err).Warn("Hook errored out") + e := new(apierrors.HTTPError) + if errors.As(err, &e) { + return e + } return apierrors.NewInternalServerError( "Error running hook URI: %v", hookConfig.URI).WithInternalError(err) } @@ -187,48 +200,3 @@ func (o *Manager) runHook( return nil } - -func checkError( - hookOutput HookOutput, -) error { - if hookOutput.IsError() { - he := hookOutput.GetHookError() - httpCode := he.HTTPCode - - if httpCode == 0 { - httpCode = http.StatusInternalServerError - } - - httpError := &apierrors.HTTPError{ - HTTPStatus: httpCode, - Message: he.Message, - } - return httpError.WithInternalError(&he) - } - return nil -} - -func validateTokenClaims(outputClaims map[string]interface{}) error { - schemaLoader := gojsonschema.NewStringLoader(MinimumViableTokenSchema) - - documentLoader := gojsonschema.NewGoLoader(outputClaims) - - result, err := gojsonschema.Validate(schemaLoader, documentLoader) - if err != nil { - return err - } - - if !result.Valid() { - var errorMessages string - - for _, desc := range result.Errors() { - errorMessages += fmt.Sprintf("- %s\n", desc) - fmt.Printf("- %s\n", desc) - } - return fmt.Errorf( - "output claims do not conform to the expected schema: \n%s", errorMessages) - - } - - return nil -} diff --git a/internal/hooks/v0hooks/manager_test.go b/internal/hooks/v0hooks/manager_test.go index 626b1b43c..4077b5b5b 100644 --- a/internal/hooks/v0hooks/manager_test.go +++ b/internal/hooks/v0hooks/manager_test.go @@ -13,6 +13,7 @@ import ( "github.com/supabase/auth/internal/e2e" "github.com/supabase/auth/internal/hooks/hookshttp" "github.com/supabase/auth/internal/hooks/hookspgfunc" + "github.com/supabase/auth/internal/models" ) type M = map[string]any @@ -31,28 +32,8 @@ func TestHooks(t *testing.T) { mr := NewManager(globalCfg, httpDr, pgfuncDr) now := time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) - // cover RunHTTPHook - { - globalCfg.Hook.SendSMS = - conf.ExtensibilityPointConfiguration{ - URI: `http://0.0.0.0:12345`, - } - - req := &SendSMSInput{} - htr := httptest.NewRequestWithContext(ctx, "POST", "/api", nil) - _, err := mr.RunHTTPHook(htr, globalCfg.Hook.SendSMS, req) - if err == nil { - t.Fatal("exp non-nil err") - } - } - - // Cover auth hook errors single method - { - ae := &AuthHookError{Message: "test"} - if exp, got := "test", ae.Error(); exp != got { - t.Fatalf("exp %v; got %v", exp, got) - } - } + httpReq := httptest.NewRequestWithContext( + ctx, "GET", "http://localhost/test", nil) type testCase struct { desc string @@ -77,7 +58,7 @@ func TestHooks(t *testing.T) { req: &SendSMSInput{}, res: &SendSMSOutput{}, exp: &SendSMSOutput{}, - errStr: "500: Error running hook URI: http://0.0.0.0:12345", + errStr: "422: Failed to reach hook within maximum time of 0.100000 seconds", }, { @@ -241,62 +222,92 @@ func TestHooks(t *testing.T) { $$ language plpgsql;`, }, - // fail - missing required claims { - desc: "fail - customize_access_token - missing required claims", + desc: "pass - before_user_created", setup: func() { - globalCfg.Hook.CustomAccessToken = + globalCfg.Hook.BeforeUserCreated = conf.ExtensibilityPointConfiguration{ URI: `pg-functions://postgres/auth/` + - `v0hooks_test_customize_access_token_fail_missing`, - HookName: `"auth"."v0hooks_test_customize_access_token_fail_missing"`, + `v0hooks_test_before_user_created`, + HookName: `"auth"."v0hooks_test_before_user_created"`, } }, - req: &CustomAccessTokenInput{ - Claims: &AccessTokenClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Audience: []string{"myaudience"}, - ExpiresAt: jwt.NewNumericDate(now), - IssuedAt: jwt.NewNumericDate(now), - Subject: "mysubject", - }, - Email: "valid.email@supabase.co", - AuthenticatorAssuranceLevel: "aal1", - SessionId: "sid", - Phone: "1234567890", - AppMetaData: M{"appmeta": "val2"}, - Role: "myrole", - }, + req: NewBeforeUserCreatedInput(httpReq, &models.User{}), + res: &BeforeUserCreatedOutput{}, + exp: &BeforeUserCreatedOutput{}, + sql: ` + create or replace function + v0hooks_test_before_user_created(input jsonb) + returns json as $$ + begin + return '{}'::jsonb; + end; $$ language plpgsql;`, + }, + + { + desc: "pass - before_user_created reject", + setup: func() { + globalCfg.Hook.BeforeUserCreated = + conf.ExtensibilityPointConfiguration{ + URI: `pg-functions://postgres/auth/` + + `v0hooks_test_before_user_created_reject`, + HookName: `"auth"."v0hooks_test_before_user_created_reject"`, + } }, - res: &CustomAccessTokenOutput{}, - exp: &CustomAccessTokenOutput{ - Claims: M{ - "aud": []interface{}{"myaudience"}, - "email": "valid.email@supabase.co", - "exp": 1.7040672e+09, - "iat": 1.7040672e+09, - "sub": "mysubject", - "aal": "aal1", - "session_id": "sid", - "is_anonymous": false, - "phone": "1234567890", - "app_metadata": M{"appmeta": "val2"}, - "custom_claim": "custom_value", - "role": "myrole", - }, + req: NewBeforeUserCreatedInput(httpReq, &models.User{}), + res: &BeforeUserCreatedOutput{}, + exp: &BeforeUserCreatedOutput{Decision: "reject"}, + sql: ` + create or replace function + v0hooks_test_before_user_created_reject(input jsonb) + returns json as $$ + begin + return '{"decision": "reject"}'::jsonb; + end; $$ language plpgsql;`, + }, + + { + desc: "pass - before_user_created reject with message", + setup: func() { + globalCfg.Hook.BeforeUserCreated = + conf.ExtensibilityPointConfiguration{ + URI: `pg-functions://postgres/auth/` + + `v0hooks_test_before_user_created_reject_msg`, + HookName: `"auth"."v0hooks_test_before_user_created_reject_msg"`, + } }, + req: NewBeforeUserCreatedInput(httpReq, &models.User{}), + res: &BeforeUserCreatedOutput{}, + exp: &BeforeUserCreatedOutput{Decision: "reject", Message: "test case"}, sql: ` create or replace function - v0hooks_test_customize_access_token_fail_missing(input jsonb) - returns json as $$ - declare - claims jsonb; - begin - claims := input->'claims' || '{"custom_claim": "custom_value"}'::jsonb; - return jsonb_build_object('claims', claims); - end; - $$ language plpgsql;`, - errStr: "500: output claims do not conform to the expected schema", + v0hooks_test_before_user_created_reject_msg(input jsonb) + returns json as $$ + begin + return '{"decision": "reject", "message": "test case"}'::jsonb; + end; $$ language plpgsql;`, + }, + + { + desc: "pass - after_user_created", + setup: func() { + globalCfg.Hook.AfterUserCreated = + conf.ExtensibilityPointConfiguration{ + URI: `pg-functions://postgres/auth/` + + `v0hooks_test_after_user_created`, + HookName: `"auth"."v0hooks_test_after_user_created"`, + } + }, + req: NewAfterUserCreatedInput(httpReq, &models.User{}), + res: &AfterUserCreatedOutput{}, + exp: &AfterUserCreatedOutput{}, + sql: ` + create or replace function + v0hooks_test_after_user_created(input jsonb) + returns json as $$ + begin + return '{}'::jsonb; + end; $$ language plpgsql;`, }, // fail @@ -310,13 +321,8 @@ func TestHooks(t *testing.T) { HookName: `"auth"."v0hooks_test_customize_access_token_failure"`, } }, - req: &CustomAccessTokenInput{}, - res: &CustomAccessTokenOutput{}, - exp: &CustomAccessTokenOutput{ - HookError: AuthHookError{ - Message: "failed hook", - }, - }, + req: &CustomAccessTokenInput{}, + res: &CustomAccessTokenOutput{}, errStr: "500: failed hook", sql: ` create or replace function @@ -327,6 +333,28 @@ func TestHooks(t *testing.T) { end; $$ language plpgsql;`, }, + { + desc: "fail - customize_access_token - error propagation http code", + setup: func() { + globalCfg.Hook.CustomAccessToken = + conf.ExtensibilityPointConfiguration{ + URI: `pg-functions://postgres/auth/` + + `v0hooks_test_customize_access_token_failure`, + HookName: `"auth"."v0hooks_test_customize_access_token_failure"`, + } + }, + req: &CustomAccessTokenInput{}, + res: &CustomAccessTokenOutput{}, + errStr: "403: auth failure", + sql: ` + create or replace function + v0hooks_test_customize_access_token_failure(input jsonb) + returns json as $$ + begin + return '{"error": {"message": "auth failure", "http_code": 403}}'::jsonb; + end; $$ language plpgsql;`, + }, + // fail - invalid URI type { desc: "fail - password_verification_attempt - run hook failure", @@ -380,6 +408,18 @@ func TestHooks(t *testing.T) { res: M{}, errStr: "500: output should be *hooks.PasswordVerificationAttemptOutput", }, + { + desc: "fail - before_user_created - invalid output type", + req: &BeforeUserCreatedInput{}, + res: M{}, + errStr: "500: output should be *hooks.BeforeUserCreatedOutput", + }, + { + desc: "fail - after_user_created - invalid output type", + req: &AfterUserCreatedInput{}, + res: M{}, + errStr: "500: output should be *hooks.AfterUserCreatedOutput", + }, // fail - invalid query { @@ -476,11 +516,77 @@ func TestHooks(t *testing.T) { } require.NoError(t, err) require.Equal(t, tc.exp, tc.res) + } +} + +func TestConfig(t *testing.T) { + globalCfg := &conf.GlobalConfiguration{ + Hook: conf.HookConfiguration{ + SendSMS: conf.ExtensibilityPointConfiguration{ + URI: "http:localhost/" + string(SendSMS), + }, + SendEmail: conf.ExtensibilityPointConfiguration{ + URI: "http:localhost/" + string(SendEmail), + }, + CustomAccessToken: conf.ExtensibilityPointConfiguration{ + URI: "http:localhost/" + string(CustomizeAccessToken), + }, + MFAVerificationAttempt: conf.ExtensibilityPointConfiguration{ + URI: "http:localhost/" + string(MFAVerification), + }, + PasswordVerificationAttempt: conf.ExtensibilityPointConfiguration{ + URI: "http:localhost/" + string(PasswordVerification), + }, + BeforeUserCreated: conf.ExtensibilityPointConfiguration{ + URI: "http:localhost/" + string(BeforeUserCreated), + }, + AfterUserCreated: conf.ExtensibilityPointConfiguration{ + URI: "http:localhost/" + string(AfterUserCreated), + }, + }, + } + cfg := &globalCfg.Hook + + mr := new(Manager) + mr.config = globalCfg + + tests := []struct { + cfg *conf.HookConfiguration + name Name + exp *conf.ExtensibilityPointConfiguration + ok bool + }{ + {}, + {cfg: cfg, ok: true, + name: SendSMS, exp: &cfg.SendSMS}, + {cfg: cfg, ok: true, + name: SendEmail, exp: &cfg.SendEmail}, + {cfg: cfg, ok: true, + name: CustomizeAccessToken, exp: &cfg.CustomAccessToken}, + {cfg: cfg, ok: true, + name: MFAVerification, exp: &cfg.MFAVerificationAttempt}, + {cfg: cfg, ok: true, + name: PasswordVerification, exp: &cfg.PasswordVerificationAttempt}, + {cfg: cfg, ok: true, + name: BeforeUserCreated, exp: &cfg.BeforeUserCreated}, + {cfg: cfg, ok: true, + name: AfterUserCreated, exp: &cfg.AfterUserCreated}, + } + for idx, test := range tests { + t.Logf("test #%v - exp ok %v with cfg %v from name %v", + idx, test.ok, test.exp, test.name) - if h, ok := tc.res.(HookOutput); ok { - _ = h.Error() - _ = h.GetHookError() - _ = h.IsError() + require.Equal(t, false, mr.Enabled(test.name)) + + got, ok := configByName(test.cfg, test.name) + require.Equal(t, test.ok, ok) + require.Equal(t, test.exp, got) + + if got == nil { + continue } + + got.Enabled = true + require.Equal(t, true, mr.Enabled(test.name)) } } diff --git a/internal/hooks/v0hooks/v0hooks.go b/internal/hooks/v0hooks/v0hooks.go index 046f003ca..33a054ec1 100644 --- a/internal/hooks/v0hooks/v0hooks.go +++ b/internal/hooks/v0hooks/v0hooks.go @@ -1,98 +1,100 @@ package v0hooks import ( + "net/http" + "time" + "github.com/gofrs/uuid" "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/utilities" ) -type HookType string +type Name string const ( - PostgresHook HookType = "pg-functions" + SendSMS Name = "send-sms" + SendEmail Name = "send-email" + CustomizeAccessToken Name = "customize-access-token" + MFAVerification Name = "mfa-verification" + PasswordVerification Name = "password-verification" + BeforeUserCreated Name = "before-user-created" + AfterUserCreated Name = "after-user-created" ) -// Hook Names const ( HookRejection = "reject" ) -type HTTPHookInput interface { - IsHTTPHook() +const ( + DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected." + DefaultPasswordHookRejectionMessage = "Further password verification attempts will be rejected." +) + +type Metadata struct { + UUID uuid.UUID `json:"uuid"` + Time time.Time `json:"time"` + + // Hook name + Name Name `json:"name,omitempty"` + + // IP Address of the request, if present + IPAddress string `json:"ip_address,omitempty"` +} + +func NewMetadata(r *http.Request, name Name) *Metadata { + return &Metadata{ + UUID: uuid.Must(uuid.NewV4()), + Time: time.Now(), + IPAddress: utilities.GetIPAddress(r), + Name: name, + } +} + +type BeforeUserCreatedInput struct { + Metadata *Metadata `json:"metadata"` + User *models.User `json:"user"` +} + +func NewBeforeUserCreatedInput( + r *http.Request, + user *models.User, +) *BeforeUserCreatedInput { + return &BeforeUserCreatedInput{ + Metadata: NewMetadata(r, BeforeUserCreated), + User: user, + } } -type HookOutput interface { - IsError() bool - Error() string - GetHookError() AuthHookError +type BeforeUserCreatedOutput struct { + Decision string `json:"decision"` + Message string `json:"message"` } +type AfterUserCreatedInput struct { + Metadata *Metadata `json:"metadata"` + User *models.User `json:"user"` +} + +func NewAfterUserCreatedInput( + r *http.Request, + user *models.User, +) *AfterUserCreatedInput { + return &AfterUserCreatedInput{ + Metadata: NewMetadata(r, AfterUserCreated), + User: user, + } +} + +type AfterUserCreatedOutput struct{} + // TODO(joel): Move this to phone package type SMS struct { OTP string `json:"otp,omitempty"` SMSType string `json:"sms_type,omitempty"` } -// #nosec -const MinimumViableTokenSchema = `{ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "aud": { - "type": ["string", "array"] - }, - "exp": { - "type": "integer" - }, - "jti": { - "type": "string" - }, - "iat": { - "type": "integer" - }, - "iss": { - "type": "string" - }, - "nbf": { - "type": "integer" - }, - "sub": { - "type": "string" - }, - "email": { - "type": "string" - }, - "phone": { - "type": "string" - }, - "app_metadata": { - "type": "object", - "additionalProperties": true - }, - "user_metadata": { - "type": "object", - "additionalProperties": true - }, - "role": { - "type": "string" - }, - "aal": { - "type": "string" - }, - "amr": { - "type": "array", - "items": { - "type": "object" - } - }, - "session_id": { - "type": "string" - } - }, - "required": ["aud", "exp", "iat", "sub", "email", "phone", "role", "aal", "session_id", "is_anonymous"] -}` - // AccessTokenClaims is a struct thats used for JWT claims type AccessTokenClaims struct { jwt.RegisteredClaims @@ -115,9 +117,8 @@ type MFAVerificationAttemptInput struct { } type MFAVerificationAttemptOutput struct { - Decision string `json:"decision"` - Message string `json:"message"` - HookError AuthHookError `json:"error"` + Decision string `json:"decision"` + Message string `json:"message"` } type PasswordVerificationAttemptInput struct { @@ -126,10 +127,9 @@ type PasswordVerificationAttemptInput struct { } type PasswordVerificationAttemptOutput struct { - Decision string `json:"decision"` - Message string `json:"message"` - ShouldLogoutUser bool `json:"should_logout_user"` - HookError AuthHookError `json:"error"` + Decision string `json:"decision"` + Message string `json:"message"` + ShouldLogoutUser bool `json:"should_logout_user"` } type CustomAccessTokenInput struct { @@ -139,8 +139,7 @@ type CustomAccessTokenInput struct { } type CustomAccessTokenOutput struct { - Claims map[string]interface{} `json:"claims"` - HookError AuthHookError `json:"error,omitempty"` + Claims map[string]interface{} `json:"claims"` } type SendSMSInput struct { @@ -149,7 +148,6 @@ type SendSMSInput struct { } type SendSMSOutput struct { - HookError AuthHookError `json:"error,omitempty"` } type SendEmailInput struct { @@ -158,69 +156,4 @@ type SendEmailInput struct { } type SendEmailOutput struct { - HookError AuthHookError `json:"error,omitempty"` -} - -func (mf *MFAVerificationAttemptOutput) IsError() bool { - return mf.HookError.Message != "" -} - -func (mf *MFAVerificationAttemptOutput) Error() string { - return mf.HookError.Message -} - -func (mf *MFAVerificationAttemptOutput) GetHookError() AuthHookError { return mf.HookError } - -func (p *PasswordVerificationAttemptOutput) IsError() bool { - return p.HookError.Message != "" } - -func (p *PasswordVerificationAttemptOutput) Error() string { - return p.HookError.Message -} - -func (p *PasswordVerificationAttemptOutput) GetHookError() AuthHookError { return p.HookError } - -func (ca *CustomAccessTokenOutput) IsError() bool { - return ca.HookError.Message != "" -} - -func (ca *CustomAccessTokenOutput) Error() string { - return ca.HookError.Message -} - -func (ca *CustomAccessTokenOutput) GetHookError() AuthHookError { return ca.HookError } - -func (cs *SendSMSOutput) IsError() bool { - return cs.HookError.Message != "" -} - -func (cs *SendSMSOutput) Error() string { - return cs.HookError.Message -} - -func (cs *SendSMSOutput) GetHookError() AuthHookError { return cs.HookError } - -func (cs *SendEmailOutput) IsError() bool { - return cs.HookError.Message != "" -} - -func (cs *SendEmailOutput) Error() string { - return cs.HookError.Message -} - -func (cs *SendEmailOutput) GetHookError() AuthHookError { return cs.HookError } - -type AuthHookError struct { - HTTPCode int `json:"http_code,omitempty"` - Message string `json:"message,omitempty"` -} - -func (a *AuthHookError) Error() string { - return a.Message -} - -const ( - DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected." - DefaultPasswordHookRejectionMessage = "Further password verification attempts will be rejected." -)