Skip to content

Logic subtyping and overloading #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 43 commits into
base: hkmc2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9707a3f
disjunctive subtyping
auht Feb 25, 2025
cc87059
Changes from meeting
LPTK Feb 26, 2025
0298463
disjoint upperbound
auht Mar 5, 2025
bf4c763
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Mar 5, 2025
ec5c55e
no disjoint upperbound
auht Mar 5, 2025
aa956a7
multiple disjointness
auht Mar 5, 2025
8bad1a5
Changes from meeting
LPTK Mar 6, 2025
86d300d
Add test case and move tests to logicsub folder
LPTK Mar 6, 2025
3b34199
constraints solving nested disjsub
auht Mar 7, 2025
ae28b42
ues linkedhashset
auht Mar 10, 2025
2f1f5f3
traverse disjsub
auht Mar 12, 2025
3d383a2
Changes from meeting
LPTK Mar 13, 2025
00e397b
fix pretty printer type traverse and tv subst
auht Mar 13, 2025
bdd4b7a
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Mar 13, 2025
0bbfeca
wip rcdtype implementation
auht Mar 20, 2025
3a67c1f
wip rcdtype implementation and fun args disjointness
auht Mar 22, 2025
bd53f6f
intersections wf check
auht Mar 24, 2025
92dc44a
fix nested record wf check
auht Mar 26, 2025
69bcc5a
wip refined if
auht Mar 28, 2025
081de12
Update hkmc2/shared/src/test/mlscript/logicsub/Records.mls
auht Mar 28, 2025
52c6941
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Mar 28, 2025
55b9066
wip fix
auht Apr 2, 2025
d99f3af
wip else branch disjointness
auht Apr 4, 2025
93d7ff6
fix if disjsub
auht Apr 6, 2025
3cd0f96
wip if
auht Apr 8, 2025
4996431
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Apr 8, 2025
3928da0
test
auht Apr 8, 2025
88e3518
Pretty printer changes from meeting
LPTK Apr 11, 2025
f77ddbb
dnf disjointness
auht Apr 14, 2025
f009711
fix if
auht Apr 16, 2025
fd21398
fix if
auht Apr 16, 2025
4dc8cd7
fix missing constraints
auht Apr 18, 2025
0ed746e
test
auht Apr 18, 2025
836f3bb
elim branches
auht Apr 20, 2025
a9c4221
negtype and wip test
auht Apr 22, 2025
3bfe1c4
wip test explanation
auht Apr 22, 2025
91ac0d4
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Apr 22, 2025
99b465b
rcd union neg disjointess
auht Apr 28, 2025
9e13a34
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Apr 28, 2025
afb08c5
wf check and typing function intersection
auht May 4, 2025
2cfa234
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht May 4, 2025
69feb99
test
auht May 4, 2025
4a335ea
modify disjointness signature and impl
auht May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ case class Config(
sanityChecks: Opt[SanityChecks],
effectHandlers: Opt[EffectHandlers],
liftDefns: Opt[LiftDefns],
simplifyTypes: Bool,
):

def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety)
Expand All @@ -26,6 +27,7 @@ object Config:
// sanityChecks = S(SanityChecks(light = true)),
effectHandlers = N,
liftDefns = N,
simplifyTypes = true,
)

