Skip to content

Commit 98aa1b9

Browse files
committed
Fix ORDER BY with NULL for MS SQL
1 parent b9e52bc commit 98aa1b9

File tree

2 files changed

+141
-3
lines changed

2 files changed

+141
-3
lines changed

scalasql/src/dialects/MsSqlDialect.scala

+121-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package scalasql.dialects
22

3-
import scalasql.core.{Aggregatable, DbApi, DialectTypeMappers, Expr, TypeMapper}
4-
import scalasql.operations
5-
import scalasql.core.SqlStr.SqlStringSyntax
3+
import scalasql.query.{AscDesc, GroupBy, Join, Nulls, OrderBy, SubqueryRef, Table}
4+
import scalasql.core.{Aggregatable, Context, DbApi, DialectTypeMappers, Expr, Queryable, TypeMapper, SqlStr}
5+
import scalasql.{Sc, operations}
6+
import scalasql.core.SqlStr.{Renderable, SqlStringSyntax}
67
import scalasql.operations.{ConcatOps, MathOps, TrimOps}
78

89
import java.time.{Instant, LocalDateTime, OffsetDateTime}
@@ -43,6 +44,9 @@ trait MsSqlDialect extends Dialect {
4344
): MsSqlDialect.ExprStringLikeOps[geny.Bytes] =
4445
new MsSqlDialect.ExprStringLikeOps(v)
4546

47+
override implicit def TableOpsConv[V[_[_]]](t: Table[V]): scalasql.dialects.TableOps[V] =
48+
new MsSqlDialect.TableOps(t)
49+
4650
implicit def ExprAggOpsConv[T](v: Aggregatable[Expr[T]]): operations.ExprAggOps[T] =
4751
new MsSqlDialect.ExprAggOps(v)
4852

@@ -95,4 +99,118 @@ object MsSqlDialect extends MsSqlDialect {
9599
def indexOf(x: Expr[T]): Expr[Int] = Expr { implicit ctx => sql"CHARINDEX($x, $v)" }
96100
def reverse: Expr[T] = Expr { implicit ctx => sql"REVERSE($v)" }
97101
}
102+
103+
class TableOps[V[_[_]]](t: Table[V]) extends scalasql.dialects.TableOps[V](t) {
104+
105+
protected override def joinableToSelect: Select[V[Expr], V[Sc]] = {
106+
val ref = Table.ref(t)
107+
new SimpleSelect(
108+
Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]],
109+
None,
110+
false,
111+
Seq(ref),
112+
Nil,
113+
Nil,
114+
None
115+
)(
116+
t.containerQr
117+
)
118+
}
119+
}
120+
121+
trait Select[Q, R] extends scalasql.query.Select[Q, R] {
122+
override def newCompoundSelect[Q, R](
123+
lhs: scalasql.query.SimpleSelect[Q, R],
124+
compoundOps: Seq[scalasql.query.CompoundSelect.Op[Q, R]],
125+
orderBy: Seq[OrderBy],
126+
limit: Option[Int],
127+
offset: Option[Int]
128+
)(
129+
implicit qr: Queryable.Row[Q, R],
130+
dialect: scalasql.core.DialectTypeMappers
131+
): scalasql.query.CompoundSelect[Q, R] = {
132+
new CompoundSelect(lhs, compoundOps, orderBy, limit, offset)
133+
}
134+
135+
override def newSimpleSelect[Q, R](
136+
expr: Q,
137+
exprPrefix: Option[Context => SqlStr],
138+
preserveAll: Boolean,
139+
from: Seq[Context.From],
140+
joins: Seq[Join],
141+
where: Seq[Expr[?]],
142+
groupBy0: Option[GroupBy]
143+
)(
144+
implicit qr: Queryable.Row[Q, R],
145+
dialect: scalasql.core.DialectTypeMappers
146+
): scalasql.query.SimpleSelect[Q, R] = {
147+
new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
148+
}
149+
}
150+
151+
class SimpleSelect[Q, R](
152+
expr: Q,
153+
exprPrefix: Option[Context => SqlStr],
154+
preserveAll: Boolean,
155+
from: Seq[Context.From],
156+
joins: Seq[Join],
157+
where: Seq[Expr[?]],
158+
groupBy0: Option[GroupBy]
159+
)(implicit qr: Queryable.Row[Q, R])
160+
extends scalasql.query.SimpleSelect(
161+
expr,
162+
exprPrefix,
163+
preserveAll,
164+
from,
165+
joins,
166+
where,
167+
groupBy0
168+
)
169+
with Select[Q, R]
170+
171+
class CompoundSelect[Q, R](
172+
lhs: scalasql.query.SimpleSelect[Q, R],
173+
compoundOps: Seq[scalasql.query.CompoundSelect.Op[Q, R]],
174+
orderBy: Seq[OrderBy],
175+
limit: Option[Int],
176+
offset: Option[Int]
177+
)(implicit qr: Queryable.Row[Q, R])
178+
extends scalasql.query.CompoundSelect(lhs, compoundOps, orderBy, limit, offset)
179+
with Select[Q, R] {
180+
protected override def selectRenderer(prevContext: Context): SubqueryRef.Wrapped.Renderer =
181+
new CompoundSelectRenderer(this, prevContext)
182+
}
183+
184+
class CompoundSelectRenderer[Q, R](
185+
query: scalasql.query.CompoundSelect[Q, R],
186+
prevContext: Context
187+
) extends scalasql.query.CompoundSelect.Renderer(query, prevContext) {
188+
189+
override lazy val limitOpt = SqlStr
190+
.flatten(CompoundSelectRendererForceLimit.limitToSqlStr(query.limit, query.offset))
191+
192+
override def orderToSqlStr(newCtx: Context) = {
193+
SqlStr.optSeq(query.orderBy) { orderBys =>
194+
val orderStr = SqlStr.join(
195+
orderBys.map { orderBy =>
196+
val exprStr = Renderable.renderSql(orderBy.expr)(newCtx)
197+
198+
(orderBy.ascDesc, orderBy.nulls) match {
199+
case (Some(AscDesc.Asc), None | Some(Nulls.First)) => sql"$exprStr ASC"
200+
case (Some(AscDesc.Desc), Some(Nulls.First)) =>
201+
sql"IIF($exprStr IS NULL, 0, 1), $exprStr DESC"
202+
case (Some(AscDesc.Asc), Some(Nulls.Last)) => sql"IIF($exprStr IS NULL, 1, 0), $exprStr ASC"
203+
case (Some(AscDesc.Desc), None | Some(Nulls.Last)) => sql"$exprStr DESC"
204+
case (None, None) => exprStr
205+
case (None, Some(Nulls.First)) => sql"IIF($exprStr IS NULL, 0, 1), $exprStr"
206+
case (None, Some(Nulls.Last)) => sql"IIF($exprStr IS NULL, 1, 0), $exprStr"
207+
}
208+
},
209+
SqlStr.commaSep
210+
)
211+
212+
sql" ORDER BY $orderStr"
213+
}
214+
}
215+
}
98216
}

