From acc59fc55b6489b3c3ac486ca151822edd902f74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?No=C3=A9mi=20V=C3=A1nyi?= Date: Thu, 12 Jun 2025 18:39:26 +0200 Subject: [PATCH] Return list of `DBAction` from `Complete` functions --- pkg/migrations/migrations.go | 2 +- pkg/migrations/op_add_column.go | 68 +++++-------------- pkg/migrations/op_alter_column.go | 53 ++++++--------- pkg/migrations/op_change_type.go | 4 +- pkg/migrations/op_create_constraint.go | 60 +++++++--------- pkg/migrations/op_create_index.go | 4 +- pkg/migrations/op_create_table.go | 4 +- pkg/migrations/op_drop_column.go | 25 ++----- pkg/migrations/op_drop_constraint.go | 41 +++-------- pkg/migrations/op_drop_index.go | 6 +- .../op_drop_multicolumn_constraint.go | 40 ++++------- pkg/migrations/op_drop_not_null.go | 4 +- pkg/migrations/op_drop_table.go | 10 +-- pkg/migrations/op_raw_sql.go | 6 +- pkg/migrations/op_rename_column.go | 7 +- pkg/migrations/op_rename_constraint.go | 8 ++- pkg/migrations/op_rename_table.go | 4 +- pkg/migrations/op_set_check.go | 13 ++-- pkg/migrations/op_set_comment.go | 6 +- pkg/migrations/op_set_default.go | 4 +- pkg/migrations/op_set_fk.go | 13 ++-- pkg/migrations/op_set_notnull.go | 39 ++++------- pkg/migrations/op_set_replica_identity.go | 4 +- pkg/migrations/op_set_unique.go | 4 +- pkg/roll/execute.go | 10 ++- 25 files changed, 163 insertions(+), 276 deletions(-) diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index 57b6daf67..2be5a4daf 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -25,7 +25,7 @@ type Operation interface { // Complete will update the database schema to match the current version // after calling Start. // This method should be called once the previous version is no longer used. - Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error + Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) // Rollback will revert the changes made by Start. It is not possible to // rollback a completed migration. diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index bcc124077..eaa0349b8 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -141,69 +141,43 @@ func toSchemaColumn(c Column) *schema.Column { return tmpColumn } -func (o *OpAddColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpAddColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - err := NewRenameColumnAction(conn, o.Table, TemporaryName(o.Column.Name), o.Column.Name).Execute(ctx) - if err != nil { - return err - } - - err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column.Name)).Execute(ctx) - if err != nil { - return err - } - - removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn) - err = removeBackfillColumn.Execute(ctx) - if err != nil { - return err + dbActions := []DBAction{ + NewRenameColumnAction(conn, o.Table, TemporaryName(o.Column.Name), o.Column.Name), + NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column.Name)), + NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn), } if !o.Column.IsNullable() && o.Column.Default == nil { - err = upgradeNotNullConstraintToNotNullAttribute(ctx, conn, o.Table, o.Column.Name) - if err != nil { - return err - } + dbActions = append(dbActions, upgradeNotNullConstraintToNotNullAttribute(conn, o.Table, o.Column.Name)...) } if o.Column.Check != nil { - err = NewValidateConstraintAction(conn, o.Table, o.Column.Check.Name).Execute(ctx) - if err != nil { - return err - } + dbActions = append(dbActions, NewValidateConstraintAction(conn, o.Table, o.Column.Check.Name)) } if o.Column.Unique { - err := NewAddConstraintUsingUniqueIndex(conn, + dbActions = append(dbActions, NewAddConstraintUsingUniqueIndex(conn, o.Table, o.Column.Name, - UniqueIndexName(o.Column.Name), - ).Execute(ctx) - if err != nil { - return err - } + UniqueIndexName(o.Column.Name))) } // If the column has a DEFAULT that could not be set using the fast-path // optimization, set it here. column := s.GetTable(o.Table).GetColumn(TemporaryName(o.Column.Name)) if o.Column.HasDefault() && column.Default == nil { - err := NewSetDefaultValueAction(conn, o.Table, o.Column.Name, *o.Column.Default).Execute(ctx) - if err != nil { - return err - } + dbActions = append(dbActions, NewSetDefaultValueAction(conn, o.Table, o.Column.Name, *o.Column.Default)) // Validate the `NOT NULL` constraint on the column if necessary if !o.Column.IsNullable() { - err = upgradeNotNullConstraintToNotNullAttribute(ctx, conn, o.Table, o.Column.Name) - if err != nil { - return err - } + dbActions = append(dbActions, upgradeNotNullConstraintToNotNullAttribute(conn, o.Table, o.Column.Name)...) } } - return nil + return dbActions, nil } func (o *OpAddColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { @@ -347,20 +321,12 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, // upgradeNotNullConstraintToNotNullAttribute validates and upgrades a NOT NULL // constraint to a NOT NULL column attribute. The constraint is removed after // the column attribute is added. -func upgradeNotNullConstraintToNotNullAttribute(ctx context.Context, conn db.DB, tableName, columnName string) error { - err := NewValidateConstraintAction(conn, tableName, NotNullConstraintName(columnName)).Execute(ctx) - if err != nil { - return err - } - - err = NewSetNotNullAction(conn, tableName, columnName).Execute(ctx) - if err != nil { - return err +func upgradeNotNullConstraintToNotNullAttribute(conn db.DB, tableName, columnName string) []DBAction { + return []DBAction{ + NewValidateConstraintAction(conn, tableName, NotNullConstraintName(columnName)), + NewSetNotNullAction(conn, tableName, columnName), + NewDropConstraintAction(conn, tableName, NotNullConstraintName(columnName)), } - - err = NewDropConstraintAction(conn, tableName, NotNullConstraintName(columnName)).Execute(ctx) - - return err } // UniqueIndexName returns the name of the unique index for the given column diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index df639e2e3..2e3fb085a 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -91,54 +91,41 @@ func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, latestS return backfill.NewTask(table), nil } -func (o *OpAlterColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpAlterColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) ops := o.subOperations() + dbActions := []DBAction{} // Perform any operation specific completion steps for _, op := range ops { - if err := op.Complete(ctx, l, conn, s); err != nil { - return err + actions, err := op.Complete(l, conn, s) + if err != nil { + return []DBAction{}, err } - } - - if err := NewAlterSequenceOwnerAction(conn, o.Table, o.Column, TemporaryName(o.Column)).Execute(ctx); err != nil { - return err - } - - removeOldColumn := NewDropColumnAction(conn, o.Table, o.Column) - err := removeOldColumn.Execute(ctx) - if err != nil { - return err - } - - // Remove the up and down function and trigger - err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column), backfill.TriggerFunctionName(o.Table, TemporaryName(o.Column))).Execute(ctx) - if err != nil { - return err - } - - removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn) - err = removeBackfillColumn.Execute(ctx) - if err != nil { - return err + dbActions = append(dbActions, actions...) } // Rename the new column to the old column name table := s.GetTable(o.Table) if table == nil { - return TableDoesNotExistError{Name: o.Table} + return []DBAction{}, TableDoesNotExistError{Name: o.Table} } column := table.GetColumn(o.Column) if column == nil { - return ColumnDoesNotExistError{Table: o.Table, Name: o.Column} - } - if err := NewRenameDuplicatedColumnAction(conn, table, column.Name).Execute(ctx); err != nil { - return err - } - - return nil + return []DBAction{}, ColumnDoesNotExistError{Table: o.Table, Name: o.Column} + } + + return append(dbActions, []DBAction{ + NewAlterSequenceOwnerAction(conn, o.Table, o.Column, TemporaryName(o.Column)), + NewDropColumnAction(conn, o.Table, o.Column), + NewDropFunctionAction(conn, + backfill.TriggerFunctionName(o.Table, o.Column), + backfill.TriggerFunctionName(o.Table, TemporaryName(o.Column)), + ), + NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn), + NewRenameDuplicatedColumnAction(conn, table, column.Name), + }...), nil } func (o *OpAlterColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_change_type.go b/pkg/migrations/op_change_type.go index 609099043..1ea765afa 100644 --- a/pkg/migrations/op_change_type.go +++ b/pkg/migrations/op_change_type.go @@ -31,10 +31,10 @@ func (o *OpChangeType) Start(ctx context.Context, l Logger, conn db.DB, latestSc return backfill.NewTask(table), nil } -func (o *OpChangeType) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpChangeType) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - return nil + return []DBAction{}, nil } func (o *OpChangeType) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_create_constraint.go b/pkg/migrations/op_create_constraint.go index f146524c1..f382f94e3 100644 --- a/pkg/migrations/op_create_constraint.go +++ b/pkg/migrations/op_create_constraint.go @@ -101,19 +101,21 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, la return task, nil } -func (o *OpCreateConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpCreateConstraint) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) + dbActions := make([]DBAction, 0) switch o.Type { case OpCreateConstraintTypeUnique: uniqueOp := &OpSetUnique{ Table: o.Table, Name: o.Name, } - err := uniqueOp.Complete(ctx, l, conn, s) + actions, err := uniqueOp.Complete(l, conn, s) if err != nil { - return err + return []DBAction{}, err } + dbActions = append(dbActions, actions...) case OpCreateConstraintTypeCheck: checkOp := &OpSetCheckConstraint{ Table: o.Table, @@ -121,10 +123,11 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, l Logger, conn db.DB, Name: o.Name, }, } - err := checkOp.Complete(ctx, l, conn, s) + actions, err := checkOp.Complete(l, conn, s) if err != nil { - return err + return []DBAction{}, err } + dbActions = append(dbActions, actions...) case OpCreateConstraintTypeForeignKey: fkOp := &OpSetForeignKey{ Table: o.Table, @@ -132,52 +135,39 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, l Logger, conn db.DB, Name: o.Name, }, } - err := fkOp.Complete(ctx, l, conn, s) + actions, err := fkOp.Complete(l, conn, s) if err != nil { - return err + return []DBAction{}, err } + dbActions = append(dbActions, actions...) case OpCreateConstraintTypePrimaryKey: - err := NewAddPrimaryKeyAction(conn, o.Table, o.Name).Execute(ctx) - if err != nil { - return err - } + dbActions = append(dbActions, NewAddPrimaryKeyAction(conn, o.Table, o.Name)) } for _, col := range o.Columns { - if err := NewAlterSequenceOwnerAction(conn, o.Table, col, TemporaryName(col)).Execute(ctx); err != nil { - return err - } + dbActions = append(dbActions, NewAlterSequenceOwnerAction(conn, o.Table, col, TemporaryName(col))) } - removeOldColumns := NewDropColumnAction(conn, o.Table, o.Columns...) - err := removeOldColumns.Execute(ctx) - if err != nil { - return err - } + dbActions = append(dbActions, NewDropColumnAction(conn, o.Table, o.Columns...)) // rename new columns to old name table := s.GetTable(o.Table) if table == nil { - return TableDoesNotExistError{Name: o.Table} + return []DBAction{}, TableDoesNotExistError{Name: o.Table} } for _, col := range o.Columns { column := table.GetColumn(col) if column == nil { - return ColumnDoesNotExistError{Table: o.Table, Name: col} + return []DBAction{}, ColumnDoesNotExistError{Table: o.Table, Name: col} } - if err := NewRenameDuplicatedColumnAction(conn, table, column.Name).Execute(ctx); err != nil { - return err - } - } - - if err := o.removeTriggers(ctx, conn); err != nil { - return err + dbActions = append(dbActions, NewRenameDuplicatedColumnAction(conn, table, column.Name)) } + dbActions = append(dbActions, + o.removeTriggers(conn), + NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn), + ) - removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn) - err = removeBackfillColumn.Execute(ctx) - - return err + return dbActions, nil } func (o *OpCreateConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { @@ -194,7 +184,7 @@ func (o *OpCreateConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, return err } - if err := o.removeTriggers(ctx, conn); err != nil { + if err := o.removeTriggers(conn).Execute(ctx); err != nil { return err } @@ -204,13 +194,13 @@ func (o *OpCreateConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, return err } -func (o *OpCreateConstraint) removeTriggers(ctx context.Context, conn db.DB) error { +func (o *OpCreateConstraint) removeTriggers(conn db.DB) DBAction { dropFuncs := make([]string, 0, len(o.Columns)*2) for _, column := range o.Columns { dropFuncs = append(dropFuncs, backfill.TriggerFunctionName(o.Table, column)) dropFuncs = append(dropFuncs, backfill.TriggerFunctionName(o.Table, TemporaryName(column))) } - return NewDropFunctionAction(conn, dropFuncs...).Execute(ctx) + return NewDropFunctionAction(conn, dropFuncs...) } func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) error { diff --git a/pkg/migrations/op_create_index.go b/pkg/migrations/op_create_index.go index 513f94d2d..08eda1262 100644 --- a/pkg/migrations/op_create_index.go +++ b/pkg/migrations/op_create_index.go @@ -79,11 +79,11 @@ func (o *OpCreateIndex) Start(ctx context.Context, l Logger, conn db.DB, latestS return nil, err } -func (o *OpCreateIndex) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpCreateIndex) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) // No-op - return nil + return []DBAction{}, nil } func (o *OpCreateIndex) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_create_table.go b/pkg/migrations/op_create_table.go index 5d3d6a615..b93700997 100644 --- a/pkg/migrations/op_create_table.go +++ b/pkg/migrations/op_create_table.go @@ -61,11 +61,11 @@ func (o *OpCreateTable) Start(ctx context.Context, l Logger, conn db.DB, latestS return nil, nil } -func (o *OpCreateTable) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpCreateTable) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) // No-op - return nil + return []DBAction{}, nil } func (o *OpCreateTable) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_column.go b/pkg/migrations/op_drop_column.go index 1f6c4654f..c6883c07c 100644 --- a/pkg/migrations/op_drop_column.go +++ b/pkg/migrations/op_drop_column.go @@ -49,27 +49,14 @@ func (o *OpDropColumn) Start(ctx context.Context, l Logger, conn db.DB, latestSc return nil, nil } -func (o *OpDropColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpDropColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - removeColumn := NewDropColumnAction(conn, o.Table, o.Column) - err := removeColumn.Execute(ctx) - if err != nil { - return err - } - - err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column)).Execute(ctx) - if err != nil { - return err - } - - removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn) - err = removeBackfillColumn.Execute(ctx) - if err != nil { - return err - } - - return nil + return []DBAction{ + NewDropColumnAction(conn, o.Table, o.Column), + NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column)), + NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn), + }, nil } func (o *OpDropColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_constraint.go b/pkg/migrations/op_drop_constraint.go index d768e4e04..cd47fa641 100644 --- a/pkg/migrations/op_drop_constraint.go +++ b/pkg/migrations/op_drop_constraint.go @@ -79,43 +79,22 @@ func (o *OpDropConstraint) Start(ctx context.Context, l Logger, conn db.DB, late return backfill.NewTask(table), nil } -func (o *OpDropConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpDropConstraint) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) // We have already validated that there is single column related to this constraint. table := s.GetTable(o.Table) column := table.GetColumn(table.GetConstraintColumns(o.Name)[0]) - // Remove the up and down function and trigger - err := NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, column.Name), backfill.TriggerFunctionName(o.Table, TemporaryName(column.Name))).Execute(ctx) - if err != nil { - return err - } - - if err := NewAlterSequenceOwnerAction(conn, o.Table, column.Name, TemporaryName(column.Name)).Execute(ctx); err != nil { - return err - } - - removeBackfillColumn := NewDropColumnAction(conn, table.Name, backfill.CNeedsBackfillColumn) - err = removeBackfillColumn.Execute(ctx) - if err != nil { - return err - } - - removeOldColumn := NewDropColumnAction(conn, - o.Table, - column.Name) - err = removeOldColumn.Execute(ctx) - if err != nil { - return err - } - - // Rename the new column to the old column name - if err := NewRenameDuplicatedColumnAction(conn, table, column.Name).Execute(ctx); err != nil { - return err - } - - return err + return []DBAction{ + NewDropFunctionAction(conn, + backfill.TriggerFunctionName(o.Table, column.Name), + backfill.TriggerFunctionName(o.Table, TemporaryName(column.Name))), + NewAlterSequenceOwnerAction(conn, o.Table, column.Name, TemporaryName(column.Name)), + NewDropColumnAction(conn, table.Name, backfill.CNeedsBackfillColumn), + NewDropColumnAction(conn, o.Table, column.Name), + NewRenameDuplicatedColumnAction(conn, table, column.Name), + }, nil } func (o *OpDropConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_index.go b/pkg/migrations/op_drop_index.go index 8971af0c5..ce9a54b59 100644 --- a/pkg/migrations/op_drop_index.go +++ b/pkg/migrations/op_drop_index.go @@ -22,10 +22,12 @@ func (o *OpDropIndex) Start(ctx context.Context, l Logger, conn db.DB, latestSch return nil, nil } -func (o *OpDropIndex) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpDropIndex) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - return NewDropIndexAction(conn, o.Name).Execute(ctx) + return []DBAction{ + NewDropIndexAction(conn, o.Name), + }, nil } func (o *OpDropIndex) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_multicolumn_constraint.go b/pkg/migrations/op_drop_multicolumn_constraint.go index 9168b630d..995bc542b 100644 --- a/pkg/migrations/op_drop_multicolumn_constraint.go +++ b/pkg/migrations/op_drop_multicolumn_constraint.go @@ -97,42 +97,26 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, l Logger, conn return backfill.NewTask(table), nil } -func (o *OpDropMultiColumnConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpDropMultiColumnConstraint) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) table := s.GetTable(o.Table) + dbActions := make([]DBAction, 0) for _, columnName := range table.GetConstraintColumns(o.Name) { - // Remove the up and down function and trigger - err := NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, columnName), backfill.TriggerFunctionName(o.Table, TemporaryName(columnName))).Execute(ctx) - if err != nil { - return err - } - - if err := NewAlterSequenceOwnerAction(conn, o.Table, columnName, TemporaryName(columnName)).Execute(ctx); err != nil { - return err - } - - removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn) - err = removeBackfillColumn.Execute(ctx) - if err != nil { - return err - } - - removeOldColumn := NewDropColumnAction(conn, o.Table, columnName) - err = removeOldColumn.Execute(ctx) - if err != nil { - return err - } - - // Rename the new column to the old column name column := table.GetColumn(columnName) - if err := NewRenameDuplicatedColumnAction(conn, table, column.Name).Execute(ctx); err != nil { - return err - } + dbActions = append(dbActions, + NewDropFunctionAction(conn, + backfill.TriggerFunctionName(o.Table, columnName), + backfill.TriggerFunctionName(o.Table, TemporaryName(columnName))), + NewAlterSequenceOwnerAction(conn, o.Table, columnName, TemporaryName(columnName)), + NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn), + NewDropColumnAction(conn, o.Table, columnName), + NewRenameDuplicatedColumnAction(conn, table, column.Name), + ) } - return nil + return dbActions, nil } func (o *OpDropMultiColumnConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_not_null.go b/pkg/migrations/op_drop_not_null.go index 8d7332f0e..937dffb83 100644 --- a/pkg/migrations/op_drop_not_null.go +++ b/pkg/migrations/op_drop_not_null.go @@ -31,9 +31,9 @@ func (o *OpDropNotNull) Start(ctx context.Context, l Logger, conn db.DB, latestS return backfill.NewTask(table), nil } -func (o *OpDropNotNull) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpDropNotNull) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - return nil + return []DBAction{}, nil } func (o *OpDropNotNull) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_table.go b/pkg/migrations/op_drop_table.go index e3a838f51..014e5c935 100644 --- a/pkg/migrations/op_drop_table.go +++ b/pkg/migrations/op_drop_table.go @@ -35,13 +35,13 @@ func (o *OpDropTable) Start(ctx context.Context, l Logger, conn db.DB, latestSch return nil, nil } -func (o *OpDropTable) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpDropTable) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - deletionName := DeletionName(o.Name) - - // Perform the actual deletion of the soft-deleted table - return NewDropTableAction(conn, deletionName).Execute(ctx) + return []DBAction{ + // Perform the actual deletion of the soft-deleted table + NewDropTableAction(conn, DeletionName(o.Name)), + }, nil } func (o *OpDropTable) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_raw_sql.go b/pkg/migrations/op_raw_sql.go index 407255e1c..a9baaca01 100644 --- a/pkg/migrations/op_raw_sql.go +++ b/pkg/migrations/op_raw_sql.go @@ -25,14 +25,14 @@ func (o *OpRawSQL) Start(ctx context.Context, l Logger, conn db.DB, latestSchema return nil, NewRawSQLAction(conn, o.Up).Execute(ctx) } -func (o *OpRawSQL) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpRawSQL) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) if !o.OnComplete { - return nil + return []DBAction{}, nil } - return NewRawSQLAction(conn, o.Up).Execute(ctx) + return []DBAction{NewRawSQLAction(conn, o.Up)}, nil } func (o *OpRawSQL) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_rename_column.go b/pkg/migrations/op_rename_column.go index 6f0f0b907..60a67404f 100644 --- a/pkg/migrations/op_rename_column.go +++ b/pkg/migrations/op_rename_column.go @@ -33,9 +33,12 @@ func (o *OpRenameColumn) Start(ctx context.Context, l Logger, conn db.DB, latest return nil, nil } -func (o *OpRenameColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpRenameColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - return NewRenameColumnAction(conn, o.Table, o.From, o.To).Execute(ctx) + + return []DBAction{ + NewRenameColumnAction(conn, o.Table, o.From, o.To), + }, nil } func (o *OpRenameColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_rename_constraint.go b/pkg/migrations/op_rename_constraint.go index 567dd4078..7b3345b37 100644 --- a/pkg/migrations/op_rename_constraint.go +++ b/pkg/migrations/op_rename_constraint.go @@ -19,11 +19,13 @@ func (o *OpRenameConstraint) Start(ctx context.Context, l Logger, conn db.DB, la return nil, nil } -func (o *OpRenameConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpRenameConstraint) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - // rename the constraint in the underlying table - return NewRenameConstraintAction(conn, o.Table, o.From, o.To).Execute(ctx) + return []DBAction{ + // rename the constraint in the underlying table + NewRenameConstraintAction(conn, o.Table, o.From, o.To), + }, nil } func (o *OpRenameConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_rename_table.go b/pkg/migrations/op_rename_table.go index 7033fbb43..21658dff8 100644 --- a/pkg/migrations/op_rename_table.go +++ b/pkg/migrations/op_rename_table.go @@ -18,10 +18,10 @@ func (o *OpRenameTable) Start(ctx context.Context, l Logger, conn db.DB, latestS return nil, s.RenameTable(o.From, o.To) } -func (o *OpRenameTable) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpRenameTable) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - return NewRenameTableAction(conn, o.From, o.To).Execute(ctx) + return []DBAction{NewRenameTableAction(conn, o.From, o.To)}, nil } func (o *OpRenameTable) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_check.go b/pkg/migrations/op_set_check.go index e7200a2cf..e9bf0a194 100644 --- a/pkg/migrations/op_set_check.go +++ b/pkg/migrations/op_set_check.go @@ -37,16 +37,13 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, l Logger, conn db.DB, return backfill.NewTask(table), nil } -func (o *OpSetCheckConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpSetCheckConstraint) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - // Validate the check constraint - err := NewValidateConstraintAction(conn, o.Table, o.Check.Name).Execute(ctx) - if err != nil { - return err - } - - return nil + return []DBAction{ + // Validate the check constraint + NewValidateConstraintAction(conn, o.Table, o.Check.Name), + }, nil } func (o *OpSetCheckConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_comment.go b/pkg/migrations/op_set_comment.go index cb644d7f2..0381d4e95 100644 --- a/pkg/migrations/op_set_comment.go +++ b/pkg/migrations/op_set_comment.go @@ -32,10 +32,12 @@ func (o *OpSetComment) Start(ctx context.Context, l Logger, conn db.DB, latestSc return backfill.NewTask(tbl), NewCommentColumnAction(conn, o.Table, TemporaryName(o.Column), o.Comment).Execute(ctx) } -func (o *OpSetComment) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpSetComment) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - return NewCommentColumnAction(conn, o.Table, o.Column, o.Comment).Execute(ctx) + return []DBAction{ + NewCommentColumnAction(conn, o.Table, o.Column, o.Comment), + }, nil } func (o *OpSetComment) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_default.go b/pkg/migrations/op_set_default.go index 48921d457..78aaef2ac 100644 --- a/pkg/migrations/op_set_default.go +++ b/pkg/migrations/op_set_default.go @@ -47,10 +47,10 @@ func (o *OpSetDefault) Start(ctx context.Context, l Logger, conn db.DB, latestSc return backfill.NewTask(table), nil } -func (o *OpSetDefault) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpSetDefault) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - return nil + return []DBAction{}, nil } func (o *OpSetDefault) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_fk.go b/pkg/migrations/op_set_fk.go index 365fba45b..3a69dd9ff 100644 --- a/pkg/migrations/op_set_fk.go +++ b/pkg/migrations/op_set_fk.go @@ -64,16 +64,13 @@ func (o *OpSetForeignKey) Start(ctx context.Context, l Logger, conn db.DB, lates return backfill.NewTask(table), nil } -func (o *OpSetForeignKey) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpSetForeignKey) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - // Validate the foreign key constraint - err := NewValidateConstraintAction(conn, o.Table, o.References.Name).Execute(ctx) - if err != nil { - return err - } - - return nil + return []DBAction{ + // Validate the foreign key constraint + NewValidateConstraintAction(conn, o.Table, o.References.Name), + }, nil } func (o *OpSetForeignKey) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_notnull.go b/pkg/migrations/op_set_notnull.go index e6341bf97..f75a4ab21 100644 --- a/pkg/migrations/op_set_notnull.go +++ b/pkg/migrations/op_set_notnull.go @@ -6,8 +6,6 @@ import ( "context" "fmt" - "github.com/lib/pq" - "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" @@ -52,33 +50,20 @@ func (o *OpSetNotNull) Start(ctx context.Context, l Logger, conn db.DB, latestSc return backfill.NewTask(table), nil } -func (o *OpSetNotNull) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpSetNotNull) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) - // Validate the NOT NULL constraint on the old column. - // The constraint must be valid because: - // * Existing NULL values in the old column were rewritten using the `up` SQL during backfill. - // * New NULL values written to the old column during the migration period were also rewritten using `up` SQL. - err := NewValidateConstraintAction(conn, o.Table, NotNullConstraintName(o.Column)).Execute(ctx) - if err != nil { - return err - } - - // Use the validated constraint to add `NOT NULL` to the new column - _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ALTER COLUMN %s SET NOT NULL", - pq.QuoteIdentifier(o.Table), - pq.QuoteIdentifier(TemporaryName(o.Column)))) - if err != nil { - return err - } - - // Drop the NOT NULL constraint - err = NewDropConstraintAction(conn, o.Table, NotNullConstraintName(o.Column)).Execute(ctx) - if err != nil { - return err - } - - return nil + return []DBAction{ + // Validate the NOT NULL constraint on the old column. + // The constraint must be valid because: + // * Existing NULL values in the old column were rewritten using the `up` SQL during backfill. + // * New NULL values written to the old column during the migration period were also rewritten using `up` SQL. + NewValidateConstraintAction(conn, o.Table, NotNullConstraintName(o.Column)), + // Use the validated constraint to add `NOT NULL` to the new column + NewSetNotNullAction(conn, o.Table, TemporaryName(o.Column)), + // Drop the NOT NULL constraint + NewDropConstraintAction(conn, o.Table, NotNullConstraintName(o.Column)), + }, nil } func (o *OpSetNotNull) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_replica_identity.go b/pkg/migrations/op_set_replica_identity.go index 43e450880..22684c4f8 100644 --- a/pkg/migrations/op_set_replica_identity.go +++ b/pkg/migrations/op_set_replica_identity.go @@ -33,11 +33,11 @@ func (o *OpSetReplicaIdentity) Start(ctx context.Context, l Logger, conn db.DB, return nil, err } -func (o *OpSetReplicaIdentity) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpSetReplicaIdentity) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) // No-op - return nil + return []DBAction{}, nil } func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_unique.go b/pkg/migrations/op_set_unique.go index 9f50e76c5..aae29739e 100644 --- a/pkg/migrations/op_set_unique.go +++ b/pkg/migrations/op_set_unique.go @@ -35,11 +35,11 @@ func (o *OpSetUnique) Start(ctx context.Context, l Logger, conn db.DB, latestSch return backfill.NewTask(table), NewCreateUniqueIndexConcurrentlyAction(conn, s.Name, o.Name, table.Name, column.Name).Execute(ctx) } -func (o *OpSetUnique) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { +func (o *OpSetUnique) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) { l.LogOperationComplete(o) // Create a unique constraint using the unique index - return NewAddConstraintUsingUniqueIndex(conn, o.Table, o.Name, o.Name).Execute(ctx) + return []DBAction{NewAddConstraintUsingUniqueIndex(conn, o.Table, o.Name, o.Name)}, nil } func (o *OpSetUnique) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/roll/execute.go b/pkg/roll/execute.go index 2c4a2c9b1..9de1b4a9f 100644 --- a/pkg/roll/execute.go +++ b/pkg/roll/execute.go @@ -204,9 +204,15 @@ func (m *Roll) Complete(ctx context.Context) error { // execute operations refreshViews := false for _, op := range migration.Operations { - err := op.Complete(ctx, m.logger, m.pgConn, currentSchema) + actions, err := op.Complete(m.logger, m.pgConn, currentSchema) if err != nil { - return fmt.Errorf("unable to execute complete operation: %w", err) + return fmt.Errorf("unable to collect actions for complete operation: %w", err) + } + + for _, action := range actions { + if err := action.Execute(ctx); err != nil { + return fmt.Errorf("unable to execute complete operation: %w", err) + } } currentSchema, err = m.state.ReadSchema(ctx, m.schema)