case class SanityChecks(light: Bool)
Expand Down
91 changes: 76 additions & 15 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
else
v.state.lowerBounds ::= nv
nv.state.upperBounds = v.state.upperBounds.map(extrude) // * propagate
nv.state.disjsub ++= v.state.disjsub.map:
case DisjSub(ds, dss, cs) =>
val d = ds.mapKeys(v0 => if v === v0 then nv else v0)
DisjSub(mutable.LinkedHashSet.from(d), dss, cs)
nv.state.disjsub.foreach(_.commit())
nv
})
case ft @ FunType(args, ret, eff) =>
Expand Down Expand Up @@ -97,26 +102,82 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
cctx.nest(bd -> v) givenIn:
v.state.lowerBounds ::= bd
v.state.upperBounds.foreach(ub => constrainImpl(bd, ub))
v.state.disjsub.toList.flatMap(_.check(v)).foreach:
case (a, b) => constrainImpl(a, b)
case Conj(i, u, Nil) => (conj.i, conj.u) match
case (_, Union(N, Nil)) =>
case (_, Union(N, Nil, Nil)) =>
// raise(ErrorReport(msg"Cannot solve ${conj.i.toString()} ∧ ¬⊥" -> N :: Nil))
cctx.err
case (Inter(S(ClassLikeType(cls1, targs1))), Union(f, ClassLikeType(cls2, targs2) :: rest)) =>
case (Inter(S(ClassLikeType(cls1, targs1))), Union(f, ClassLikeType(cls2, targs2) :: rest, rcd)) =>
if cls1.uid === cls2.uid then
targs1.zip(targs2).foreach: (ta1, ta2) =>
constrainArgs(ta1, ta2)
else constrainConj(Conj(conj.i, Union(f, rest), Nil))
case (int: Inter, Union(f, _ :: rest)) => constrainConj(Conj(int, Union(f, rest), Nil))
case (Inter(S(FunType(args1, ret1, eff1))), Union(S(FunType(args2, ret2, eff2)), Nil)) =>
if args1.length =/= args2.length then
// raise(ErrorReport(msg"Cannot constrain ${conj.i.toString()} <: ${conj.u.toString()}" -> N :: Nil))
cctx.err
else
args1.zip(args2).foreach {
case (a1, a2) => constrainImpl(a2, a1)
}
constrainImpl(ret1, ret2)
constrainImpl(eff1, eff2)
else constrainConj(Conj(conj.i, Union(f, rest, rcd), Nil))
case (int: Inter, Union(f, _ :: rest, rcd)) => constrainConj(Conj(int, Union(f, rest, rcd), Nil))
case (Inter(S(RcdType(u))), Union(f, Nil, RcdType(w) :: Nil)) =>
val um = u.toMap
val wm = w.toMap
val k = w.keys.toSet
if k.subsetOf(um.keySet) then
k.foreach(k => constrainImpl(um(k), wm(k)))
else cctx.err
case (Inter(S(u: RcdType)), Union(f, Nil, rs@(RcdType(w) :: _))) =>
val ws = rs.foldLeft(Nil): (x, w) =>
val d = Type.disjoint(w, u)
if d === S(Set.empty) then x else Ls(d -> w) ++ x
ws match
case Nil => cctx.err
case (_, w) :: Nil => constrainImpl(u, w)
case ((_, RcdType(w)) :: (_, RcdType(z)) :: _) =>
val (wm, zm) = (w.toMap, z.toMap)
val k = wm.keySet & zm.keySet
val dk = k.find(k => Type.disjoint(wm(k), zm(k)) === S(Set.empty)).get
val ku = ws.foldLeft(Bot: Type):
case (ku, (_, w)) => ku | w.fields.find(_._1 === dk).get._2
ws.foreach:
case (S(k), w) => k.foreach: k =>
DisjSub(mutable.LinkedHashSet.from(k), Nil, Ls(u -> RcdType(w.fields.filter(_._1 =/= k)))).commit()
case _ =>
constrainImpl(u.fields.find(_._1 === dk).get._2, ku)
ws.foreach:
case (N, w) => constrainImpl(u, RcdType(w.fields.filter(_._1 =/= dk)))
case _ =>
case (Inter(S(fs: Ls[FunType])), Union(S(FunType(args2, ret2, eff2)), Nil, Nil)) =>
val k = args2.flatMap(x => Type.disjoint(x, x))
if k.forall(_.nonEmpty) then
val f = fs.filter(_.args.length === args2.length)
if args2.isEmpty then
if f.isEmpty then
cctx.err
else f.foreach: f =>
constrainImpl(f.ret, ret2)
constrainImpl(f.eff, eff2)
else
val args = f.map(x => Type.discriminant(x.args))
val args2r = args2.zipWithIndex.map(u => (s"${u._2}", u._1))
val args2q = RcdType(args2r)
val (cs, dss) = (args.iterator.zip(f).map:
case ((q, r), f) =>
val rm = r.fields.toMap
val rcs = args2r.flatMap(u => rm.get(u._1).filter(_ =/= Top).map(u._2 -> _))
val cs = (f.ret, ret2) :: (f.eff, eff2) :: rcs
Type.disjoint(q, args2q) match
case N => (cs, Nil)
case S(k) =>
(Nil, k.map(k => DisjSub(mutable.LinkedHashSet.from(k), Nil, cs)))).toList.unzip
val c = (args2q, args.foldLeft(Bot: Type) { case (t, (q, _)) => t | q })
if k.isEmpty then
if f.isEmpty then
cctx.err
else
dss.flatten.foreach(_.commit())
constrainImpl(c._1, c._2)
cs.flatten.foreach(u => constrainImpl(u._1, u._2))
else
val cs0 = c :: cs.flatten
val dss0 = dss.flatten
k.reduce((x, y) => y.flatMap(y => x.map(_ ++ y))).foreach: k =>
DisjSub(mutable.LinkedHashSet.from(k), dss0, cs0).commit()
case _ =>
// raise(ErrorReport(msg"Cannot solve ${conj.i.toString()} <: ${conj.u.toString()}" -> N :: Nil))
cctx.err
Expand All @@ -134,7 +195,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
inlineSkolemBounds(if pol then state.upperBounds.foldLeft[Type](v)(_ & _) else state.lowerBounds.foldLeft[Type](v)(_ | _), pol)
case ComposedType(lhs, rhs, p) => ComposedType(inlineSkolemBounds(lhs, pol), inlineSkolemBounds(rhs, pol), p)
case NegType(ty) => NegType(inlineSkolemBounds(ty, !pol))
case _: ClassLikeType | _: FunType | _: InfVar | Top | Bot => ty
case _: ClassLikeType | _: FunType | _: RcdType |_: InfVar | Top | Bot => ty