scalasql/test/src/datatypes/OptionalTests.scala

+20
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,11 @@ trait OptionalTests extends ScalaSqlSuite {
399399
SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2
400400
FROM opt_cols opt_cols0
401401
ORDER BY my_int IS NULL ASC, my_int
402+
""",
403+
"""
404+
SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2
405+
FROM opt_cols opt_cols0
406+
ORDER BY IIF(my_int IS NULL, 1, 0), my_int
402407
"""
403408
),
404409
value = Seq(
@@ -423,6 +428,11 @@ trait OptionalTests extends ScalaSqlSuite {
423428
SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2
424429
FROM opt_cols opt_cols0
425430
ORDER BY my_int IS NULL DESC, my_int
431+
""",
432+
"""
433+
SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2
434+
FROM opt_cols opt_cols0
435+
ORDER BY IIF(my_int IS NULL, 0, 1), my_int
426436
"""
427437
),
428438
value = Seq(
@@ -444,6 +454,11 @@ trait OptionalTests extends ScalaSqlSuite {
444454
SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2
445455
FROM opt_cols opt_cols0
446456
ORDER BY my_int IS NULL ASC, my_int ASC
457+
""",
458+
"""
459+
SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2
460+
FROM opt_cols opt_cols0
461+
ORDER BY IIF(my_int IS NULL, 1, 0), my_int ASC
447462
"""
448463
),
449464
value = Seq(
@@ -507,6 +522,11 @@ trait OptionalTests extends ScalaSqlSuite {
507522
SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2
508523
FROM opt_cols opt_cols0
509524
ORDER BY my_int IS NULL DESC, my_int DESC
525+
""",
526+
"""
527+
SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2
528+
FROM opt_cols opt_cols0
529+
ORDER BY IIF(my_int IS NULL, 0, 1), my_int DESC
510530
"""
511531
),
512532
value = Seq(

0 commit comments

Comments
 (0)