Skip to content

Return DBActions from Complete #902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/migrations/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Operation interface {
// Complete will update the database schema to match the current version
// after calling Start.
// This method should be called once the previous version is no longer used.
Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error
Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error)

// Rollback will revert the changes made by Start. It is not possible to
// rollback a completed migration.
Expand Down
68 changes: 17 additions & 51 deletions pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,69 +141,43 @@ func toSchemaColumn(c Column) *schema.Column {
return tmpColumn
}

func (o *OpAddColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
func (o *OpAddColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {
l.LogOperationComplete(o)

err := NewRenameColumnAction(conn, o.Table, TemporaryName(o.Column.Name), o.Column.Name).Execute(ctx)
if err != nil {
return err
}

err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column.Name)).Execute(ctx)
if err != nil {
return err
}

removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn)
err = removeBackfillColumn.Execute(ctx)
if err != nil {
return err
dbActions := []DBAction{
NewRenameColumnAction(conn, o.Table, TemporaryName(o.Column.Name), o.Column.Name),
NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column.Name)),
NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn),
}

if !o.Column.IsNullable() && o.Column.Default == nil {
err = upgradeNotNullConstraintToNotNullAttribute(ctx, conn, o.Table, o.Column.Name)
if err != nil {
return err
}
dbActions = append(dbActions, upgradeNotNullConstraintToNotNullAttribute(conn, o.Table, o.Column.Name)...)
}

if o.Column.Check != nil {
err = NewValidateConstraintAction(conn, o.Table, o.Column.Check.Name).Execute(ctx)
if err != nil {
return err
}
dbActions = append(dbActions, NewValidateConstraintAction(conn, o.Table, o.Column.Check.Name))
}

if o.Column.Unique {
err := NewAddConstraintUsingUniqueIndex(conn,
dbActions = append(dbActions, NewAddConstraintUsingUniqueIndex(conn,
o.Table,
o.Column.Name,
UniqueIndexName(o.Column.Name),
).Execute(ctx)
if err != nil {
return err
}
UniqueIndexName(o.Column.Name)))
}

// If the column has a DEFAULT that could not be set using the fast-path
// optimization, set it here.
column := s.GetTable(o.Table).GetColumn(TemporaryName(o.Column.Name))
if o.Column.HasDefault() && column.Default == nil {
err := NewSetDefaultValueAction(conn, o.Table, o.Column.Name, *o.Column.Default).Execute(ctx)
if err != nil {
return err
}
dbActions = append(dbActions, NewSetDefaultValueAction(conn, o.Table, o.Column.Name, *o.Column.Default))

// Validate the `NOT NULL` constraint on the column if necessary
if !o.Column.IsNullable() {
err = upgradeNotNullConstraintToNotNullAttribute(ctx, conn, o.Table, o.Column.Name)
if err != nil {
return err
}
dbActions = append(dbActions, upgradeNotNullConstraintToNotNullAttribute(conn, o.Table, o.Column.Name)...)
}
}

return nil
return dbActions, nil
}