private def constrainImpl(lhs: Type, rhs: Type)(using BbCtx, CCtx, TL): Unit =
if cctx.cache((lhs, rhs)) then log(s"Cached!")
Expand Down
46 changes: 28 additions & 18 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,36 @@ object Conj:
}){}
lazy val empty: Conj = Conj(Inter.empty, Union.empty, Nil)
def mkVar(v: InfVar, pol: Bool) = Conj(Inter.empty, Union.empty, (v, pol) :: Nil)
def mkInter(inter: ClassLikeType | FunType) =
def mkInter(inter: ClassLikeType | Ls[FunType] | RcdType) =
Conj(Inter(S(inter)), Union.empty, Nil)
def mkUnion(union: ClassLikeType | FunType) =
def mkUnion(union: ClassLikeType | FunType | RcdType) =
Conj(Inter.empty, union match {
case cls: ClassLikeType => Union(N, cls :: Nil)
case fun: FunType => Union(S(fun), Nil)
case cls: ClassLikeType => Union(N, cls :: Nil, Nil)
case fun: FunType => Union(S(fun), Nil, Nil)
case r: RcdType => Union(N, Nil, Ls(r))
}, Nil)

// * Some(ClassType) -> C[in D_i out D_i], Some(FunType) -> D_1 ->{D_2} D_3, None -> Top
final case class Inter(v: Opt[ClassLikeType | FunType]) extends NormalForm:
final case class Inter(v: Opt[ClassLikeType | Ls[FunType] | RcdType]) extends NormalForm:
def isTop: Bool = v.isEmpty
def merge(other: Inter): Option[Inter] = (v, other.v) match
case (S(ClassLikeType(cls1, targs1)), S(ClassLikeType(cls2, targs2))) if cls1.uid === cls2.uid =>
S(Inter(S(ClassLikeType(cls1, targs1.lazyZip(targs2).map(_ & _)))))
case (S(_: ClassLikeType), S(_: ClassLikeType)) => N
case (S(FunType(a1, r1, e1)), S(FunType(a2, r2, e2))) =>
S(Inter(S(FunType(a1.lazyZip(a2).map(_ | _), r1 & r2, e1 & e2))))
// case (S(FunType(a1, r1, e1)), S(FunType(a2, r2, e2))) =>
// S(Inter(S(FunType(a1.lazyZip(a2).map(_ | _), r1 & r2, e1 & e2))))
case (S(a: Ls[FunType]), S(b: Ls[FunType])) => S(Inter(S(a ++ b)))
case (S(a: RcdType), S(b: RcdType)) => S(Inter(S(a & b)))
case (S(v), N) => S(Inter(S(v)))
case (N, v) => S(Inter(v))
case _ => N
def toBasic: BasicType = v.getOrElse(Top)
def toDnf(using TL): Disj = Disj(Conj(this, Union(N, Nil), Nil) :: Nil)
def toBasic: BasicType = v match
case N => Top
case S(x: ClassLikeType) => x
case S(Nil) => Top
case S(x: Ls[FunType]) => x.reduce[Type](_&_).toBasic
case S(x: RcdType) => x
def toDnf(using TL): Disj = Disj(Conj(this, Union(N, Nil, Nil), Nil) :: Nil)
override def show(using Scope): Str =
toBasic.show

