Skip to content

Commit aa956a7

Browse files
committed
multiple disjointness
1 parent ec5c55e commit aa956a7

File tree

4 files changed

+20
-15
lines changed

4 files changed

+20
-15
lines changed

hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,17 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
9898
v.state.lowerBounds ::= bd
9999
v.state.upperBounds.foreach(ub => constrainImpl(bd, ub))
100100
v.state.disjsub.foreach: d =>
101-
Type.disjoint(d.disjoint(v), bd.toBasic.simp.toBasic)(Set.empty)(using c = mutable.Map.empty) match
102-
case N =>
103-
d.remove(v)
104-
if d.disjoint.isEmpty then
105-
d.dss.foreach(_.commit())
106-
d.cs.foreach((a, b) => constrainImpl(a, b))
107-
case S(k) =>
108-
k.foreach(k => DisjSub(d.disjoint ++ k, d.dss, d.cs).commit())
101+
val u = d.disjoint(v).flatMap: t =>
102+
Type.disjoint(t, bd.toBasic.simp.toBasic)(Set.empty)(using c = mutable.Map.empty)
103+
if u.isEmpty then
104+
d.remove(v)
105+
if d.disjoint.isEmpty then
106+
d.dss.foreach(_.commit())
107+
d.cs.foreach((a, b) => constrainImpl(a, b))
108+
else
109+
d.clear()
110+
u.reduce((x, y) => y.flatMap(y => x.map(_ ++ y))).foreach: k =>
111+
DisjSub(mutable.Map.from(k.groupMap(_._1)(_._2)), d.dss, d.cs).commit()
109112
case Conj(i, u, Nil) => (conj.i, conj.u) match
110113
case (_, Union(N, Nil)) =>
111114
// raise(ErrorReport(msg"Cannot solve ${conj.i.toString()} ∧ ¬⊥" -> N :: Nil))
@@ -133,7 +136,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
133136
else
134137
val cs = (ret1, ret2) :: (eff1, eff2) :: args2.zip(args1)
135138
k.reduce((x, y) => y.flatMap(y => x.map(_ ++ y))).foreach: k =>
136-
DisjSub(mutable.Map.from(k), Nil, cs).commit()
139+
DisjSub(mutable.Map.from(k.groupMap(_._1)(_._2)), Nil, cs).commit()
137140
case (Inter(S(fs:Ls[FunType])), Union(S(FunType(args2, ret2, eff2)), Nil)) =>
138141
val f = fs.filter(_.args.length === args2.length)
139142
val args = f.map(_.args).transpose
@@ -152,7 +155,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
152155
constrainImpl(b.eff, eff2)
153156
case S(k) =>
154157
val cs = (b.ret,ret2) :: (b.eff,eff2) :: s
155-
k.foreach(k => DisjSub(mutable.Map.from(k), Nil, cs).commit())
158+
k.foreach(k => DisjSub(mutable.Map.from(k.groupMap(_._1)(_._2)), Nil, cs).commit())
156159
case _ =>
157160
// raise(ErrorReport(msg"Cannot solve ${conj.i.toString()} <: ${conj.u.toString()}" -> N :: Nil))
158161
cctx.err

hkmc2/shared/src/main/scala/hkmc2/bbml/PrettyPrinter.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ object PrettyPrinter:
3838
apply(false)(bd)
3939
(v, bd)
4040
res ++= state.disjsub.map: d =>
41-
val ds = d.disjoint.iterator
41+
val ds = d.disjoint.iterator.flatMap:
42+
case (v, u) => u.map(v -> _)
4243
val k = ds.next()
4344
(k, ds.toList, d.cs.toList)
4445
super.apply(pol)(ty)

hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,9 @@ class VarState:
422422
val disjsub: MutSet[DisjSub] = MutSet.empty
423423
override def toString = "<>"
424424

425-
case class DisjSub(disjoint: MutMap[InfVar,BasicType], dss:Ls[DisjSub], cs:Ls[Type->Type]):
425+
case class DisjSub(disjoint: MutMap[InfVar, Set[BasicType]], dss:Ls[DisjSub], cs:Ls[Type->Type]):
426426
def commit() = disjoint.keys.foreach(_.state.disjsub += this)
427+
def clear() = disjoint.keys.foreach(_.state.disjsub -= this)
427428
def remove(v:InfVar)=
428-
v.state.disjsub-=this
429-
disjoint-=v
429+
v.state.disjsub -= this
430+
disjoint -= v

hkmc2/shared/src/test/mlscript/bbml/DisjSub.mls

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ x=>
3737
//│ Where:
3838
//│ 'x <: Int ∨ Bool
3939
//│ 'x <: Int
40-
//│ 'x#Int ∨ Int<:'app ∧ ⊥<:'eff}
4140
//│ 'x#Bool ∨ Bool<:'app ∧ ⊥<:'eff}
41+
//│ 'x#Int ∨ Int<:'app ∧ ⊥<:'eff}
4242

4343
fun ap1(f)=f(1)
4444
ap1(idIB)

0 commit comments

Comments
 (0)