func (o *OpAddColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
Expand Down Expand Up @@ -347,20 +321,12 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table,
// upgradeNotNullConstraintToNotNullAttribute validates and upgrades a NOT NULL
// constraint to a NOT NULL column attribute. The constraint is removed after
// the column attribute is added.
func upgradeNotNullConstraintToNotNullAttribute(ctx context.Context, conn db.DB, tableName, columnName string) error {
err := NewValidateConstraintAction(conn, tableName, NotNullConstraintName(columnName)).Execute(ctx)
if err != nil {
return err
}

err = NewSetNotNullAction(conn, tableName, columnName).Execute(ctx)
if err != nil {
return err
func upgradeNotNullConstraintToNotNullAttribute(conn db.DB, tableName, columnName string) []DBAction {
return []DBAction{
NewValidateConstraintAction(conn, tableName, NotNullConstraintName(columnName)),
NewSetNotNullAction(conn, tableName, columnName),
NewDropConstraintAction(conn, tableName, NotNullConstraintName(columnName)),
}

err = NewDropConstraintAction(conn, tableName, NotNullConstraintName(columnName)).Execute(ctx)

return err
}

// UniqueIndexName returns the name of the unique index for the given column
Expand Down
53 changes: 20 additions & 33 deletions pkg/migrations/op_alter_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,54 +91,41 @@ func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, latestS
return backfill.NewTask(table), nil
}

func (o *OpAlterColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
func (o *OpAlterColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {
l.LogOperationComplete(o)

ops := o.subOperations()

dbActions := []DBAction{}
// Perform any operation specific completion steps
for _, op := range ops {
if err := op.Complete(ctx, l, conn, s); err != nil {
return err
actions, err := op.Complete(l, conn, s)
if err != nil {
return []DBAction{}, err
}
}

if err := NewAlterSequenceOwnerAction(conn, o.Table, o.Column, TemporaryName(o.Column)).Execute(ctx); err != nil {
return err
}

removeOldColumn := NewDropColumnAction(conn, o.Table, o.Column)
err := removeOldColumn.Execute(ctx)
if err != nil {
return err
}

// Remove the up and down function and trigger
err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column), backfill.TriggerFunctionName(o.Table, TemporaryName(o.Column))).Execute(ctx)
if err != nil {
return err
}

removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn)
err = removeBackfillColumn.Execute(ctx)
if err != nil {
return err
dbActions = append(dbActions, actions...)
}

// Rename the new column to the old column name
table := s.GetTable(o.Table)
if table == nil {
return TableDoesNotExistError{Name: o.Table}
return []DBAction{}, TableDoesNotExistError{Name: o.Table}
}
column := table.GetColumn(o.Column)
if column == nil {
return ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
}
if err := NewRenameDuplicatedColumnAction(conn, table, column.Name).Execute(ctx); err != nil {
return err
}

return nil
return []DBAction{}, ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
}

return append(dbActions, []DBAction{
NewAlterSequenceOwnerAction(conn, o.Table, o.Column, TemporaryName(o.Column)),
NewDropColumnAction(conn, o.Table, o.Column),
NewDropFunctionAction(conn,
backfill.TriggerFunctionName(o.Table, o.Column),
backfill.TriggerFunctionName(o.Table, TemporaryName(o.Column)),
),
NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn),
NewRenameDuplicatedColumnAction(conn, table, column.Name),
}...), nil
}

func (o *OpAlterColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
Expand Down
4 changes: 2 additions & 2 deletions pkg/migrations/op_change_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ func (o *OpChangeType) Start(ctx context.Context, l Logger, conn db.DB, latestSc
return backfill.NewTask(table), nil
}

func (o *OpChangeType) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
func (o *OpChangeType) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {
l.LogOperationComplete(o)

return nil
return []DBAction{}, nil
}

func (o *OpChangeType) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
Expand Down
60 changes: 25 additions & 35 deletions pkg/migrations/op_create_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,83 +101,73 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, la
return task, nil
}

func (o *OpCreateConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
func (o *OpCreateConstraint) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {
l.LogOperationComplete(o)

dbActions := make([]DBAction, 0)
switch o.Type {
case OpCreateConstraintTypeUnique:
uniqueOp := &OpSetUnique{
Table: o.Table,
Name: o.Name,
}
err := uniqueOp.Complete(ctx, l, conn, s)
actions, err := uniqueOp.Complete(l, conn, s)
if err != nil {
return err
return []DBAction{}, err
}
dbActions = append(dbActions, actions...)
case OpCreateConstraintTypeCheck:
checkOp := &OpSetCheckConstraint{
Table: o.Table,
Check: CheckConstraint{
Name: o.Name,
},
}
err := checkOp.Complete(ctx, l, conn, s)
actions, err := checkOp.Complete(l, conn, s)
if err != nil {
return err
return []DBAction{}, err
}
dbActions = append(dbActions, actions...)
case OpCreateConstraintTypeForeignKey:
fkOp := &OpSetForeignKey{
Table: o.Table,
References: ForeignKeyReference{
Name: o.Name,
},
}
err := fkOp.Complete(ctx, l, conn, s)
actions, err := fkOp.Complete(l, conn, s)
if err != nil {
return err
return []DBAction{}, err
}
dbActions = append(dbActions, actions...)
case OpCreateConstraintTypePrimaryKey:
err := NewAddPrimaryKeyAction(conn, o.Table, o.Name).Execute(ctx)
if err != nil {
return err
}
dbActions = append(dbActions, NewAddPrimaryKeyAction(conn, o.Table, o.Name))
}

for _, col := range o.Columns {
if err := NewAlterSequenceOwnerAction(conn, o.Table, col, TemporaryName(col)).Execute(ctx); err != nil {
return err
}
dbActions = append(dbActions, NewAlterSequenceOwnerAction(conn, o.Table, col, TemporaryName(col)))
}

removeOldColumns := NewDropColumnAction(conn, o.Table, o.Columns...)
err := removeOldColumns.Execute(ctx)
if err != nil {
return err
}
dbActions = append(dbActions, NewDropColumnAction(conn, o.Table, o.Columns...))

// rename new columns to old name
table := s.GetTable(o.Table)
if table == nil {
return TableDoesNotExistError{Name: o.Table}
return []DBAction{}, TableDoesNotExistError{Name: o.Table}
}
for _, col := range o.Columns {
column := table.GetColumn(col)
if column == nil {
return ColumnDoesNotExistError{Table: o.Table, Name: col}
return []DBAction{}, ColumnDoesNotExistError{Table: o.Table, Name: col}
}
if err := NewRenameDuplicatedColumnAction(conn, table, column.Name).Execute(ctx); err != nil {
return err
}
}

if err := o.removeTriggers(ctx, conn); err != nil {
return err
dbActions = append(dbActions, NewRenameDuplicatedColumnAction(conn, table, column.Name))
}
dbActions = append(dbActions,
o.removeTriggers(conn),
NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn),
)

removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn)
err = removeBackfillColumn.Execute(ctx)