Expand All @@ -106,11 +114,11 @@ object Inter:
lazy val empty: Inter = Inter(N)

// * fun: Some(FunType) -> D_1 ->{D_2} D_3, None -> bot
final case class Union(fun: Opt[FunType], cls: Ls[ClassLikeType])
final case class Union(fun: Opt[FunType], cls: Ls[ClassLikeType], rcd: Ls[RcdType])
extends NormalForm with CachedBasicType:
def isBot = fun.isEmpty && cls.isEmpty
def isBot = fun.isEmpty && cls.isEmpty && rcd.isEmpty
def toType = fun.getOrElse(Bot) |
cls.foldLeft[Type](Bot)(_ | _)
cls.foldLeft[Type](Bot)(_ | _) | rcd.foldLeft[Type](Bot)(_ | _)
def merge(other: Union): Union = Union((fun, other.fun) match {
case (S(FunType(a1, r1, e1)), S(FunType(a2, r2, e2))) =>
S(FunType(a1.lazyZip(a2).map(_ & _), r1 | r2, e1 | e2))
Expand All @@ -121,19 +129,19 @@ extends NormalForm with CachedBasicType:
case (cls1, cls2) => cls1.name.uid <= cls2.name.uid
}.foldLeft[Ls[ClassLikeType]](Nil)((res, cls) => (res, cls) match {
case (Nil, cls) => cls :: Nil
case (ClassLikeType(cls1, targs1) :: tail, ClassLikeType(cls2, targs2)) if cls1.uid === cls2.uid =>
case (ClassLikeType(cls1, targs1) :: tail, ClassLikeType(cls2, targs2)) if cls1.uid === cls2.uid =>
ClassLikeType(cls1, targs1.lazyZip(targs2).map(_ | _)) :: tail
case (head :: tail, cls) => cls :: head :: tail
}))
}), rcd ++ other.rcd)
def mkBasic: BasicType =
BasicType.union(fun.toList ::: cls)
BasicType.union(fun.toList ::: cls ::: rcd)
def toDnf(using TL): Disj = NormalForm.neg(this)
override def show(using Scope): Str =
toType.show

override def showDbg: Str = toType.showDbg
object Union:
val empty: Union = Union(N, Nil)
val empty: Union = Union(N, Nil, Nil)

