Skip to content
7 changes: 4 additions & 3 deletions backend/internal/handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/kurodakayn/mpp-backend/internal/middleware"
"github.com/kurodakayn/mpp-backend/internal/models"
"github.com/kurodakayn/mpp-backend/internal/pkg/rediskey"
"github.com/kurodakayn/mpp-backend/internal/services/email"
)

Expand Down Expand Up @@ -217,15 +218,15 @@ func (h *AuthHandler) generateRandomCode(length int) (string, error) {
}

func verificationCodeKey(scene, email string) string {
return fmt.Sprintf("auth:code:%s:%s", scene, verificationEmailKeyDigest(email))
return fmt.Sprintf("auth:code:%s:%s", rediskey.Tag("email", verificationEmailKeyDigest(email)), rediskey.Part(scene))
}

func verificationAttemptKey(scene, email string) string {
return fmt.Sprintf("auth:code_attempts:%s:%s", scene, verificationEmailKeyDigest(email))
return fmt.Sprintf("auth:code_attempts:%s:%s", rediskey.Tag("email", verificationEmailKeyDigest(email)), rediskey.Part(scene))
}

func verificationLastSendKey(scene, email string) string {
return fmt.Sprintf("auth:last_send:%s:%s", scene, verificationEmailKeyDigest(email))
return fmt.Sprintf("auth:last_send:%s:%s", rediskey.Tag("email", verificationEmailKeyDigest(email)), rediskey.Part(scene))
}

func canonicalVerificationEmail(email string) string {
Expand Down
14 changes: 14 additions & 0 deletions backend/internal/handlers/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"gorm.io/gorm"

"github.com/kurodakayn/mpp-backend/internal/models"
"github.com/kurodakayn/mpp-backend/internal/pkg/rediskey"
"github.com/kurodakayn/mpp-backend/internal/services/email"
)

Expand Down Expand Up @@ -49,6 +50,19 @@ func storeVerificationCode(t *testing.T, rdb *redis.Client, scene, email, code s
require.NoError(t, rdb.Set(context.Background(), verificationCodeKey(scene, email), code, 0).Err())
}

func TestVerificationRedisKeysShareEmailHashTag(t *testing.T) {
email := "Person@example.com"

codeKey := verificationCodeKey("register", email)
attemptKey := verificationAttemptKey("register", email)
lastSendKey := verificationLastSendKey("register", email)

require.True(t, rediskey.ShareTag(codeKey, attemptKey, lastSendKey))
tag, ok := rediskey.ExtractTag(codeKey)
require.True(t, ok)
require.Equal(t, "email:"+verificationEmailKeyDigest(email), tag)
}

func assertNoRedisKeyContains(t *testing.T, keys []string, values ...string) {
t.Helper()
for _, key := range keys {
Expand Down
65 changes: 65 additions & 0 deletions backend/internal/pkg/rediskey/rediskey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package rediskey

import "strings"

const Unknown = "unknown"

func Part(value string) string {
value = strings.ToLower(strings.TrimSpace(value))
if value == "" {
return Unknown
}

var builder strings.Builder
lastDash := false
for _, r := range value {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_' || r == ':' || r == '.' {
builder.WriteRune(r)
lastDash = false
continue
}
if !lastDash {
builder.WriteByte('-')
lastDash = true
}
}

result := strings.Trim(builder.String(), "-")
if result == "" {
return Unknown
}
return result
}

func Tag(scope string, value string) string {
return "{" + Part(scope) + ":" + Part(value) + "}"
}

func ExtractTag(key string) (string, bool) {
start := strings.IndexByte(key, '{')
if start < 0 {
return "", false
}
end := strings.IndexByte(key[start+1:], '}')
if end <= 0 {
return "", false
}
return key[start+1 : start+1+end], true
}

func ShareTag(keys ...string) bool {
if len(keys) == 0 {
return true
}
expected, ok := ExtractTag(keys[0])
if !ok {
return false
}
for _, key := range keys[1:] {
tag, ok := ExtractTag(key)
if !ok || tag != expected {
return false
}
}
return true
}
32 changes: 32 additions & 0 deletions backend/internal/pkg/rediskey/rediskey_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package rediskey

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestPartNormalizesUnsafeCharacters(t *testing.T) {
require.Equal(t, "tenant:workspace-1", Part(" Tenant:Workspace 1 "))
require.Equal(t, Unknown, Part(" @@@ "))
}

func TestTagBuildsRedisHashTag(t *testing.T) {
require.Equal(t, "{session:11111111-1111-4111-8111-111111111111}", Tag("session", "11111111-1111-4111-8111-111111111111"))
require.Equal(t, "{tenant:unknown}", Tag("tenant", ""))
}

func TestShareTagRequiresMatchingHashTags(t *testing.T) {
require.True(t, ShareTag(
"mpp:browser:stream-current:{session:abc}",
"mpp:browser:stream-token:{session:abc}:hash",
))
require.False(t, ShareTag(
"mpp:browser:stream-current:{session:abc}",
"mpp:browser:stream-token:{session:def}:hash",
))
require.False(t, ShareTag(
"mpp:browser:stream-current:abc",
"mpp:browser:stream-token:{session:abc}:hash",
))
}
11 changes: 6 additions & 5 deletions backend/internal/services/browser_session/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/redis/go-redis/v9"

"github.com/kurodakayn/mpp-backend/internal/models"
"github.com/kurodakayn/mpp-backend/internal/pkg/rediskey"
)

