Skip to content

Improve GADT reasoning for pattern alternatives #23205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import annotation.constructorOnly
import cc.*
import NameKinds.WildcardParamName
import MatchTypes.isConcrete
import scala.util.boundary, boundary.break

/** Provides methods to compare types.
*/
Expand Down Expand Up @@ -2054,6 +2055,45 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
else op2
end necessaryEither

/** Finds the necessary (the weakest) GADT constraint among a list of them.
* It returns the one being subsumed by all others if exists, and `None` otherwise.
*
* This is used when typechecking pattern alternatives, for instance:
*
* enum Expr[+T]:
* case I1(x: Int) extends Expr[Int]
* case I2(x: Int) extends Expr[Int]
* case B(x: Boolean) extends Expr[Boolean]
* import Expr.*
*
* The following function should compile:
*
* def foo[T](e: Expr[T]): T = e match
* case I1(_) | I2(_) => 42
*
* since `T >: Int` is subsumed by both alternatives in the first match clause.
*
* However, the following should not:
*
* def foo[T](e: Expr[T]): T = e match
* case I1(_) | B(_) => 42
*
* since the `I1(_)` case gives the constraint `T >: Int` while `B(_)` gives `T >: Boolean`.
* Neither of the constraints is subsumed by the other.
*/
def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] = boundary:
constrs match
case Nil => break(None)
case c0 :: constrs =>
var weakest = c0
for c <- constrs do
if subsumes(weakest.constraint, c.constraint, preGadt.constraint) then
weakest = c
else if !subsumes(c.constraint, weakest.constraint, preGadt.constraint) then
// this two constraints are disjoint
break(None)
break(Some(weakest))

inline def rollbackConstraintsUnless(inline op: Boolean): Boolean =
val saved = constraint
var result = false
Expand Down Expand Up @@ -3376,6 +3416,9 @@ object TypeComparer {
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))

def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] =
comparing(_.necessaryGadtConstraint(constrs, preGadt))

def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
comparing(_.explained(op, header, short))

Expand Down
16 changes: 13 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2826,10 +2826,20 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
else
assert(ctx.reporter.errorsReported)
tree.withType(defn.AnyType)
val savedGadt = nestedCtx.gadt
val trees1 = tree.trees.mapconserve(typed(_, pt)(using nestedCtx))
val preGadt = nestedCtx.gadt
var gadtConstrs: mutable.ArrayBuffer[GadtConstraint] = mutable.ArrayBuffer.empty
val trees1 = tree.trees.mapconserve: t =>
nestedCtx.gadtState.restore(preGadt)
val res = typed(t, pt)(using nestedCtx)
gadtConstrs += ctx.gadt
Copy link
Preview

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recording ctx.gadt likely captures the outer context rather than the updated nested context; it should use nestedCtx.gadt to collect the actual constraints produced by each alternative.

Suggested change
gadtConstrs += ctx.gadt
gadtConstrs += nestedCtx.gadt

Copilot uses AI. Check for mistakes.

res
.mapconserve(ensureValueTypeOrWildcard)
nestedCtx.gadtState.restore(savedGadt) // Disable GADT reasoning for pattern alternatives
// Look for the necessary constraint that is subsumed by all alternatives.
// Use that constraint as the outcome if possible, otherwise fallback to not using
// GADT reasoning for soundness.
TypeComparer.necessaryGadtConstraint(gadtConstrs.toList, preGadt) match
case Some(constr) => nestedCtx.gadtState.restore(constr)
case None => nestedCtx.gadtState.restore(preGadt)
assignType(cpy.Alternative(tree)(trees1), trees1)
}

