From 01c04efd4d0efa861fb63201260d56f3964e3665 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Wed, 14 May 2025 10:13:06 +0200 Subject: [PATCH 1/2] fix: move email & sms send out of the `POST /user` transaction --- internal/api/mail.go | 19 +++++++++++++------ internal/api/phone.go | 4 ++-- internal/api/user.go | 32 ++++++++++++++++++++++++++------ 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/internal/api/mail.go b/internal/api/mail.go index f90d9a74c..acc8b29d0 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -303,7 +303,7 @@ func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *model maxFrequency := config.SMTP.MaxFrequency otpLength := config.Mailer.OtpLength - if err = validateSentWithinFrequencyLimit(u.ConfirmationSentAt, maxFrequency); err != nil { + if err = validateSentWithinFrequencyLimitEmail(u.ConfirmationSentAt, maxFrequency); err != nil { return err } oldToken := u.ConfirmationToken @@ -370,7 +370,7 @@ func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *m config := a.config otpLength := config.Mailer.OtpLength - if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { + if err := validateSentWithinFrequencyLimitEmail(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { return err } @@ -407,7 +407,7 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u maxFrequency := config.SMTP.MaxFrequency otpLength := config.Mailer.OtpLength - if err := validateSentWithinFrequencyLimit(u.ReauthenticationSentAt, maxFrequency); err != nil { + if err := validateSentWithinFrequencyLimitEmail(u.ReauthenticationSentAt, maxFrequency); err != nil { return err } @@ -445,7 +445,7 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U // since Magic Link is just a recovery with a different template and behaviour // around new users we will reuse the recovery db timer to prevent potential abuse - if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { + if err := validateSentWithinFrequencyLimitEmail(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { return err } @@ -482,7 +482,7 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models config := a.config otpLength := config.Mailer.OtpLength - if err := validateSentWithinFrequencyLimit(u.EmailChangeSentAt, config.SMTP.MaxFrequency); err != nil { + if err := validateSentWithinFrequencyLimitEmail(u.EmailChangeSentAt, config.SMTP.MaxFrequency); err != nil { return err } @@ -553,13 +553,20 @@ func (a *API) validateEmail(email string) (string, error) { return strings.ToLower(email), nil } -func validateSentWithinFrequencyLimit(sentAt *time.Time, frequency time.Duration) error { +func validateSentWithinFrequencyLimitEmail(sentAt *time.Time, frequency time.Duration) error { if sentAt != nil && sentAt.Add(frequency).After(time.Now()) { return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, frequency)) } return nil } +func validateSentWithinFrequencyLimitSMS(sentAt *time.Time, frequency time.Duration) error { + if sentAt != nil && sentAt.Add(frequency).After(time.Now()) { + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, frequency)) + } + return nil +} + var emailLabelPattern = regexp.MustCompile("[+][^@]+@") func (a *API) checkEmailAddressAuthorization(email string) bool { diff --git a/internal/api/phone.go b/internal/api/phone.go index 9a8662dbb..214226946 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -70,8 +70,8 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use // intentionally keeping this before the test OTP, so that the behavior // of regular and test OTPs is similar - if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { - return "", apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, config.Sms.MaxFrequency)) + if err := validateSentWithinFrequencyLimitSMS(sentAt, config.Sms.MaxFrequency); err != nil { + return "", err } now := time.Now() diff --git a/internal/api/user.go b/internal/api/user.go index 8da2e82d1..03df4227c 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -181,6 +181,9 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } } + var sendEmailChange, sendPhoneConfirmation bool + flowType := getFlowFromChallenge(params.CodeChallenge) + err := db.Transaction(func(tx *storage.Connection) error { var terr error if params.Password != nil { @@ -223,17 +226,18 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } } else { - flowType := getFlowFromChallenge(params.CodeChallenge) if isPKCEFlow(flowType) { _, terr := generateFlowState(tx, models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) if terr != nil { return terr } - } - if terr = a.sendEmailChange(r, tx, user, params.Email, flowType); terr != nil { - return terr + + if err := validateSentWithinFrequencyLimitEmail(user.EmailChangeSentAt, config.SMTP.MaxFrequency); err != nil { + return err } + + sendEmailChange = true } } @@ -247,9 +251,11 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { return terr } } else { - if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneChangeVerification, params.Channel); terr != nil { - return terr + if err := validateSentWithinFrequencyLimitSMS(user.ReauthenticationSentAt, config.SMTP.MaxFrequency); err != nil { + return err } + + sendPhoneConfirmation = true } } @@ -263,5 +269,19 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { return err } + if sendEmailChange { + // email sending should not hold a database transaction open as latency incurred by SMTP or HTTP hooks can exhaust the database pool + if err := a.sendEmailChange(r, db, user, params.Email, flowType); err != nil { + return err + } + } + + if sendPhoneConfirmation { + // SMS sending should not hold a database transaction open as latency incurred by SMTP or HTTP hooks can exhaust the database pool + if _, err := a.sendPhoneConfirmation(r, db, user, params.Phone, phoneChangeVerification, params.Channel); err != nil { + return err + } + } + return sendJSON(w, http.StatusOK, user) } From 2e612286d1ee9db524c9dbe667a682bc8d33ad5b Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Mon, 19 May 2025 14:55:58 +0200 Subject: [PATCH 2/2] restructure more so updates and inserts are atomic --- internal/api/mail.go | 50 ++++++++++++++++++++++--------------------- internal/api/phone.go | 46 ++++++++++++++++++++------------------- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/internal/api/mail.go b/internal/api/mail.go index acc8b29d0..b6ed02fc7 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -478,7 +478,7 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U } // sendEmailChange sends out an email change token to the new email. -func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models.User, email string, flowType models.FlowType) error { +func (a *API) sendEmailChange(r *http.Request, db *storage.Connection, u *models.User, email string, flowType models.FlowType) error { config := a.config otpLength := config.Mailer.OtpLength @@ -503,7 +503,7 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models u.EmailChangeConfirmStatus = zeroConfirmation now := time.Now() - if err := a.sendEmail(r, tx, u, mail.EmailChangeVerification, otpCurrent, otpNew, u.EmailChangeTokenNew); err != nil { + if err := a.sendEmail(r, db, u, mail.EmailChangeVerification, otpCurrent, otpNew, u.EmailChangeTokenNew); err != nil { if errors.Is(err, EmailRateLimitExceeded) { return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { @@ -512,31 +512,33 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(err) } - u.EmailChangeSentAt = &now - if err := tx.UpdateOnly( - u, - "email_change_token_current", - "email_change_token_new", - "email_change", - "email_change_sent_at", - "email_change_confirm_status", - ); err != nil { - return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error updating user for email change")) - } + return db.Transaction(func(tx *storage.Connection) error { + u.EmailChangeSentAt = &now + if err := tx.UpdateOnly( + u, + "email_change_token_current", + "email_change_token_new", + "email_change", + "email_change_sent_at", + "email_change_confirm_status", + ); err != nil { + return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error updating user for email change")) + } - if u.EmailChangeTokenCurrent != "" { - if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent); err != nil { - return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token current")) + if u.EmailChangeTokenCurrent != "" { + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent); err != nil { + return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token current")) + } } - } - if u.EmailChangeTokenNew != "" { - if err := models.CreateOneTimeToken(tx, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew); err != nil { - return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token new")) + if u.EmailChangeTokenNew != "" { + if err := models.CreateOneTimeToken(tx, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew); err != nil { + return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token new")) + } } - } - return nil + return nil + }) } func (a *API) validateEmail(email string) (string, error) { @@ -586,7 +588,7 @@ func (a *API) checkEmailAddressAuthorization(email string) bool { return true } -func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, emailActionType, otp, otpNew, tokenHashWithPrefix string) error { +func (a *API) sendEmail(r *http.Request, db *storage.Connection, u *models.User, emailActionType, otp, otpNew, tokenHashWithPrefix string) error { ctx := r.Context() config := a.config referrerURL := utilities.GetReferrer(r, config) @@ -657,7 +659,7 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, EmailData: emailData, } output := v0hooks.SendEmailOutput{} - return a.hooksMgr.InvokeHook(tx, r, &input, &output) + return a.hooksMgr.InvokeHook(db, r, &input, &output) } mr := a.Mailer() diff --git a/internal/api/phone.go b/internal/api/phone.go index 214226946..fb3e73cf0 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -43,7 +43,7 @@ func formatPhoneNumber(phone string) string { } // sendPhoneConfirmation sends an otp to the user's phone number -func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) { +func (a *API) sendPhoneConfirmation(r *http.Request, db *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) { config := a.config var token *string @@ -102,7 +102,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use }, } output := v0hooks.SendSMSOutput{} - err := a.hooksMgr.InvokeHook(tx, r, &input, &output) + err := a.hooksMgr.InvokeHook(db, r, &input, &output) if err != nil { return "", err } @@ -133,29 +133,31 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use user.ReauthenticationSentAt = &now } - if err := tx.UpdateOnly(user, includeFields...); err != nil { - return messageID, errors.Wrap(err, "Database error updating user for phone") - } - - var ottErr error - switch otpType { - case phoneConfirmationOtp: - if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ConfirmationToken, models.ConfirmationToken); err != nil { - ottErr = errors.Wrap(err, "Database error creating confirmation token for phone") + return messageID, db.Transaction(func(tx *storage.Connection) error { + if err := tx.UpdateOnly(user, includeFields...); err != nil { + return errors.Wrap(err, "Database error updating user for phone") } - case phoneChangeVerification: - if err := models.CreateOneTimeToken(tx, user.ID, user.PhoneChange, user.PhoneChangeToken, models.PhoneChangeToken); err != nil { - ottErr = errors.Wrap(err, "Database error creating phone change token") + + var ottErr error + switch otpType { + case phoneConfirmationOtp: + if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ConfirmationToken, models.ConfirmationToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating confirmation token for phone") + } + case phoneChangeVerification: + if err := models.CreateOneTimeToken(tx, user.ID, user.PhoneChange, user.PhoneChangeToken, models.PhoneChangeToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating phone change token") + } + case phoneReauthenticationOtp: + if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ReauthenticationToken, models.ReauthenticationToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating reauthentication token for phone") + } } - case phoneReauthenticationOtp: - if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ReauthenticationToken, models.ReauthenticationToken); err != nil { - ottErr = errors.Wrap(err, "Database error creating reauthentication token for phone") + if ottErr != nil { + return apierrors.NewInternalServerError("error creating one time token").WithInternalError(ottErr) } - } - if ottErr != nil { - return messageID, apierrors.NewInternalServerError("error creating one time token").WithInternalError(ottErr) - } - return messageID, nil + return nil + }) } func generateSMSFromTemplate(SMSTemplate *template.Template, otp string) (string, error) {