Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ func syncSchemaUnlocked(database *gorm.DB) error {
if err := ensureMonthlyEventPartitions(database, time.Now().UTC()); err != nil {
return err
}
if err := ensureCollabUpdateBatchHashPartitions(database); err != nil {
return err
}

if err := database.AutoMigrate(
&models.User{},
Expand Down
40 changes: 40 additions & 0 deletions backend/internal/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ func TestMonthlyPartitionedEventModelsUsePartitionCompatiblePrimaryKeys(t *testi
require.True(t, database.Migrator().HasTable(&models.ExtensionExecutionEventClaim{}))
}

func TestCollabUpdateBatchModelUsesHashPartitionCompatiblePrimaryKey(t *testing.T) {
database, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, syncSchema(database))

primaryKeyColumns := sqlitePrimaryKeyColumns(t, database, "collab_document_update_batches")

require.ElementsMatch(t, []string{"id", "document_id"}, primaryKeyColumns)
}

func TestCreateMonthlyPartitionSQLUsesMonthRange(t *testing.T) {
start := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC)

Expand All @@ -171,6 +181,36 @@ func TestCreateDefaultMonthlyPartitionSQLUsesDefaultPartition(t *testing.T) {
require.Contains(t, sql, " DEFAULT")
}

func TestCreateHashPartitionSQLUsesDocumentHashRemainder(t *testing.T) {
sql := createHashPartitionSQL("collab_document_update_batches", 3, 16)

require.Contains(t, sql, `"collab_document_update_batches_p03"`)
require.Contains(t, sql, `PARTITION OF "collab_document_update_batches"`)
require.Contains(t, sql, "FOR VALUES WITH (MODULUS 16, REMAINDER 3)")
}

func TestCollabUpdateBatchHashPartitionDefinition(t *testing.T) {
require.Equal(t, "collab_document_update_batches", collabUpdateBatchHashPartitionedTable.name)
require.Equal(t, "document_id", collabUpdateBatchHashPartitionedTable.partitionKey)
require.Equal(t, 16, collabUpdateBatchHashPartitionedTable.partitionCount)
require.Contains(t, collabUpdateBatchHashPartitionedTable.createSQL, "PARTITION BY HASH (document_id)")
require.Contains(t, collabUpdateBatchHashPartitionedTable.createSQL, "PRIMARY KEY (id, document_id)")
require.Contains(t, collabUpdateBatchHashPartitionedTable.columns, "document_id")
}

func TestRenameLegacySerialSequenceSQLAvoidsBigserialCollision(t *testing.T) {
sql := renameLegacySerialSequenceSQL(
"collab_document_update_batches",
"collab_document_update_batches_legacy_20260623120000",
"id",
)

require.Contains(t, sql, "pg_get_serial_sequence('collab_document_update_batches_legacy_20260623120000', 'id')")
require.Contains(t, sql, "to_regclass('collab_document_update_batches_id_seq')")
require.Contains(t, sql, "ALTER SEQUENCE %s RENAME TO %I")
require.Contains(t, sql, "'collab_document_update_batches_legacy_20260623120000_id_seq'")
}

func TestBackfillExtensionExecutionEventClaims(t *testing.T) {
database, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
TranslateError: true,
Expand Down
192 changes: 192 additions & 0 deletions backend/internal/db/hash_partitions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package db

import (
"fmt"
"strings"
"time"

"gorm.io/gorm"
)

const collabUpdateBatchHashPartitionCount = 16

type hashPartitionedTable struct {
name string
partitionKey string
partitionCount int
createSQL string
columns []string
}

var collabUpdateBatchHashPartitionedTable = hashPartitionedTable{
name: "collab_document_update_batches",
partitionKey: "document_id",
partitionCount: collabUpdateBatchHashPartitionCount,
createSQL: `
CREATE TABLE IF NOT EXISTS collab_document_update_batches (
id bigserial NOT NULL,
document_id uuid NOT NULL,
from_seq bigint NOT NULL,
to_seq bigint NOT NULL,
update_payload bytea NOT NULL,
update_count integer NOT NULL,
payload_size_bytes integer NOT NULL,
actor_user_id uuid,
created_at timestamptz NOT NULL,
CONSTRAINT pk_collab_document_update_batches_partitioned PRIMARY KEY (id, document_id)
) PARTITION BY HASH (document_id)
`,
columns: []string{
"id",
"document_id",
"from_seq",
"to_seq",
"update_payload",
"update_count",
"payload_size_bytes",
"actor_user_id",
"created_at",
},
}

func ensureCollabUpdateBatchHashPartitions(database *gorm.DB) error {
if database.Name() != "postgres" {
return nil
}
return ensureHashPartitionedTable(database, collabUpdateBatchHashPartitionedTable)
}

func ensureHashPartitionedTable(database *gorm.DB, table hashPartitionedTable) error {
partitionKey, exists, err := postgresPartitionedTableKey(database, table.name)
if err != nil {
return err
}

expectedKey := fmt.Sprintf("HASH (%s)", table.partitionKey)
if exists && partitionKey != "" && !strings.EqualFold(partitionKey, expectedKey) {
return fmt.Errorf("%s is partitioned by %s, expected %s", table.name, partitionKey, expectedKey)
}

if exists && partitionKey == "" {
if err := migrateRegularTableToHashPartitions(database, table); err != nil {
return err
}
} else if !exists {
if err := database.Exec(table.createSQL).Error; err != nil {
return err
}
}

for partitionIndex := range table.partitionCount {
if err := database.Exec(createHashPartitionSQL(table.name, partitionIndex, table.partitionCount)).Error; err != nil {
return err
}
}
return nil
}

func postgresPartitionedTableKey(database *gorm.DB, tableName string) (partitionKey string, exists bool, err error) {
err = database.Raw(`
SELECT
to_regclass(?) IS NOT NULL AS exists,
COALESCE(pg_get_partkeydef(to_regclass(?)), '') AS partition_key
`, tableName, tableName).Row().Scan(&exists, &partitionKey)
return partitionKey, exists, err
}

func migrateRegularTableToHashPartitions(database *gorm.DB, table hashPartitionedTable) error {
legacyName := fmt.Sprintf("%s_legacy_%s", table.name, time.Now().UTC().Format("20060102150405"))
if err := database.Exec(fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s",
quotePostgresIdentifier(table.name),
quotePostgresIdentifier(legacyName),
)).Error; err != nil {
return err
}
if err := renameLegacySerialSequence(database, table.name, legacyName, "id"); err != nil {
return err
}

if err := database.Exec(table.createSQL).Error; err != nil {
return err
}
for partitionIndex := range table.partitionCount {
if err := database.Exec(createHashPartitionSQL(table.name, partitionIndex, table.partitionCount)).Error; err != nil {
return err
}
}
if err := copyLegacyRowsIntoHashPartitionedTable(database, table, legacyName); err != nil {
return err
}
if err := syncSerialSequenceToMaxID(database, table.name, "id"); err != nil {
return err
}
return database.Exec(fmt.Sprintf("DROP TABLE %s", quotePostgresIdentifier(legacyName))).Error
}

func copyLegacyRowsIntoHashPartitionedTable(database *gorm.DB, table hashPartitionedTable, legacyName string) error {
columnList := quotedColumnList(table.columns)
return database.Exec(fmt.Sprintf(
"INSERT INTO %s (%s) SELECT %s FROM %s",
quotePostgresIdentifier(table.name),
columnList,
columnList,
quotePostgresIdentifier(legacyName),
)).Error
}

func renameLegacySerialSequence(database *gorm.DB, tableName string, legacyName string, columnName string) error {
return database.Exec(renameLegacySerialSequenceSQL(tableName, legacyName, columnName)).Error
}

func renameLegacySerialSequenceSQL(tableName string, legacyName string, columnName string) string {
sequenceName := fmt.Sprintf("%s_%s_seq", tableName, columnName)
legacySequenceName := fmt.Sprintf("%s_%s_seq", legacyName, columnName)
return fmt.Sprintf(`
DO $$
DECLARE
legacy_sequence text;
BEGIN
SELECT pg_get_serial_sequence(%s, %s) INTO legacy_sequence;
IF legacy_sequence IS NOT NULL
AND to_regclass(legacy_sequence) = to_regclass(%s)
THEN
EXECUTE format('ALTER SEQUENCE %%s RENAME TO %%I', legacy_sequence, %s);
END IF;
END $$`,
quotePostgresStringLiteral(legacyName),
quotePostgresStringLiteral(columnName),
quotePostgresStringLiteral(sequenceName),
quotePostgresStringLiteral(legacySequenceName),
)
}

func syncSerialSequenceToMaxID(database *gorm.DB, tableName string, columnName string) error {
return database.Exec(fmt.Sprintf(
`SELECT setval(
pg_get_serial_sequence('%s', '%s'),
COALESCE((SELECT MAX(%s) FROM %s), 1),
EXISTS (SELECT 1 FROM %s)
)`,
strings.ReplaceAll(tableName, "'", "''"),
strings.ReplaceAll(columnName, "'", "''"),
quotePostgresIdentifier(columnName),
quotePostgresIdentifier(tableName),
quotePostgresIdentifier(tableName),
)).Error
}

func createHashPartitionSQL(tableName string, partitionIndex int, partitionCount int) string {
partitionName := fmt.Sprintf("%s_p%02d", tableName, partitionIndex)
return fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s PARTITION OF %s FOR VALUES WITH (MODULUS %d, REMAINDER %d)",
quotePostgresIdentifier(partitionName),
quotePostgresIdentifier(tableName),
partitionCount,
partitionIndex,
)
}

func quotePostgresStringLiteral(value string) string {
return "'" + strings.ReplaceAll(value, "'", "''") + "'"
}
4 changes: 2 additions & 2 deletions backend/internal/models/collab.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ type CollabDocumentState struct {
}

type CollabDocumentUpdateBatch struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
DocumentID uuid.UUID `gorm:"type:uuid;not null;uniqueIndex:ux_collab_update_batch_doc_seq,priority:1;index:idx_collab_update_batches_doc_seq,priority:1"`
ID int64 `gorm:"primaryKey;autoIncrement:false"`
DocumentID uuid.UUID `gorm:"type:uuid;primaryKey;not null;uniqueIndex:ux_collab_update_batch_doc_seq,priority:1;index:idx_collab_update_batches_doc_seq,priority:1"`
FromSeq int64 `gorm:"not null;uniqueIndex:ux_collab_update_batch_doc_seq,priority:2"`
ToSeq int64 `gorm:"not null;uniqueIndex:ux_collab_update_batch_doc_seq,priority:3;index:idx_collab_update_batches_doc_seq,priority:2,sort:desc"`
UpdatePayload []byte `gorm:"type:bytea;not null"`
Expand Down
Loading
Loading