From fb7bbc41851374211676fbbcd9c778451ef42069 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?No=C3=A9mi=20V=C3=A1nyi?= Date: Fri, 20 Jun 2025 13:57:23 +0200 Subject: [PATCH] Add trigger configuration to `backfill.Task` && consolidate triggers per column into single trigger --- pkg/backfill/backfill.go | 72 +++++++++++++-- .../templates/function.go | 4 +- .../templates/trigger.go | 0 pkg/backfill/trigger.go | 92 ++++++++++++++++++- pkg/{migrations => backfill}/trigger_test.go | 80 ++++++++++++---- pkg/migrations/op_add_column.go | 12 +-- pkg/migrations/op_alter_column.go | 42 ++++----- pkg/migrations/op_create_constraint.go | 34 ++++--- pkg/migrations/op_drop_column.go | 15 ++- pkg/migrations/op_drop_constraint.go | 34 ++++--- .../op_drop_multicolumn_constraint.go | 33 +++---- pkg/migrations/trigger.go | 91 ------------------ pkg/roll/execute.go | 2 +- 13 files changed, 299 insertions(+), 212 deletions(-) rename pkg/{migrations => backfill}/templates/function.go (85%) rename pkg/{migrations => backfill}/templates/trigger.go (100%) rename pkg/{migrations => backfill}/trigger_test.go (68%) delete mode 100644 pkg/migrations/trigger.go diff --git a/pkg/backfill/backfill.go b/pkg/backfill/backfill.go index c8bd96191..14b43fcf5 100644 --- a/pkg/backfill/backfill.go +++ b/pkg/backfill/backfill.go @@ -7,6 +7,7 @@ import ( "database/sql" "errors" "fmt" + "strings" "time" "github.com/lib/pq" @@ -22,13 +23,16 @@ const CNeedsBackfillColumn = "_pgroll_needs_backfill" // Task represents a backfill task for a specific table from an operation. type Task struct { table *schema.Table - triggers []TriggerConfig + triggers []OperationTrigger } // Job is a collection of all tables that need to be backfilled and their associated triggers. type Job struct { - Tables []*schema.Table - triggers []TriggerConfig + schemaName string + latestSchema string + triggers map[string]triggerConfig + + Tables []*schema.Table } type Backfill struct { @@ -38,16 +42,62 @@ type Backfill struct { type CallbackFn func(done int64, total int64) -func NewTask(table *schema.Table, triggers ...TriggerConfig) *Task { +func NewTask(table *schema.Table, triggers ...OperationTrigger) *Task { return &Task{ table: table, triggers: triggers, } } +func NewJob(schemaName, latestSchema string) *Job { + return &Job{ + schemaName: schemaName, + latestSchema: latestSchema, + triggers: make(map[string]triggerConfig, 0), + Tables: make([]*schema.Table, 0), + } +} + +func (t *Task) AddTriggers(other *Task) { + t.triggers = append(t.triggers, other.triggers...) +} + func (j *Job) AddTask(t *Task) { - j.Tables = append(j.Tables, t.table) - j.triggers = append(j.triggers, t.triggers...) + if t.table != nil { + j.Tables = append(j.Tables, t.table) + } + + for _, trigger := range t.triggers { + if tg, exists := j.triggers[trigger.Name]; exists { + tg.SQL = append(tg.SQL, rewriteTriggerSQL(trigger.SQL, findColumnName(tg.Columns, tg.PhysicalColumn), tg.PhysicalColumn)) + j.triggers[trigger.Name] = tg + } else { + j.triggers[trigger.Name] = triggerConfig{ + Name: trigger.Name, + Direction: trigger.Direction, + Columns: trigger.Columns, + SchemaName: j.schemaName, + TableName: trigger.TableName, + PhysicalColumn: trigger.PhysicalColumn, + LatestSchema: j.latestSchema, + SQL: []string{trigger.SQL}, + NeedsBackfillColumn: CNeedsBackfillColumn, + } + } + } +} + +func rewriteTriggerSQL(sql string, from, to string) string { + return strings.ReplaceAll(sql, from, fmt.Sprintf("NEW.%s", pq.QuoteIdentifier(to))) +} + +func findColumnName(columns map[string]*schema.Column, columnName string) string { + for name, col := range columns { + if col.Name == columnName { + return name + } + } + return columnName } // New creates a new backfill operation with the given options. The backfill is @@ -63,7 +113,15 @@ func New(conn db.DB, c *Config) *Backfill { // CreateTriggers creates the triggers for the tables before starting the backfill. func (bf *Backfill) CreateTriggers(ctx context.Context, j *Job) error { - // Not yet implemented, triggers are loaded during the Start method. + for _, trigger := range j.triggers { + a := &createTriggerAction{ + conn: bf.conn, + cfg: trigger, + } + if err := a.execute(ctx); err != nil { + return fmt.Errorf("creating trigger %q: %w", trigger.Name, err) + } + } return nil } diff --git a/pkg/migrations/templates/function.go b/pkg/backfill/templates/function.go similarity index 85% rename from pkg/migrations/templates/function.go rename to pkg/backfill/templates/function.go index 2fb932ad3..cfafe7833 100644 --- a/pkg/migrations/templates/function.go +++ b/pkg/backfill/templates/function.go @@ -20,7 +20,9 @@ const Function = `CREATE OR REPLACE FUNCTION {{ .Name | qi }}() FROM current_setting('search_path'); IF search_path {{- if eq .Direction "up" }} != {{- else }} = {{- end }} {{ .LatestSchema | ql }} THEN - NEW.{{ .PhysicalColumn | qi }} = {{ .SQL }}; + {{- $physicalColumn := .PhysicalColumn | qi }}{{ range $s := .SQL }} + NEW.{{ $physicalColumn }} = {{ $s }}; + {{- end }} NEW.{{ .NeedsBackfillColumn | qi }} = false; END IF; diff --git a/pkg/migrations/templates/trigger.go b/pkg/backfill/templates/trigger.go similarity index 100% rename from pkg/migrations/templates/trigger.go rename to pkg/backfill/templates/trigger.go diff --git a/pkg/backfill/trigger.go b/pkg/backfill/trigger.go index eadc80027..9b543aa2c 100644 --- a/pkg/backfill/trigger.go +++ b/pkg/backfill/trigger.go @@ -3,6 +3,16 @@ package backfill import ( + "bytes" + "context" + "database/sql" + "fmt" + "text/template" + + "github.com/lib/pq" + + "github.com/xataio/pgroll/pkg/backfill/templates" + "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -13,7 +23,7 @@ const ( TriggerDirectionDown TriggerDirection = "down" ) -type TriggerConfig struct { +type triggerConfig struct { Name string Direction TriggerDirection Columns map[string]*schema.Column @@ -21,10 +31,88 @@ type TriggerConfig struct { TableName string PhysicalColumn string LatestSchema string - SQL string + SQL []string NeedsBackfillColumn string } +type OperationTrigger struct { + Name string + Direction TriggerDirection + Columns map[string]*schema.Column + TableName string + PhysicalColumn string + SQL string +} + +type createTriggerAction struct { + conn db.DB + cfg triggerConfig +} + +func (a *createTriggerAction) execute(ctx context.Context) error { + // Parenthesize the up/down SQL if it's not parenthesized already + for i, sql := range a.cfg.SQL { + if len(sql) > 0 && sql[0] != '(' { + a.cfg.SQL[i] = "(" + sql + ")" + } + } + + a.cfg.NeedsBackfillColumn = CNeedsBackfillColumn + + funcSQL, err := buildFunction(a.cfg) + if err != nil { + return err + } + + triggerSQL, err := buildTrigger(a.cfg) + if err != nil { + return err + } + + return a.conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + _, err := a.conn.ExecContext(ctx, + fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS %s boolean DEFAULT true", + pq.QuoteIdentifier(a.cfg.TableName), + pq.QuoteIdentifier(CNeedsBackfillColumn))) + if err != nil { + return err + } + + _, err = a.conn.ExecContext(ctx, funcSQL) + if err != nil { + return err + } + + _, err = a.conn.ExecContext(ctx, triggerSQL) + return err + }) +} + +func buildFunction(cfg triggerConfig) (string, error) { + return executeTemplate("function", templates.Function, cfg) +} + +func buildTrigger(cfg triggerConfig) (string, error) { + return executeTemplate("trigger", templates.Trigger, cfg) +} + +func executeTemplate(name, content string, cfg triggerConfig) (string, error) { + tmpl := template.Must(template. + New(name). + Funcs(template.FuncMap{ + "ql": pq.QuoteLiteral, + "qi": pq.QuoteIdentifier, + }). + Parse(content)) + + buf := bytes.Buffer{} + if err := tmpl.Execute(&buf, cfg); err != nil { + return "", err + } + + return buf.String(), nil +} + // TriggerFunctionName returns the name of the trigger function // for a given table and column. func TriggerFunctionName(tableName, columnName string) string { diff --git a/pkg/migrations/trigger_test.go b/pkg/backfill/trigger_test.go similarity index 68% rename from pkg/migrations/trigger_test.go rename to pkg/backfill/trigger_test.go index 727357f9c..d09abdc11 100644 --- a/pkg/migrations/trigger_test.go +++ b/pkg/backfill/trigger_test.go @@ -1,26 +1,25 @@ // SPDX-License-Identifier: Apache-2.0 -package migrations +package backfill import ( "testing" "github.com/stretchr/testify/assert" - "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/schema" ) func TestBuildFunction(t *testing.T) { testCases := []struct { name string - config backfill.TriggerConfig + config triggerConfig expected string }{ { name: "simple up trigger", - config: backfill.TriggerConfig{ + config: triggerConfig{ Name: "triggerName", - Direction: backfill.TriggerDirectionUp, + Direction: TriggerDirectionUp, Columns: map[string]*schema.Column{ "id": {Name: "id", Type: "int"}, "username": {Name: "username", Type: "text"}, @@ -31,8 +30,8 @@ func TestBuildFunction(t *testing.T) { LatestSchema: "public_01_migration_name", TableName: "reviews", PhysicalColumn: "_pgroll_new_review", - NeedsBackfillColumn: backfill.CNeedsBackfillColumn, - SQL: "product || 'is good'", + NeedsBackfillColumn: CNeedsBackfillColumn, + SQL: []string{"product || 'is good'"}, }, expected: `CREATE OR REPLACE FUNCTION "triggerName"() RETURNS TRIGGER @@ -55,15 +54,62 @@ func TestBuildFunction(t *testing.T) { NEW."_pgroll_needs_backfill" = false; END IF; + RETURN NEW; + END; $$ +`, + }, + { + name: "multiple up trigger", + config: triggerConfig{ + Name: "triggerName", + Direction: TriggerDirectionUp, + Columns: map[string]*schema.Column{ + "id": {Name: "id", Type: "int"}, + "username": {Name: "username", Type: "text"}, + "product": {Name: "product", Type: "text"}, + "review": {Name: "review", Type: "text"}, + }, + SchemaName: "public", + LatestSchema: "public_01_migration_name", + TableName: "reviews", + PhysicalColumn: "_pgroll_new_review", + NeedsBackfillColumn: CNeedsBackfillColumn, + SQL: []string{ + "product || 'is good'", + "CASE WHEN NEW.\"_pgroll_new_review\" = 'bad' THEN 'bad review' ELSE 'good review' END", + }, + }, + expected: `CREATE OR REPLACE FUNCTION "triggerName"() + RETURNS TRIGGER + LANGUAGE PLPGSQL + AS $$ + DECLARE + "id" "public"."reviews"."id"%TYPE := NEW."id"; + "product" "public"."reviews"."product"%TYPE := NEW."product"; + "review" "public"."reviews"."review"%TYPE := NEW."review"; + "username" "public"."reviews"."username"%TYPE := NEW."username"; + latest_schema text; + search_path text; + BEGIN + SELECT current_setting + INTO search_path + FROM current_setting('search_path'); + + IF search_path != 'public_01_migration_name' THEN + NEW."_pgroll_new_review" = product || 'is good'; + NEW."_pgroll_new_review" = CASE WHEN NEW."_pgroll_new_review" = 'bad' THEN 'bad review' ELSE 'good review' END; + NEW."_pgroll_needs_backfill" = false; + END IF; + RETURN NEW; END; $$ `, }, { name: "simple down trigger", - config: backfill.TriggerConfig{ + config: triggerConfig{ Name: "triggerName", - Direction: backfill.TriggerDirectionDown, + Direction: TriggerDirectionDown, Columns: map[string]*schema.Column{ "id": {Name: "id", Type: "int"}, "username": {Name: "username", Type: "text"}, @@ -74,8 +120,8 @@ func TestBuildFunction(t *testing.T) { LatestSchema: "public_01_migration_name", TableName: "reviews", PhysicalColumn: "review", - NeedsBackfillColumn: backfill.CNeedsBackfillColumn, - SQL: `NEW."_pgroll_new_review"`, + NeedsBackfillColumn: CNeedsBackfillColumn, + SQL: []string{`NEW."_pgroll_new_review"`}, }, expected: `CREATE OR REPLACE FUNCTION "triggerName"() RETURNS TRIGGER @@ -104,9 +150,9 @@ func TestBuildFunction(t *testing.T) { }, { name: "down trigger with aliased column", - config: backfill.TriggerConfig{ + config: triggerConfig{ Name: "triggerName", - Direction: backfill.TriggerDirectionDown, + Direction: TriggerDirectionDown, Columns: map[string]*schema.Column{ "id": {Name: "id", Type: "int"}, "username": {Name: "username", Type: "text"}, @@ -118,8 +164,8 @@ func TestBuildFunction(t *testing.T) { LatestSchema: "public_01_migration_name", TableName: "reviews", PhysicalColumn: "rating", - NeedsBackfillColumn: backfill.CNeedsBackfillColumn, - SQL: `CAST(rating as text)`, + NeedsBackfillColumn: CNeedsBackfillColumn, + SQL: []string{`CAST(rating as text)`}, }, expected: `CREATE OR REPLACE FUNCTION "triggerName"() RETURNS TRIGGER @@ -163,12 +209,12 @@ func TestBuildFunction(t *testing.T) { func TestBuildTrigger(t *testing.T) { testCases := []struct { name string - config backfill.TriggerConfig + config triggerConfig expected string }{ { name: "trigger", - config: backfill.TriggerConfig{ + config: triggerConfig{ Name: "triggerName", TableName: "reviews", }, diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index bcc124077..bd3709a69 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -102,22 +102,16 @@ func (o *OpAddColumn) Start(ctx context.Context, l Logger, conn db.DB, latestSch var task *backfill.Task if o.Up != "" { - err := NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + task = backfill.NewTask(table, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, o.Column.Name), Direction: backfill.TriggerDirectionUp, Columns: table.Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, TableName: table.Name, PhysicalColumn: TemporaryName(o.Column.Name), SQL: o.Up, }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create trigger: %w", err) - } - task = backfill.NewTask(table) + ) } tmpColumn := toSchemaColumn(o.Column) diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index df639e2e3..db3132568 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -38,22 +38,25 @@ func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, latestS return nil, fmt.Errorf("failed to duplicate column: %w", err) } + // Copy the columns from table columns, so we can use it later + // in the down trigger with the physical name + upColumns := make(map[string]*schema.Column) + for name, col := range table.Columns { + upColumns[name] = col + } + // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. - err := NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + triggers := make([]backfill.OperationTrigger, 0) + triggers = append(triggers, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, o.Column), Direction: backfill.TriggerDirectionUp, - Columns: table.Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, TableName: table.Name, + Columns: upColumns, PhysicalColumn: TemporaryName(o.Column), SQL: o.upSQLForOperations(ops), }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create up trigger: %w", err) - } + ) // Add the new column to the internal schema representation. This is done // here, before creation of the down trigger, so that the trigger can declare @@ -65,30 +68,27 @@ func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, latestS }) // Add a trigger to copy values from the new column to the old. - err = NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + triggers = append(triggers, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, TemporaryName(o.Column)), Direction: backfill.TriggerDirectionDown, - Columns: table.Columns, - LatestSchema: latestSchema, - SchemaName: s.Name, TableName: table.Name, + Columns: table.Columns, PhysicalColumn: oldPhysicalColumn, SQL: o.downSQLForOperations(ops), }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create down trigger: %w", err) - } - + ) + task := backfill.NewTask(table, triggers...) // perform any operation specific start steps for _, op := range ops { - if _, err := op.Start(ctx, l, conn, latestSchema, s); err != nil { + bf, err := op.Start(ctx, l, conn, latestSchema, s) + if err != nil { return nil, err } + task.AddTriggers(bf) } - return backfill.NewTask(table), nil + return task, nil } func (o *OpAlterColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_create_constraint.go b/pkg/migrations/op_create_constraint.go index f146524c1..9a9c537a5 100644 --- a/pkg/migrations/op_create_constraint.go +++ b/pkg/migrations/op_create_constraint.go @@ -41,24 +41,27 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, la return nil, fmt.Errorf("failed to duplicate columns for new constraint: %w", err) } + // Copy the columns from table columns, so we can use it later + // in the down trigger with the physical name + upColumns := make(map[string]*schema.Column) + for name, col := range table.Columns { + upColumns[name] = col + } + // Setup triggers + triggers := make([]backfill.OperationTrigger, 0) for _, colName := range o.Columns { upSQL := o.Up[colName] - err := NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + triggers = append(triggers, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, colName), Direction: backfill.TriggerDirectionUp, - Columns: table.Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, + Columns: upColumns, TableName: table.Name, PhysicalColumn: TemporaryName(colName), SQL: upSQL, }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create up trigger: %w", err) - } + ) // Add the new column to the internal schema representation. This is done // here, before creation of the down trigger, so that the trigger can declare @@ -70,24 +73,19 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, la }) downSQL := o.Down[colName] - err = NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + triggers = append(triggers, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, TemporaryName(colName)), Direction: backfill.TriggerDirectionDown, Columns: table.Columns, - LatestSchema: latestSchema, - SchemaName: s.Name, TableName: table.Name, PhysicalColumn: oldPhysicalColumn, SQL: downSQL, }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create down trigger: %w", err) - } + ) } - task := backfill.NewTask(table) + task := backfill.NewTask(table, triggers...) switch o.Type { case OpCreateConstraintTypeUnique, OpCreateConstraintTypePrimaryKey: diff --git a/pkg/migrations/op_drop_column.go b/pkg/migrations/op_drop_column.go index 1f6c4654f..dc6f037b5 100644 --- a/pkg/migrations/op_drop_column.go +++ b/pkg/migrations/op_drop_column.go @@ -18,22 +18,18 @@ var ( func (o *OpDropColumn) Start(ctx context.Context, l Logger, conn db.DB, latestSchema string, s *schema.Schema) (*backfill.Task, error) { l.LogOperationStart(o) + var task *backfill.Task if o.Down != "" { - err := NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + task = backfill.NewTask(nil, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, o.Column), Direction: backfill.TriggerDirectionDown, Columns: s.GetTable(o.Table).Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, TableName: s.GetTable(o.Table).Name, PhysicalColumn: o.Column, SQL: o.Down, }, - ).Execute(ctx) - if err != nil { - return nil, err - } + ) } table := s.GetTable(o.Table) @@ -46,7 +42,8 @@ func (o *OpDropColumn) Start(ctx context.Context, l Logger, conn db.DB, latestSc } s.GetTable(o.Table).RemoveColumn(o.Column) - return nil, nil + + return task, nil } func (o *OpDropColumn) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_constraint.go b/pkg/migrations/op_drop_constraint.go index d768e4e04..e7ec07e8f 100644 --- a/pkg/migrations/op_drop_constraint.go +++ b/pkg/migrations/op_drop_constraint.go @@ -36,22 +36,25 @@ func (o *OpDropConstraint) Start(ctx context.Context, l Logger, conn db.DB, late return nil, fmt.Errorf("failed to duplicate column: %w", err) } + // Copy the columns from table columns, so we can use it later + // in the down trigger with the physical name + upColumns := make(map[string]*schema.Column) + for name, col := range table.Columns { + upColumns[name] = col + } + // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. - err := NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + triggers := make([]backfill.OperationTrigger, 0) + triggers = append(triggers, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, column.Name), Direction: backfill.TriggerDirectionUp, - Columns: table.Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, + Columns: upColumns, TableName: o.Table, PhysicalColumn: TemporaryName(column.Name), SQL: o.upSQL(column.Name), }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create up trigger: %w", err) - } + ) // Add the new column to the internal schema representation. This is done // here, before creation of the down trigger, so that the trigger can declare @@ -61,22 +64,17 @@ func (o *OpDropConstraint) Start(ctx context.Context, l Logger, conn db.DB, late }) // Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL. - err = NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + triggers = append(triggers, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, TemporaryName(column.Name)), Direction: backfill.TriggerDirectionDown, Columns: table.Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, TableName: o.Table, PhysicalColumn: column.Name, SQL: o.Down, }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create down trigger: %w", err) - } - return backfill.NewTask(table), nil + ) + return backfill.NewTask(table, triggers...), nil } func (o *OpDropConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_multicolumn_constraint.go b/pkg/migrations/op_drop_multicolumn_constraint.go index 9168b630d..d0ff597c8 100644 --- a/pkg/migrations/op_drop_multicolumn_constraint.go +++ b/pkg/migrations/op_drop_multicolumn_constraint.go @@ -49,23 +49,25 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, l Logger, conn } // Create triggers for each column covered by the constraint to be dropped + triggers := make([]backfill.OperationTrigger, 0) for _, columnName := range table.GetConstraintColumns(o.Name) { + // Copy the columns from table columns, so we can use it later + // in the down trigger with the physical name + upColumns := make(map[string]*schema.Column) + for name, col := range table.Columns { + upColumns[name] = col + } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. - err := NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + triggers = append(triggers, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, columnName), Direction: backfill.TriggerDirectionUp, - Columns: table.Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, + Columns: upColumns, TableName: table.Name, PhysicalColumn: TemporaryName(columnName), SQL: o.upSQL(columnName), }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create up trigger: %w", err) - } + ) // Add the new column to the internal schema representation. This is done // here, before creation of the down trigger, so that the trigger can declare @@ -77,24 +79,19 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, l Logger, conn }) // Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL. - err = NewCreateTriggerAction(conn, - backfill.TriggerConfig{ + triggers = append(triggers, + backfill.OperationTrigger{ Name: backfill.TriggerName(o.Table, TemporaryName(columnName)), Direction: backfill.TriggerDirectionDown, Columns: table.Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, TableName: table.Name, PhysicalColumn: oldPhysicalColumn, SQL: o.Down[columnName], }, - ).Execute(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create down trigger: %w", err) - } + ) } - return backfill.NewTask(table), nil + return backfill.NewTask(table, triggers...), nil } func (o *OpDropMultiColumnConstraint) Complete(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) error { diff --git a/pkg/migrations/trigger.go b/pkg/migrations/trigger.go deleted file mode 100644 index e64b7cbfa..000000000 --- a/pkg/migrations/trigger.go +++ /dev/null @@ -1,91 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package migrations - -import ( - "bytes" - "context" - "database/sql" - "fmt" - "text/template" - - "github.com/lib/pq" - - "github.com/xataio/pgroll/pkg/backfill" - "github.com/xataio/pgroll/pkg/db" - "github.com/xataio/pgroll/pkg/migrations/templates" -) - -type createTriggerAction struct { - conn db.DB - cfg backfill.TriggerConfig -} - -func NewCreateTriggerAction(conn db.DB, cfg backfill.TriggerConfig) DBAction { - return &createTriggerAction{ - conn: conn, - cfg: cfg, - } -} - -func (a *createTriggerAction) Execute(ctx context.Context) error { - // Parenthesize the up/down SQL if it's not parenthesized already - if len(a.cfg.SQL) > 0 && a.cfg.SQL[0] != '(' { - a.cfg.SQL = "(" + a.cfg.SQL + ")" - } - - a.cfg.NeedsBackfillColumn = backfill.CNeedsBackfillColumn - - funcSQL, err := buildFunction(a.cfg) - if err != nil { - return err - } - - triggerSQL, err := buildTrigger(a.cfg) - if err != nil { - return err - } - - return a.conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { - _, err := a.conn.ExecContext(ctx, - fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS %s boolean DEFAULT true", - pq.QuoteIdentifier(a.cfg.TableName), - pq.QuoteIdentifier(backfill.CNeedsBackfillColumn))) - if err != nil { - return err - } - - _, err = a.conn.ExecContext(ctx, funcSQL) - if err != nil { - return err - } - - _, err = a.conn.ExecContext(ctx, triggerSQL) - return err - }) -} - -func buildFunction(cfg backfill.TriggerConfig) (string, error) { - return executeTemplate("function", templates.Function, cfg) -} - -func buildTrigger(cfg backfill.TriggerConfig) (string, error) { - return executeTemplate("trigger", templates.Trigger, cfg) -} - -func executeTemplate(name, content string, cfg backfill.TriggerConfig) (string, error) { - tmpl := template.Must(template. - New(name). - Funcs(template.FuncMap{ - "ql": pq.QuoteLiteral, - "qi": pq.QuoteIdentifier, - }). - Parse(content)) - - buf := bytes.Buffer{} - if err := tmpl.Execute(&buf, cfg); err != nil { - return "", err - } - - return buf.String(), nil -} diff --git a/pkg/roll/execute.go b/pkg/roll/execute.go index 2c4a2c9b1..d0ddefc9a 100644 --- a/pkg/roll/execute.go +++ b/pkg/roll/execute.go @@ -99,7 +99,7 @@ func (m *Roll) StartDDLOperations(ctx context.Context, migration *migrations.Mig } // execute operations - job := &backfill.Job{} + job := backfill.NewJob(m.schema, versionSchemaName) for _, op := range migration.Operations { task, err := op.Start(ctx, m.logger, m.pgConn, versionSchemaName, newSchema) if err != nil {