Skip to content

Commit 4d894cb

Browse files
committed
Add support for nulls_not_distinct option in create_constraint operations
1 parent c15326b commit 4d894cb

15 files changed

+83
-10
lines changed

docs/operations/create_constraint.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Required fields: `name`, `table`, `type`, `up`, `down`.
1818
"type": "unique"| "check" | "foreign_key",
1919
"check": "SQL expression for CHECK constraint",
2020
"no_inherit": "true|false",
21+
"nulls_not_distinct": "true|false",
2122
"references": {
2223
"name": "name of foreign key reference",
2324
"table": "name of referenced table",

docs/operations/create_index.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ description: A create index operation creates a new index on a set of columns.
2727
"predicate": "conditional expression for defining a partial index",
2828
"storage_parameters": "comma-separated list of storage parameters",
2929
"unique": true | false,
30+
"nulls_not_distinct": true | false,
3031
"method": "btree"
3132
}
3233
}
@@ -35,6 +36,7 @@ description: A create index operation creates a new index on a set of columns.
3536
* The field `method` can be `btree`, `hash`, `gist`, `spgist`, `gin`, `brin`.
3637
* You can also specify storage parameters for the index in `storage_parameters`.
3738
* To create a unique index set `unique` to `true`.
39+
* If the index is unique, you can configure `nulls_not_distinct`.
3840

3941
## Examples
4042

examples/44_add_table_unique_constraint.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"sellers_name",
1111
"sellers_zip"
1212
],
13+
"nulls_not_distinct": false,
1314
"up": {
1415
"sellers_name": "sellers_name",
1516
"sellers_zip": "sellers_zip"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
This is an invalid 'create_constraint' migration.
2+
Check constraints cannot configure nulls_not_distinct.
3+
4+
-- create_constraint.json --
5+
{
6+
"name": "migration_name",
7+
"operations": [
8+
{
9+
"create_constraint": {
10+
"name": "my_invalid_check",
11+
"table": "my_table",
12+
"type": "check",
13+
"check": "my_column > 5",
14+
"nulls_not_distinct": true,
15+
"columns": [
16+
"my_column"
17+
],
18+
"up": {
19+
"my_column": "my_column"
20+
},
21+
"down": {
22+
"my_column": "my_column"
23+
}
24+
}
25+
}
26+
]
27+
}
28+
29+
-- valid --
30+
false

pkg/migrations/duplicate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
136136
continue
137137
}
138138
if duplicatedMember, constraintColumns := d.stmtBuilder.allConstraintColumns(uc.Columns, colNames...); duplicatedMember {
139-
if err := createUniqueIndexConcurrently(ctx, d.conn, "", DuplicationName(uc.Name), d.stmtBuilder.table.Name, constraintColumns); err != nil {
139+
if err := createUniqueIndexConcurrently(ctx, d.conn, "", DuplicationName(uc.Name), d.stmtBuilder.table.Name, constraintColumns, uc.NullsNotDistinct); err != nil {
140140
return err
141141
}
142142
}

pkg/migrations/duplicate_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ func TestCreateIndexConcurrentlySqlGeneration(t *testing.T) {
231231
},
232232
} {
233233
t.Run(name, func(t *testing.T) {
234-
stmt := getCreateUniqueIndexConcurrentlySQL(testCases.indexName, testCases.schemaName, testCases.tableName, testCases.columns)
234+
stmt := getCreateUniqueIndexConcurrentlySQL(testCases.indexName, testCases.schemaName, testCases.tableName, testCases.columns, false)
235235
assert.Equal(t, testCases.expectedStmt, stmt)
236236
})
237237
}

pkg/migrations/index.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ import (
1212
"github.com/xataio/pgroll/pkg/db"
1313
)
1414

