Skip to content

Commit 3a23ad8

Browse files
Rachel ClarkRachel Clark
authored andcommitted
Bug fixes, singular class names, and nullable operators toggle.
1 parent 8aa358e commit 3a23ad8

File tree

12 files changed

+179
-75
lines changed

12 files changed

+179
-75
lines changed

go.mod

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,7 @@ go 1.20
44

55
require github.com/tabbed/sqlc-go v1.16.0
66

7-
require google.golang.org/protobuf v1.28.1 // indirect
7+
require (
8+
github.com/jinzhu/inflection v1.0.0
9+
google.golang.org/protobuf v1.28.1 // indirect
10+
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
22
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
33
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
4+
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
5+
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
46
github.com/tabbed/sqlc-go v1.16.0 h1:EwPBXdGn5tyrLjcNiHRoQthWvJeF5NjG9Cx1WK5iFsY=
57
github.com/tabbed/sqlc-go v1.16.0/go.mod h1:mqMU5duZRGz5Wp/qJXwkERf+MXgGOZ8BmW/tH9KyvWA=
68
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

internal/core/class.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ type ClassMember struct {
1313
DBName string
1414
Type string
1515
Comment string
16+
NotNull bool
1617
Column *plugin.Column
1718
}
1819

internal/core/config.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
package core
22

33
type Config struct {
4-
Namespace string `json:"namespace"`
5-
QueryParamLimit int `json:"query_param_limit"`
6-
EmitAsync bool `json:"emit_async"`
7-
LogFile string `json:"log_file"`
4+
Namespace string `json:"namespace"`
5+
QueryParamLimit int `json:"query_param_limit"`
6+
EmitAsync bool `json:"emit_async"`
7+
EmitNullOperators bool `json:"emit_null_ops"`
8+
LogFile string `json:"log_file"`
9+
EmitExactTableNames bool `json:"emit_exact_table_names"`
10+
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
811
}

internal/core/cs_type.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
sdk "github.com/tabbed/sqlc-go/sdk"
66
)
77