return err
return dbActions, nil
}

func (o *OpCreateConstraint) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
Expand All @@ -194,7 +184,7 @@ func (o *OpCreateConstraint) Rollback(ctx context.Context, l Logger, conn db.DB,
return err
}

if err := o.removeTriggers(ctx, conn); err != nil {
if err := o.removeTriggers(conn).Execute(ctx); err != nil {
return err
}

Expand All @@ -204,13 +194,13 @@ func (o *OpCreateConstraint) Rollback(ctx context.Context, l Logger, conn db.DB,
return err
}

func (o *OpCreateConstraint) removeTriggers(ctx context.Context, conn db.DB) error {
func (o *OpCreateConstraint) removeTriggers(conn db.DB) DBAction {
dropFuncs := make([]string, 0, len(o.Columns)*2)
for _, column := range o.Columns {
dropFuncs = append(dropFuncs, backfill.TriggerFunctionName(o.Table, column))
dropFuncs = append(dropFuncs, backfill.TriggerFunctionName(o.Table, TemporaryName(column)))
}
return NewDropFunctionAction(conn, dropFuncs...).Execute(ctx)
return NewDropFunctionAction(conn, dropFuncs...)
}

func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) error {
Expand Down
4 changes: 2 additions & 2 deletions pkg/migrations/op_create_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ func (o *OpCreateIndex) Start(ctx context.Context, l Logger, conn db.DB, latestS
return nil, err
}

func (o *OpCreateIndex) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
func (o *OpCreateIndex) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {
l.LogOperationComplete(o)

// No-op
return nil
return []DBAction{}, nil
}

func (o *OpCreateIndex) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
Expand Down
4 changes: 2 additions & 2 deletions pkg/migrations/op_create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ func (o *OpCreateTable) Start(ctx context.Context, l Logger, conn db.DB, latestS
return nil, nil
}

func (o *OpCreateTable) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
func (o *OpCreateTable) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {
l.LogOperationComplete(o)

// No-op
return nil
return []DBAction{}, nil
}

func (o *OpCreateTable) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
Expand Down
25 changes: 6 additions & 19 deletions pkg/migrations/op_drop_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,14 @@ func (o *OpDropColumn) Start(ctx context.Context, l Logger, conn db.DB, latestSc
return nil, nil
}

func (o *OpDropColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
func (o *OpDropColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {
l.LogOperationComplete(o)

removeColumn := NewDropColumnAction(conn, o.Table, o.Column)
err := removeColumn.Execute(ctx)
if err != nil {
return err
}

err = NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column)).Execute(ctx)
if err != nil {
return err
}

removeBackfillColumn := NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn)
err = removeBackfillColumn.Execute(ctx)
if err != nil {
return err
}

return nil
return []DBAction{
NewDropColumnAction(conn, o.Table, o.Column),
NewDropFunctionAction(conn, backfill.TriggerFunctionName(o.Table, o.Column)),
NewDropColumnAction(conn, o.Table, backfill.CNeedsBackfillColumn),
}, nil
}

func (o *OpDropColumn) Rollback(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error {
Expand Down
Loading
Loading