15-
func createUniqueIndexConcurrently(ctx context.Context, conn db.DB, schemaName string, indexName string, tableName string, columnNames []string) error {
15+
func createUniqueIndexConcurrently(ctx context.Context, conn db.DB, schemaName string, indexName string, tableName string, columnNames []string, nullsNotDistinct bool) error {
1616
quotedQualifiedIndexName := pq.QuoteIdentifier(indexName)
1717
if schemaName != "" {
1818
quotedQualifiedIndexName = fmt.Sprintf("%s.%s", pq.QuoteIdentifier(schemaName), pq.QuoteIdentifier(indexName))
1919
}
2020
for retryCount := 0; retryCount < 5; retryCount++ {
2121
// Add a unique index to the new column
2222
// Indexes are created in the same schema with the table automatically. Instead of the qualified one, just pass the index name.
23-
createIndexSQL := getCreateUniqueIndexConcurrentlySQL(indexName, schemaName, tableName, columnNames)
23+
createIndexSQL := getCreateUniqueIndexConcurrentlySQL(indexName, schemaName, tableName, columnNames, nullsNotDistinct)
2424
if _, err := conn.ExecContext(ctx, createIndexSQL); err != nil {
2525
return fmt.Errorf("failed to add unique index %q: %w", indexName, err)
2626
}
@@ -64,7 +64,7 @@ func createUniqueIndexConcurrently(ctx context.Context, conn db.DB, schemaName s
6464
return fmt.Errorf("failed to create unique index %q", indexName)
6565
}
6666

67-
func getCreateUniqueIndexConcurrentlySQL(indexName string, schemaName string, tableName string, columnNames []string) string {
67+
func getCreateUniqueIndexConcurrentlySQL(indexName string, schemaName string, tableName string, columnNames []string, nullsNotDistinct bool) string {
6868
// create unique index concurrently
6969
qualifiedTableName := pq.QuoteIdentifier(tableName)
7070
if schemaName != "" {
@@ -77,6 +77,9 @@ func getCreateUniqueIndexConcurrentlySQL(indexName string, schemaName string, ta
7777
qualifiedTableName,
7878
strings.Join(quoteColumnNames(columnNames), ", "),
7979
)
80+
if nullsNotDistinct {
81+
indexQuery += " NULLS NOT DISTINCT"
82+
}
8083

8184
return indexQuery
8285
}

pkg/migrations/op_add_column.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string
6262
}
6363

6464
if o.Column.Unique {
65-
if err := createUniqueIndexConcurrently(ctx, conn, s.Name, UniqueIndexName(o.Column.Name), table.Name, []string{TemporaryName(o.Column.Name)}); err != nil {
65+
if err := createUniqueIndexConcurrently(ctx, conn, s.Name, UniqueIndexName(o.Column.Name), table.Name, []string{TemporaryName(o.Column.Name)}, false); err != nil {
6666
return nil, fmt.Errorf("failed to add unique index: %w", err)
6767
}
6868
}

pkg/migrations/op_create_constraint.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema
8282

8383
switch o.Type {
8484
case OpCreateConstraintTypeUnique:
85-
return table, createUniqueIndexConcurrently(ctx, conn, s.Name, o.Name, table.Name, temporaryNames(o.Columns))
85+
return table, createUniqueIndexConcurrently(ctx, conn, s.Name, o.Name, table.Name, temporaryNames(o.Columns), o.NullsNotDistinct)
8686
case OpCreateConstraintTypeCheck:
8787
return table, o.addCheckConstraint(ctx, conn, table.Name)
8888
case OpCreateConstraintTypeForeignKey:
@@ -279,6 +279,10 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err
279279
}
280280
}
281281

282+
if o.Type != OpCreateConstraintTypeUnique && o.NullsNotDistinct {
283+
return fmt.Errorf("nulls_not_distinct can only be true for unique constraints")
284+
}
285+
282286
return nil
283287
}
284288

pkg/migrations/op_create_index.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, latestSchema stri
6161
}
6262
stmt += fmt.Sprintf(" (%s)", strings.Join(colSQLs, ", "))
6363

64+
if o.Unique && o.NullsNotDistinct {
65+
stmt += " NULLS NOT DISTINCT"
66+
}
67+
6468
if o.StorageParameters != "" {
6569
stmt += fmt.Sprintf(" WITH (%s)", o.StorageParameters)
6670
}

0 commit comments

Comments
 (0)