Skip to content

feat: hooks round 2 - remove indirection and simplify error handling #2025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/hooks"
"github.com/supabase/auth/internal/hooks/hookshttp"
"github.com/supabase/auth/internal/hooks/hookspgfunc"
"github.com/supabase/auth/internal/hooks/v0hooks"
"github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
Expand All @@ -33,7 +35,7 @@ type API struct {
config *conf.GlobalConfiguration
version string

hooksMgr *hooks.Manager
hooksMgr *v0hooks.Manager
hibpClient *hibp.PwnedClient

// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
Expand Down Expand Up @@ -87,7 +89,9 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
api.limiterOpts = NewLimiterOptions(globalConfig)
}
if api.hooksMgr == nil {
api.hooksMgr = hooks.NewManager(db, globalConfig)
httpDr := hookshttp.New()
pgfuncDr := hookspgfunc.New(db)
api.hooksMgr = v0hooks.NewManager(globalConfig, httpDr, pgfuncDr)
}
if api.config.Password.HIBP.Enabled {
httpClient := &http.Client{
Expand Down
34 changes: 15 additions & 19 deletions internal/api/hooks_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"net/http"
"testing"

Expand All @@ -12,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/hooks/hookserrors"
"github.com/supabase/auth/internal/hooks/v0hooks"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
Expand Down Expand Up @@ -70,20 +70,20 @@ func (ts *HooksTestSuite) TestRunHTTPHook() {
testURL := "http://localhost:54321/functions/v1/custom-sms-sender"
ts.Config.Hook.SendSMS.URI = testURL

unsuccessfulResponse := v0hooks.AuthHookError{
unsuccessfulResponse := hookserrors.Error{
HTTPCode: http.StatusUnprocessableEntity,
Message: "test error",
}

testCases := []struct {
description string
expectError bool
mockResponse v0hooks.AuthHookError
mockResponse hookserrors.Error
}{
{
description: "Hook returns success",
expectError: false,
mockResponse: v0hooks.AuthHookError{},
mockResponse: hookserrors.Error{},
},
{
description: "Hook returns error",
Expand All @@ -102,23 +102,22 @@ func (ts *HooksTestSuite) TestRunHTTPHook() {
Post("/").
MatchType("json").
Reply(http.StatusUnprocessableEntity).
JSON(v0hooks.SendSMSOutput{HookError: unsuccessfulResponse})
JSON(struct {
Error *hookserrors.Error `json:"error,omitempty"`
}{Error: &unsuccessfulResponse})

for _, tc := range testCases {
ts.Run(tc.description, func() {
req, _ := http.NewRequest("POST", ts.Config.Hook.SendSMS.URI, nil)
body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)

var output v0hooks.SendSMSOutput
err := ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)

if !tc.expectError {
require.NoError(ts.T(), err)
} else {
require.Error(ts.T(), err)
if body != nil {
var output v0hooks.SendSMSOutput
require.NoError(ts.T(), json.Unmarshal(body, &output))
require.Equal(ts.T(), unsuccessfulResponse.HTTPCode, output.HookError.HTTPCode)
require.Equal(ts.T(), unsuccessfulResponse.Message, output.HookError.Message)
}
require.Equal(ts.T(), output, v0hooks.SendSMSOutput{})
}
})
}
Expand Down Expand Up @@ -154,12 +153,9 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() {
req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil)
require.NoError(ts.T(), err)

body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)
require.NoError(ts.T(), err)

var output v0hooks.SendSMSOutput
err = json.Unmarshal(body, &output)
require.NoError(ts.T(), err, "Unmarshal should not fail")
err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)
require.NoError(ts.T(), err)

// Ensure that all expected HTTP interactions (mocks) have been called
require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry")
Expand All @@ -186,10 +182,10 @@ func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() {
req, err := http.NewRequest("POST", "http://localhost:9999/otp", nil)
require.NoError(ts.T(), err)

_, err = ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)
var output v0hooks.SendSMSOutput
err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)
require.Error(ts.T(), err, "Expected an error due to wrong content type")
require.Contains(ts.T(), err.Error(), "Invalid JSON response.")

require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called")
}

Expand Down
89 changes: 88 additions & 1 deletion internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package api

import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"time"

"github.com/gofrs/uuid"
"github.com/golang-jwt/jwt/v5"
"github.com/xeipuuv/gojsonschema"

"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/hooks/v0hooks"
Expand Down Expand Up @@ -369,14 +371,16 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user
if err != nil {
return "", 0, err
}
if err := validateTokenClaims(output.Claims); err != nil {
return "", 0, err
}
gotrueClaims = jwt.MapClaims(output.Claims)
}

signed, err := signJwt(&config.JWT, gotrueClaims)
if err != nil {
return "", 0, err
}

return signed, expiresAt.Unix(), nil
}

Expand Down Expand Up @@ -491,3 +495,86 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
User: user,
}, nil
}

var schemaLoader = gojsonschema.NewStringLoader(MinimumViableTokenSchema)

func validateTokenClaims(outputClaims map[string]interface{}) error {
documentLoader := gojsonschema.NewGoLoader(outputClaims)
result, err := gojsonschema.Validate(schemaLoader, documentLoader)
if err != nil {
return err
}

if !result.Valid() {
var errorMessages string

for _, desc := range result.Errors() {
errorMessages += fmt.Sprintf("- %s\n", desc)
fmt.Printf("- %s\n", desc)
}
return fmt.Errorf(
"output claims do not conform to the expected schema: \n%s", errorMessages)

}

return nil
}