type browserSessionLiveState struct {
Expand Down Expand Up @@ -67,7 +68,7 @@ func (s *BrowserSessionService) cleanupRedisSessionForTenant(ctx context.Context
}

func browserSessionActiveKey(userID uuid.UUID, platform string) string {
return browserSessionActiveKeyPrefix + userID.String() + ":" + platform
return browserSessionActiveKeyPrefix + rediskey.Tag("user", userID.String()) + ":" + rediskey.Part(platform)
}

func browserSessionQuotaUserKey(userID uuid.UUID) string {
Expand All @@ -79,19 +80,19 @@ func browserSessionQuotaTenantKey(tenantID string) string {
}

func browserSessionKey(sessionID uuid.UUID) string {
return browserSessionKeyPrefix + sessionID.String()
return browserSessionKeyPrefix + rediskey.Tag("session", sessionID.String())
}

func browserSessionStreamTokenKey(sessionID uuid.UUID, tokenHash string) string {
return browserSessionStreamTokenPrefix + sessionID.String() + ":" + tokenHash
return browserSessionStreamTokenPrefix + rediskey.Tag("session", sessionID.String()) + ":" + rediskey.Part(tokenHash)
}

func browserSessionStreamTokenKeyPrefixFor(sessionID uuid.UUID) string {
return browserSessionStreamTokenPrefix + sessionID.String() + ":"
return browserSessionStreamTokenPrefix + rediskey.Tag("session", sessionID.String()) + ":"
}

func browserSessionStreamCurrentKey(sessionID uuid.UUID) string {
return browserSessionStreamCurrentPrefix + sessionID.String()
return browserSessionStreamCurrentPrefix + rediskey.Tag("session", sessionID.String())
}

func browserSessionWorkerHeartbeatKey(workerSessionRef string) string {
Expand Down
32 changes: 32 additions & 0 deletions backend/internal/services/browser_session/redis_keys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package browsersession

import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/kurodakayn/mpp-backend/internal/pkg/rediskey"
)

func TestBrowserSessionRedisKeysUseApprovedHashTags(t *testing.T) {
sessionID := uuid.New()
userID := uuid.New()

require.Equal(t, "session:"+sessionID.String(), mustRedisTag(t, browserSessionKey(sessionID)))
require.Equal(t, "user:"+userID.String(), mustRedisTag(t, browserSessionActiveKey(userID, "Douyin")))
require.True(t, rediskey.ShareTag(
browserSessionKey(sessionID),
browserSessionStreamCurrentKey(sessionID),
browserSessionStreamTokenKey(sessionID, "TOKEN-HASH"),
))
require.Equal(t, browserSessionStreamTokenPrefix+rediskey.Tag("session", sessionID.String())+":", browserSessionStreamTokenKeyPrefixFor(sessionID))
}

func mustRedisTag(t *testing.T, key string) string {
t.Helper()

tag, ok := rediskey.ExtractTag(key)
require.True(t, ok, key)
return tag
}
54 changes: 39 additions & 15 deletions backend/internal/services/browser_session/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,31 @@ func setRedisLiveSession(t *testing.T, client *redis.Client, state map[string]an
require.True(t, ok)
payload, err := json.Marshal(state)
require.NoError(t, err)
require.NoError(t, client.Set(context.Background(), "mpp:browser:session:"+sessionID, payload, ttl).Err())
require.NoError(t, client.Set(context.Background(), browserSessionTestRedisKey(sessionID), payload, ttl).Err())
}

func browserSessionTestRedisKey(sessionID string) string {
return "mpp:browser:session:" + browserSessionTestRedisTag("session", sessionID)
}

func browserSessionTestActiveKey(userID uuid.UUID, platform string) string {
return "mpp:browser:active:" + browserSessionTestRedisTag("user", userID.String()) + ":" + strings.ToLower(platform)
}

func browserSessionTestStreamTokenKey(sessionID uuid.UUID, tokenHash string) string {
return browserSessionTestStreamTokenPrefix(sessionID) + strings.ToLower(tokenHash)
}

func browserSessionTestStreamTokenPrefix(sessionID uuid.UUID) string {
return "mpp:browser:stream-token:" + browserSessionTestRedisTag("session", sessionID.String()) + ":"
}

func browserSessionTestStreamCurrentKey(sessionID uuid.UUID) string {
return "mpp:browser:stream-current:" + browserSessionTestRedisTag("session", sessionID.String())
}

func browserSessionTestRedisTag(scope string, value string) string {
return "{" + strings.ToLower(strings.TrimSpace(scope)) + ":" + strings.ToLower(strings.TrimSpace(value)) + "}"
}

type dashboardAccountCacheInvalidation struct {
Expand Down Expand Up @@ -493,12 +517,12 @@ func TestBrowserSessionService_GetSessionReturnsGoneForExpiredRedisSession(t *te
ExpiresAt: now.Add(-15 * time.Minute),
}
require.NoError(t, db.Create(&session).Error)
require.NoError(t, client.Set(context.Background(), "mpp:browser:active:"+userID.String()+":"+platform, session.ID.String(), time.Hour).Err())
require.NoError(t, client.Set(context.Background(), browserSessionTestActiveKey(userID, platform), session.ID.String(), time.Hour).Err())

_, err := svc.GetSession(context.Background(), userID, session.ID)

require.ErrorIs(t, err, browsersession.ErrSessionGone)
assert.Equal(t, int64(0), client.Exists(context.Background(), "mpp:browser:active:"+userID.String()+":"+platform).Val())
assert.Equal(t, int64(0), client.Exists(context.Background(), browserSessionTestActiveKey(userID, platform)).Val())

var savedSession models.RemoteBrowserSession
require.NoError(t, db.First(&savedSession, session.ID).Error)
Expand Down Expand Up @@ -548,7 +572,7 @@ func TestBrowserSessionService_RedisLiveStateOmitsInternalEndpointRefs(t *testin
resp, err := svc.StartSession(context.Background(), userID, platform)
require.NoError(t, err)

raw, err := client.Get(context.Background(), "mpp:browser:session:"+resp.SessionID.String()).Bytes()
raw, err := client.Get(context.Background(), browserSessionTestRedisKey(resp.SessionID.String())).Bytes()
require.NoError(t, err)
var payload map[string]any
require.NoError(t, json.Unmarshal(raw, &payload))
Expand Down Expand Up @@ -577,15 +601,15 @@ func TestBrowserSessionService_CancelSessionDeletesAllRedisStreamTokens(t *testi

resp, err := svc.StartSession(context.Background(), userID, platform)
require.NoError(t, err)
strayTokenKey := "mpp:browser:stream-token:" + resp.SessionID.String() + ":stray-token-hash"
strayTokenKey := browserSessionTestStreamTokenKey(resp.SessionID, "stray-token-hash")
require.NoError(t, client.Set(context.Background(), strayTokenKey, "{}", time.Hour).Err())

require.NoError(t, svc.CancelSession(context.Background(), userID, resp.SessionID))

tokenKeys, err := client.Keys(context.Background(), "mpp:browser:stream-token:"+resp.SessionID.String()+":*").Result()
tokenKeys, err := client.Keys(context.Background(), browserSessionTestStreamTokenPrefix(resp.SessionID)+"*").Result()
require.NoError(t, err)
assert.Empty(t, tokenKeys)
assert.Equal(t, int64(0), client.Exists(context.Background(), "mpp:browser:stream-current:"+resp.SessionID.String()).Val())
assert.Equal(t, int64(0), client.Exists(context.Background(), browserSessionTestStreamCurrentKey(resp.SessionID)).Val())
}

func TestBrowserSessionService_UnsupportedPlatform(t *testing.T) {
Expand Down Expand Up @@ -746,7 +770,7 @@ func TestBrowserSessionService_StartSessionRecoversStaleRedisActiveLock(t *testi
ExpiresAt: time.Now().Add(13 * time.Minute),
}
require.NoError(t, db.Create(&staleSession).Error)
require.NoError(t, client.Set(context.Background(), "mpp:browser:active:"+userID.String()+":"+platform, staleSession.ID.String(), time.Hour).Err())
require.NoError(t, client.Set(context.Background(), browserSessionTestActiveKey(userID, platform), staleSession.ID.String(), time.Hour).Err())
setRedisLiveSession(t, client, map[string]any{
"session_id": staleSession.ID.String(),
"user_id": userID.String(),
Expand All @@ -764,10 +788,10 @@ func TestBrowserSessionService_StartSessionRecoversStaleRedisActiveLock(t *testi
assert.Equal(t, models.BrowserSessionStatusReady, resp.Status)
assert.NotEqual(t, staleSession.ID, resp.SessionID)

activeSessionID, err := client.Get(context.Background(), "mpp:browser:active:"+userID.String()+":"+platform).Result()
activeSessionID, err := client.Get(context.Background(), browserSessionTestActiveKey(userID, platform)).Result()
require.NoError(t, err)
assert.Equal(t, resp.SessionID.String(), activeSessionID)
assert.Equal(t, int64(0), client.Exists(context.Background(), "mpp:browser:session:"+staleSession.ID.String()).Val())
assert.Equal(t, int64(0), client.Exists(context.Background(), browserSessionTestRedisKey(staleSession.ID.String())).Val())

var savedStaleSession models.RemoteBrowserSession
require.NoError(t, db.First(&savedStaleSession, staleSession.ID).Error)
Expand All @@ -789,7 +813,7 @@ func TestBrowserSessionService_StartSessionPreservesReachableRedisActiveLock(t *
_, err = svc.StartSession(context.Background(), userID, platform)

require.ErrorIs(t, err, browsersession.ErrActiveSessionExists)
activeSessionID, err := client.Get(context.Background(), "mpp:browser:active:"+userID.String()+":"+platform).Result()
activeSessionID, err := client.Get(context.Background(), browserSessionTestActiveKey(userID, platform)).Result()
require.NoError(t, err)
assert.Equal(t, resp.SessionID.String(), activeSessionID)
}
Expand Down Expand Up @@ -864,7 +888,7 @@ func TestBrowserSessionService_GetSessionKeepsLiveRedisStateOnTransientWorkerRea
ExpiresAt: time.Now().Add(10 * time.Minute),
}
require.NoError(t, db.Create(&session).Error)
require.NoError(t, client.Set(context.Background(), "mpp:browser:active:"+userID.String()+":"+platform, session.ID.String(), time.Hour).Err())
require.NoError(t, client.Set(context.Background(), browserSessionTestActiveKey(userID, platform), session.ID.String(), time.Hour).Err())
require.NoError(t, client.Set(context.Background(), "mpp:browser:worker-heartbeat:"+workerSessionRef, session.ID.String(), time.Hour).Err())
setRedisLiveSession(t, client, map[string]any{
"session_id": session.ID.String(),
Expand All @@ -889,8 +913,8 @@ func TestBrowserSessionService_GetSessionKeepsLiveRedisStateOnTransientWorkerRea
require.NoError(t, db.First(&savedSession, session.ID).Error)
assert.Equal(t, models.BrowserSessionStatusReady, savedSession.Status)
assert.Equal(t, "stale-token", savedSession.ConnectTokenHash)
assert.Equal(t, int64(1), client.Exists(context.Background(), "mpp:browser:active:"+userID.String()+":"+platform).Val())
assert.Equal(t, int64(1), client.Exists(context.Background(), "mpp:browser:session:"+session.ID.String()).Val())
assert.Equal(t, int64(1), client.Exists(context.Background(), browserSessionTestActiveKey(userID, platform)).Val())
assert.Equal(t, int64(1), client.Exists(context.Background(), browserSessionTestRedisKey(session.ID.String())).Val())
assert.Equal(t, int64(1), client.Exists(context.Background(), "mpp:browser:worker-heartbeat:"+workerSessionRef).Val())
}

Expand Down Expand Up @@ -928,7 +952,7 @@ func TestBrowserSessionService_CancelSessionRemovesContinuityStateWhenCoordinati

require.NoError(t, svc.CancelSession(context.Background(), userID, session.ID))

assert.Equal(t, int64(0), continuityClient.Exists(context.Background(), "mpp:browser:session:"+session.ID.String()).Val())
assert.Equal(t, int64(0), continuityClient.Exists(context.Background(), browserSessionTestRedisKey(session.ID.String())).Val())
assert.Equal(t, int64(0), continuityClient.Exists(context.Background(), "mpp:browser:worker-heartbeat:"+workerSessionRef).Val())
assert.Equal(t, int64(0), continuityClient.ZCard(context.Background(), "mpp:browser:cleanup").Val())

Expand Down
2 changes: 1 addition & 1 deletion backend/internal/services/mediaasset/assets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ func (s *deadlinePresignStorage) PresignGetObject(ctx context.Context, input obj
func requireResolvedMediaAssetCacheKeys(t *testing.T, redisServer *miniredis.Miniredis, assetID uuid.UUID, count int) []string {
t.Helper()

prefix := "mpp:dashboard:media-assets:resolve:v1:" + assetID.String() + ":"
prefix := "mpp:dashboard:media-assets:resolve:v1:{asset:" + assetID.String() + "}:"
keys := make([]string, 0)
for _, key := range redisServer.Keys() {
if strings.HasPrefix(key, prefix) {
Expand Down
Loading
Loading