sealed abstract class NormalForm extends TypeExt:
def toBasic: BasicType
Expand Down Expand Up @@ -167,6 +175,7 @@ object NormalForm:
case v: InfVar => Disj(Conj.mkVar(v, false) :: Nil)
case ct: ClassLikeType => Disj(Conj.mkUnion(ct) :: Nil)
case ft: FunType => Disj(Conj.mkUnion(ft) :: Nil)
case r: RcdType => Disj(Conj.mkUnion(r) :: Nil)
case ComposedType(lhs, rhs, pol) =>
if pol then inter(neg(lhs), neg(rhs)) else union(neg(lhs), neg(rhs))
case NegType(ty) => dnf(ty)
Expand All @@ -176,13 +185,14 @@ object NormalForm:
ty match
case d: Disj => d
case c: Conj => Disj(c :: Nil)
case i: Inter => Disj(Conj(i, Union(N, Nil), Nil) :: Nil)
case i: Inter => Disj(Conj(i, Union(N, Nil, Nil), Nil) :: Nil)
case _ => ty.toBasic match
case Top => Disj.top
case Bot => Disj.bot
case v: InfVar => Disj(Conj.mkVar(v, true) :: Nil)
case ct: ClassLikeType => Disj(Conj.mkInter(ct.toNorm) :: Nil)
case ft: FunType => Disj(Conj.mkInter(ft.toNorm) :: Nil)
case ft: FunType => Disj(Conj.mkInter(Ls(ft.toNorm)) :: Nil)
case r: RcdType => Disj(Conj.mkInter(r.toNorm) :: Nil)
case ComposedType(lhs, rhs, pol) =>
if pol then union(dnf(lhs), dnf(rhs)) else inter(dnf(lhs), dnf(rhs))
case NegType(ty) => neg(ty)
Expand Down
15 changes: 13 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/PrettyPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,29 @@ import scala.collection.mutable.{Set => MutSet, ListBuffer}
import utils.Scope

class PrettyPrinter(output: String => Unit)(using Scope):
def showDisjSub(ds: DisjSub): String = ds match
case DisjSub(d, dss, cs) =>
val g = d.iterator.map { case (x, y) => s"${x.show}#${y.show} ∨ " }.mkString
val h = dss.iterator.map("(" + showDisjSub(_) + ")").mkString(" ∧ ")
val b = cs.map { case (x, y) => s" ∧ ${x.simp.show}<:${y.simp.show}"}.mkString
s" $g$h$b"
def print(ty: GeneralType): Unit =
output(s"Type: ${ty.show}")
val bounds = PrettyPrinter.collectBounds(ty).distinct
if !bounds.isEmpty then
output("Where:")
bounds.foreach {
case (lhs, rhs) => output(s" ${lhs.show} <: ${rhs.show}")
case ds: DisjSub => output(showDisjSub(ds))
}

object PrettyPrinter:
def apply(output: String => Unit)(using Scope): PrettyPrinter = new PrettyPrinter(output)

type Bound = (Type, Type) // * Type <: Type

private def collectBounds(ty: GeneralType): List[Bound] =
val res = ListBuffer[Bound]()
private def collectBounds(ty: GeneralType): List[Bound | DisjSub] =
val res = ListBuffer[Bound | DisjSub]()
val cache = MutSet[Uid[InfVar]]()
object CollectBounds extends TypeTraverser:
override def apply(pol: Boolean)(ty: GeneralType): Unit = ty match
Expand All @@ -31,6 +38,10 @@ object PrettyPrinter:
res ++= state.upperBounds.map: bd =>
apply(false)(bd)
(v, bd)
res ++= state.disjsub
val (p, n) = state.disjsub.map(_.children()).unzip
p.flatten.foreach(apply(true))
n.flatten.foreach(apply(false))
super.apply(pol)(ty)
case _ => super.apply(pol)(ty)
CollectBounds(true)(ty)
Expand Down
33 changes: 22 additions & 11 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ final def printPol(pol: Bool): Str = pol match {
case false => "-"
}

class TypeSimplifier(tl: TraceLogger):
class TypeSimplifier(using tl: TL):
import tl.{trace, log}

def apply(pol: Bool, lvl: Int)(ty: GeneralType): GeneralType =
Expand Down Expand Up @@ -60,6 +60,8 @@ class TypeSimplifier(tl: TraceLogger):
val varSubst: MutMap[IV, IV] = MutMap.empty

val traversedTVs: MutSet[IV] = MutSet.empty

val traversedDisjSub: MutSet[DisjSub] = MutSet.empty