8-
func CsType(req *plugin.CodeGenRequest, col *plugin.Column) string {
8+
func CsType(req *plugin.CodeGenRequest, col *plugin.Column, conf *Config) string {
99
for _, oride := range req.Settings.Overrides {
1010
if oride.CodeType == "" {
1111
continue
@@ -17,15 +17,15 @@ func CsType(req *plugin.CodeGenRequest, col *plugin.Column) string {
1717
}
1818
}
1919

20-
typ := csInnerType(req, col)
20+
typ := csInnerType(req, col, conf)
2121
if col.IsArray {
2222
return typ + "[]"
2323
}
2424

2525
return typ
2626
}
2727

28-
func csInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string {
28+
func csInnerType(req *plugin.CodeGenRequest, col *plugin.Column, conf *Config) string {
2929
columnType := sdk.DataType(col.Type)
3030
notNull := col.NotNull || col.IsArray
3131

@@ -41,7 +41,7 @@ func csInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string {
4141

4242
switch req.Settings.Engine {
4343
case "postgresql":
44-
return PostgresType(req, col)
44+
return PostgresType(req, col, conf)
4545
default:
4646
return "object"
4747
}

internal/core/gen.go

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
plugin "github.com/tabbed/sqlc-go/codegen"
1010
"github.com/tabbed/sqlc-go/metadata"
1111
"github.com/tabbed/sqlc-go/sdk"
12+
13+
"github.com/hyperbeam/sqlc-gen-cs/internal/inflection"
1214
)
1315

1416
// Column tagged with an ID for matching parameters used multiple times in queries
@@ -64,7 +66,7 @@ func BuildEnums(req *plugin.CodeGenRequest) []Enum {
6466
return enums
6567
}
6668

67-
func BuildClasses(req *plugin.CodeGenRequest) []Class {
69+
func BuildClasses(req *plugin.CodeGenRequest, conf Config) []Class {
6870
log.Println("Building classes...")
6971
var classes []Class
7072
for _, schema := range req.Catalog.Schemas {
@@ -80,18 +82,33 @@ func BuildClasses(req *plugin.CodeGenRequest) []Class {
8082
}
8183
className := tableName
8284

85+
if !conf.EmitExactTableNames {
86+
className = inflection.Singular(inflection.SingularParams{
87+
Name: tableName,
88+
Exclusions: conf.InflectionExcludeTableNames,
89+
})
90+
}
91+
8392
c := Class{
8493
Table: &plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name},
8594
Name: ClassName(className, req.Settings),
8695
Comment: table.Comment,
8796
}
8897

8998
for _, column := range table.Columns {
90-
c.Members = append(c.Members, ClassMember{
99+
member := ClassMember{
91100
Name: ClassName(column.Name, req.Settings),
92-
Type: CsType(req, column),
101+
Type: CsType(req, column, &conf),
93102
Comment: column.Comment,
94-
})
103+
}
104+
105+
if conf.EmitNullOperators {
106+
member.NotNull = column.NotNull
107+
} else {
108+
member.NotNull = false
109+
}
110+
111+
c.Members = append(c.Members, member)
95112
}
96113

97114
classes = append(classes, c)
@@ -133,9 +150,14 @@ func BuildQueries(req *plugin.CodeGenRequest, conf Config, classes []Class) ([]Q
133150
gq.Arg = QueryValue{
134151
Name: paramName(p),
135152
DBName: p.Column.Name,
136-
Typ: CsType(req, p.Column),
153+
Typ: CsType(req, p.Column, &conf),
137154
Column: p.Column,
138155
}
156+
if conf.EmitNullOperators {
157+
gq.Arg.NotNull = p.Column.NotNull
158+
} else {
159+
gq.Arg.NotNull = false
160+
}
139161
} else if len(query.Params) >= 1 {
140162
var cols []codeColumn
141163
for _, p := range query.Params {
@@ -144,7 +166,7 @@ func BuildQueries(req *plugin.CodeGenRequest, conf Config, classes []Class) ([]Q
144166
Column: p.Column,
145167
})
146168
}
147-
c, err := columnsToClass(req, gq.MethodName+"Params", cols, false)
169+
c, err := columnsToClass(&conf, req, gq.MethodName+"Params", cols, false)
148170
if err != nil {
149171
log.Println("Error in arguments: ", err)
150172
return nil, err
@@ -169,7 +191,13 @@ func BuildQueries(req *plugin.CodeGenRequest, conf Config, classes []Class) ([]Q
169191
gq.Ret = QueryValue{
170192
Name: name,
171193
DBName: name,
172-
Typ: CsType(req, c),
194+
Typ: CsType(req, c, &conf),
195+
}
196+
197+
if conf.EmitNullOperators && !strings.HasSuffix(gq.Ret.Typ, "?") {
198+
gq.Ret.NotNull = true
199+
} else {
200+
gq.Ret.NotNull = false
173201
}
174202
} else if putOutColumns(query) {
175203
var gs *Class
@@ -183,7 +211,7 @@ func BuildQueries(req *plugin.CodeGenRequest, conf Config, classes []Class) ([]Q
183211
for i, f := range class.Members {
184212
c := query.Columns[i]
185213
sameName := f.Name == ClassName(columnName(c, i), req.Settings)
186-
sameType := f.Type == CsType(req, c)
214+
sameType := f.Type == CsType(req, c, &conf)
187215
sameTable := sdk.SameTableName(c.Table, class.Table, req.Catalog.DefaultSchema)
188216
if !sameName || !sameType || !sameTable {
189217
same = false
@@ -204,7 +232,7 @@ func BuildQueries(req *plugin.CodeGenRequest, conf Config, classes []Class) ([]Q
204232
})
205233
}
206234
var err error
207-
gs, err = columnsToClass(req, gq.MethodName+"Row", columns, true)
235+
gs, err = columnsToClass(&conf, req, gq.MethodName+"Row", columns, true)
208236
if err != nil {
209237
return nil, err
210238
}
@@ -264,7 +292,7 @@ func putOutColumns(query *plugin.Query) bool {
264292
return false
265293
}
266294

267-
func columnsToClass(req *plugin.CodeGenRequest, name string, columns []codeColumn, useID bool) (*Class, error) {
295+
func columnsToClass(conf *Config, req *plugin.CodeGenRequest, name string, columns []codeColumn, useID bool) (*Class, error) {
268296
class := Class{
269297
Name: name,
270298
}
@@ -292,7 +320,13 @@ func columnsToClass(req *plugin.CodeGenRequest, name string, columns []codeColum
292320
Name: memberName,
293321
DBName: colName,
294322
Column: c.Column,
295-
Type: CsType(req, c.Column),
323+
Type: CsType(req, c.Column, conf),
324+
}
325+
326+
if conf.EmitNullOperators {
327+
member.NotNull = c.Column.NotNull
328+
} else {
329+
member.NotNull = false
296330
}
297331

298332
class.Members = append(class.Members, member)

internal/core/postgresql_type.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
sdk "github.com/tabbed/sqlc-go/sdk"
66
)
77

8-
func PostgresType(req *plugin.CodeGenRequest, col *plugin.Column) string {
8+
func PostgresType(req *plugin.CodeGenRequest, col *plugin.Column, conf *Config) string {
99
var csType string
1010
columnType := sdk.DataType(col.Type)
1111

@@ -137,7 +137,7 @@ func PostgresType(req *plugin.CodeGenRequest, col *plugin.Column) string {
137137

138138
if col.IsArray {
139139
return csType + "[]"
140-
} else if !col.NotNull {
140+
} else if !col.NotNull && conf.EmitNullOperators {
141141
return csType + "?"
142142
} else {
143143
return csType

internal/core/query.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ func (q Query) HasArgs() bool {
2727
// QueryValue is the holder for our IO part of the query
2828
// It exists to hold a new class, or an existing one.
2929
type QueryValue struct {
30-
Emit bool
31-
Name string
32-
DBName string
33-
Class *Class
34-
Typ string
30+
Emit bool
31+
Name string
32+
DBName string
33+
Class *Class
34+
Typ string
35+
NotNull bool
3536

3637
Column *plugin.Column
3738
}
@@ -58,6 +59,19 @@ func (v QueryValue) Type() string {
5859
panic("no type for QueryValue: " + v.Name)
5960
}
6061

62+
func (v QueryValue) EmitReturnType(emitNull bool) string {
63+
if !emitNull {
64+
return v.Type()
65+
}
66+
67+
// Return types may always be null
68+
if strings.HasSuffix(v.Type(), "?") {
69+
return v.Type()
70+
} else {
71+
return v.Type() + "?"
72+
}
73+
}
74+
6175
func (v QueryValue) Pair() string {
6276
log.Println("Arg value pair: ", v)
6377
if v.isEmpty() {
@@ -70,10 +84,10 @@ func (v QueryValue) Pair() string {
7084
out = append(out, f.Type+" "+strings.ToLower(f.Name))
7185
}
7286

73-
return ", " + strings.Join(out, ",")
87+
return strings.Join(out, ", ")
7488
}
7589

76-
return ", " + v.Type() + " " + v.Name
90+
return v.Type() + " " + v.Name
7791
}
7892

7993
func (v QueryValue) UniqueMembers() []ClassMember {

internal/gen.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ var version string
2222

2323
type TemplateCtx struct {
2424
EmitAsync bool
25+
EmitNulls bool
2526
SqlcVersion string
2627
CsGenVersion string
2728
Namespace string
@@ -40,6 +41,10 @@ func (t *TemplateCtx) ClassName() {
4041
}
4142

4243
func Generate(ctx context.Context, req *plugin.Request) (*plugin.Response, error) {
44+
if version == "" {
45+
version = "0.1.0"
46+
}
47+
4348
var conf core.Config
4449
if len(req.PluginOptions) > 0 {
4550
if err := json.Unmarshal(req.PluginOptions, &conf); err != nil {
@@ -59,7 +64,7 @@ func Generate(ctx context.Context, req *plugin.Request) (*plugin.Response, error
5964

6065
log.Println("Beginning generation with config: ", conf)
6166
enums := core.BuildEnums(req)
62-
classes := core.BuildClasses(req)
67+
classes := core.BuildClasses(req, conf)
6368
queries, err := core.BuildQueries(req, conf, classes)
6469
log.Println("queries built: ", queries)
6570
if err != nil {
@@ -68,6 +73,7 @@ func Generate(ctx context.Context, req *plugin.Request) (*plugin.Response, error
6873

6974
tctx := TemplateCtx{
7075
EmitAsync: conf.EmitAsync,
76+
EmitNulls: conf.EmitNullOperators,
7177
SqlcVersion: req.SqlcVersion,
7278
CsGenVersion: version,
7379
Namespace: conf.Namespace,

internal/inflection/singular.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package inflection
2+
3+
import (
4+
"strings"
5+
6+
upstream "github.com/jinzhu/inflection"
7+
)
8+
9+
type SingularParams struct {
10+
Name string
11+
Exclusions []string
12+
}
13+
14+
func Singular(s SingularParams) string {
15+
for _, exclusion := range s.Exclusions {
16+
if strings.EqualFold(s.Name, exclusion) {
17+
return s.Name
18+
}
19+
}
20+
21+
// Manual fix for incorrect handling of "campus"
22+
//
23+
// https://github.com/kyleconroy/sqlc/issues/430
24+
// https://github.com/jinzhu/inflection/issues/13
25+
if strings.ToLower(s.Name) == "campus" {
26+
return s.Name
27+
}
28+
// Manual fix for incorrect handling of "meta"
29+
//
30+
// https://github.com/kyleconroy/sqlc/issues/1217
31+
// https://github.com/jinzhu/inflection/issues/21
32+
if strings.ToLower(s.Name) == "meta" {
33+
return s.Name
34+
}
35+
return upstream.Singular(s.Name)
36+
}

0 commit comments

Comments
 (0)