Expand Down
13 changes: 13 additions & 0 deletions tests/neg/gadt-alt-expr1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
enum Expr[+T]:
case I1() extends Expr[Int]
case I2() extends Expr[Int]
case B() extends Expr[Boolean]
import Expr.*
def foo[T](e: Expr[T]): T =
e match
case I1() | I2() => 42 // ok
case B() => true
def bar[T](e: Expr[T]): T =
e match
case I1() | B() => 42 // error
case I2() => 0
15 changes: 15 additions & 0 deletions tests/neg/gadt-alt-expr2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
enum Expr[+T]:
case I1() extends Expr[Int]
case I2() extends Expr[Int]
case I3() extends Expr[Int]
case I4() extends Expr[Int]
case I5() extends Expr[Int]
case B() extends Expr[Boolean]
import Expr.*
def test1[T](e: Expr[T]): T =
e match
case I1() | I2() | I3() | I4() | I5() => 42 // ok
case B() => true
def test2[T](e: Expr[T]): T =
e match
case I1() | I2() | I3() | I4() | I5() | B() => 42 // error
34 changes: 34 additions & 0 deletions tests/neg/gadt-alt-expr3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
trait A
trait B extends A
trait C extends B
trait D
enum Expr[+T]:
case IsA() extends Expr[A]
case IsB() extends Expr[B]
case IsC() extends Expr[C]
case IsD() extends Expr[D]
import Expr.*
def test1[T](e: Expr[T]): T = e match
case IsA() => new A {}
case IsB() => new B {}
case IsC() => new C {}
def test2[T](e: Expr[T]): T = e match
case IsA() | IsB() =>
// IsA() implies T >: A
// IsB() implies T >: B
// So T >: B is chosen
new B {}
case IsC() => new C {}
def test3[T](e: Expr[T]): T = e match
case IsA() | IsB() | IsC() =>
// T >: C is chosen
new C {}
def test4[T](e: Expr[T]): T = e match
case IsA() | IsB() | IsC() =>
new B {} // error
def test5[T](e: Expr[T]): T = e match
case IsA() | IsB() =>
new A {} // error
def test6[T](e: Expr[T]): T = e match
case IsA() | IsC() | IsD() =>
new C {} // error
19 changes: 19 additions & 0 deletions tests/neg/gadt-alt-expr4.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
trait A
trait B extends A
trait C extends B
enum Expr[T]:
case IsA() extends Expr[A]
case IsB() extends Expr[B]
case IsC() extends Expr[C]
import Expr.*
def test1[T](e: Expr[T]): T = e match
case IsA() => new A {}
case IsB() => new B {}
case IsC() => new C {}
def test2[T](e: Expr[T]): T = e match
case IsA() | IsB() =>
// IsA() implies T =:= A
// IsB() implies T =:= B
// No necessary constraint can be found
new B {} // error
case IsC() => new C {}
14 changes: 14 additions & 0 deletions tests/neg/gadt-alt-expr5.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
trait A
trait B extends A
trait C extends B
enum Expr[-T]:
case IsA() extends Expr[A]
case IsB() extends Expr[B]
case IsC() extends Expr[C]
import Expr.*
def test1[T](e: Expr[T]): Unit = e match
case IsA() | IsB() =>
val t1: T = ???
val t2: A = t1
val t3: B = t1 // error
case IsC() =>
2 changes: 1 addition & 1 deletion tests/neg/gadt-alternatives.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ import Expr.*
def eval[T](e: Expr[T]): T = e match
case StringVal(_) | IntVal(_) => "42" // error
def eval1[T](e: Expr[T]): T = e match
case IntValAlt(_) | IntVal(_) => 42 // error // limitation
case IntValAlt(_) | IntVal(_) => 42 // previously error, now ok
8 changes: 8 additions & 0 deletions tests/pos/gadt-alt-doc1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
trait Document[Doc <: Document[Doc]]
sealed trait Conversion[Doc, V]

case class C[Doc <: Document[Doc]]() extends Conversion[Doc, Doc]

def Test[Doc <: Document[Doc], V](conversion: Conversion[Doc, V]) =
conversion match
case C() | C() => ??? // error
Loading