// #nosec
const MinimumViableTokenSchema = `{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"aud": {
"type": ["string", "array"]
},
"exp": {
"type": "integer"
},
"jti": {
"type": "string"
},
"iat": {
"type": "integer"
},
"iss": {
"type": "string"
},
"nbf": {
"type": "integer"
},
"sub": {
"type": "string"
},
"email": {
"type": "string"
},
"phone": {
"type": "string"
},
"app_metadata": {
"type": "object",
"additionalProperties": true
},
"user_metadata": {
"type": "object",
"additionalProperties": true
},
"role": {
"type": "string"
},
"aal": {
"type": "string"
},
"amr": {
"type": "array",
"items": {
"type": "object"
}
},
"session_id": {
"type": "string"
}
},
"required": ["aud", "exp", "iat", "sub", "email", "phone", "role", "aal", "session_id", "is_anonymous"]
}`
17 changes: 17 additions & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,9 @@ type HookConfiguration struct {
CustomAccessToken ExtensibilityPointConfiguration `json:"custom_access_token" split_words:"true"`
SendEmail ExtensibilityPointConfiguration `json:"send_email" split_words:"true"`
SendSMS ExtensibilityPointConfiguration `json:"send_sms" split_words:"true"`

BeforeUserCreated ExtensibilityPointConfiguration `json:"before_user_created" split_words:"true"`
AfterUserCreated ExtensibilityPointConfiguration `json:"after_user_created" split_words:"true"`
}

type HTTPHookSecrets []string
Expand Down Expand Up @@ -671,6 +674,8 @@ func (h *HookConfiguration) Validate() error {
h.CustomAccessToken,
h.SendSMS,
h.SendEmail,
h.BeforeUserCreated,
h.AfterUserCreated,
}
for _, point := range points {
if err := point.ValidateExtensibilityPoint(); err != nil {
Expand Down Expand Up @@ -888,6 +893,18 @@ func populateGlobal(config *GlobalConfiguration) error {
}
}

if config.Hook.BeforeUserCreated.Enabled {
if err := config.Hook.BeforeUserCreated.PopulateExtensibilityPoint(); err != nil {
return err
}
}

if config.Hook.AfterUserCreated.Enabled {
if err := config.Hook.AfterUserCreated.PopulateExtensibilityPoint(); err != nil {
return err
}
}

if config.SAML.Enabled {
if err := config.SAML.PopulateFields(config.API.ExternalURL); err != nil {
return err
Expand Down
46 changes: 46 additions & 0 deletions internal/conf/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,36 @@ func TestGlobal(t *testing.T) {
os.Setenv("API_EXTERNAL_URL", "http://localhost:9999")
}

{
os.Setenv("API_EXTERNAL_URL", "")
cfg := new(GlobalConfiguration)
cfg.Hook = HookConfiguration{
BeforeUserCreated: ExtensibilityPointConfiguration{
Enabled: true,
URI: "\n",
},
}

err := populateGlobal(cfg)
require.Error(t, err)
os.Setenv("API_EXTERNAL_URL", "http://localhost:9999")
}

{
os.Setenv("API_EXTERNAL_URL", "")
cfg := new(GlobalConfiguration)
cfg.Hook = HookConfiguration{
AfterUserCreated: ExtensibilityPointConfiguration{
Enabled: true,
URI: "\n",
},
}

err := populateGlobal(cfg)
require.Error(t, err)
os.Setenv("API_EXTERNAL_URL", "http://localhost:9999")
}

{
os.Setenv("API_EXTERNAL_URL", "")
cfg := new(GlobalConfiguration)
Expand Down Expand Up @@ -490,6 +520,11 @@ func TestValidate(t *testing.T) {
err: `conf: session timebox duration must` +
` be positive when set, was -1`,
},
{
val: &SessionsConfiguration{InactivityTimeout: toPtr(time.Duration(-1))},
err: `conf: session inactivity timeout duration must` +
` be positive when set, was -1ns`,
},
{
val: &SessionsConfiguration{AllowLowAAL: nil},
},
Expand Down Expand Up @@ -532,6 +567,17 @@ func TestValidate(t *testing.T) {
err: `conf: mailer validation headers not a map[string][]string format:` +
` invalid character 'i' looking for beginning of value`,
},
{
val: &MailerConfiguration{EmailValidationBlockedMX: "invalid"},
err: `conf: email_validation_blocked_mx`,
},
{
val: &MailerConfiguration{EmailValidationBlockedMX: `["foo.com"]`},
check: func(t *testing.T, v any) {
got := (v.(*MailerConfiguration)).GetEmailValidationBlockedMXRecords()
require.True(t, got["foo.com"])
},
},

{
val: &CaptchaConfiguration{Enabled: false},
Expand Down
8 changes: 7 additions & 1 deletion internal/e2e/e2e.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@ var (
configPath string
)

var isTesting func() bool = testing.Testing

func init() {
if testing.Testing() {
initPackage()
}

func initPackage() {
if isTesting() {
_, thisFile, _, _ := runtime.Caller(0)
projectRoot = filepath.Join(filepath.Dir(thisFile), "../..")
configPath = filepath.Join(GetProjectRoot(), "hack", "test.env")
Expand Down
23 changes: 23 additions & 0 deletions internal/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,27 @@ func TestUtils(t *testing.T) {
t.Fatal("exp non-nil err")
}
}()

// block init from main()
func() {
restore := isTesting
defer func() {
isTesting = restore
}()
isTesting = func() bool { return false }

var errStr string
func() {
defer func() {
errStr = recover().(string)
}()

initPackage()
}()

exp := "package e2e may not be used in a main package"
if errStr != exp {
t.Fatalf("exp %v; got %v", exp, errStr)
}
}()
}
Loading
Loading