def getRepr(tv: IV): IV = varSubst.get(tv) match {
case S(tv2) =>
Expand Down Expand Up @@ -129,6 +131,13 @@ class TypeSimplifier(tl: TraceLogger):
// traversingTVs += tv
// traversedTVs += tv
super.apply(pol)(ty)
val (p, n) = (tv.state.disjsub.flatMap: ds =>
ds.subDisjSub.map: k =>
if traversedDisjSub.add(ds) then
ds.children()
else (Nil, Nil)).unzip
p.flatten.foreach(apply(true))
n.flatten.foreach(apply(false))
// traversingTVs -= tv
curPath = oldPath
case pt @ PolyType(tvs, outer, _) => // Avoid simplify outer variables to Top unexpectedly
Expand Down Expand Up @@ -191,15 +200,17 @@ class TypeSimplifier(tl: TraceLogger):
tv.state.upperBounds = newUBs
val isPos = Analysis.posVars.contains(tv)
val isNeg = Analysis.negVars.contains(tv)
// if (isPos && !isNeg && (Analysis.occsNum(tv) === 1 && {newLBs match { case (tv: IV) :: Nil => true; case _ => false }} || newLBs.forall(_.isSmall))) {
if isPos && !isNeg && ({newLBs match { case (tv: IV) :: Nil => true; case _ => false }} || newLBs.forall(_ => true)) then {
// if (isPos && !isNeg && ({newLBs match { case (tv: IV) :: Nil => true; case _ => false }})) {
newLBs.foldLeft(Bot: Type)(_ | _)
} else
// if (isNeg && !isPos && (Analysis.occsNum(tv) === 1 && {newUBs match { case (tv: IV) :: Nil => true; case _ => false }} || newUBs.forall(_.isSmall))) {
if isNeg && !isPos && ({newUBs match { case (tv: IV) :: Nil => true; case _ => false }} || newUBs.forall(_ => true)) then
// if (isNeg && !isPos && ({newUBs match { case (tv: IV) :: Nil => true; case _ => false }})) {
newUBs.foldLeft(Top: Type)(_ & _)
if tv.state.disjsub.isEmpty then
// if (isPos && !isNeg && (Analysis.occsNum(tv) === 1 && {newLBs match { case (tv: IV) :: Nil => true; case _ => false }} || newLBs.forall(_.isSmall))) {
if isPos && !isNeg && ({newLBs match { case (tv: IV) :: Nil => true; case _ => false }} || newLBs.forall(_ => true)) then {
// if (isPos && !isNeg && ({newLBs match { case (tv: IV) :: Nil => true; case _ => false }})) {
newLBs.foldLeft(Bot: Type)(_ | _)
} else
// if (isNeg && !isPos && (Analysis.occsNum(tv) === 1 && {newUBs match { case (tv: IV) :: Nil => true; case _ => false }} || newUBs.forall(_.isSmall))) {
if isNeg && !isPos && ({newUBs match { case (tv: IV) :: Nil => true; case _ => false }} || newUBs.forall(_ => true)) then
// if (isNeg && !isPos && ({newUBs match { case (tv: IV) :: Nil => true; case _ => false }})) {
newUBs.foldLeft(Top: Type)(_ & _)
else tv
else
// tv.lowerBounds = newLBs
// tv.upperBounds = newUBs
Expand All @@ -211,7 +222,7 @@ class TypeSimplifier(tl: TraceLogger):

subst(ty)

def simplifyForall(ty: GeneralType): GeneralType = ty match
def simplifyForall(ty: GeneralType)(using TL): GeneralType = ty match
case PolyType(tvs, outer, body) =>
val newBody = simplifyForall(body)
val visited = PolyType.collectTVs(newBody)
Expand Down
1 change: 1 addition & 0 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/TypeTraverser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TypeTraverser:
case ty: Type =>
apply(pol)(ty)
apply(!pol)(ty)
case RcdType(fields) => fields.values.foreach(apply(pol))
case InfVar(vlvl, uid, state, _) =>
if pol then state.lowerBounds.foreach(apply(true))
else state.upperBounds.foreach(apply(false))
Expand Down
Loading
Loading