Skip to content

Commit 48518e1

Browse files
committed
Document & test GADT reasoning for pattern alternatives
1 parent aaeedd3 commit 48518e1

File tree

8 files changed

+131
-1
lines changed

8 files changed

+131
-1
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2056,7 +2056,31 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20562056
end necessaryEither
20572057

20582058
/** Finds the necessary (the weakest) GADT constraint among a list of them.
2059-
* It returns the one being subsumed by all others if exists, and `None` otherwise. */
2059+
* It returns the one being subsumed by all others if exists, and `None` otherwise.
2060+
*
2061+
* This is used when typechecking pattern alternatives, for instance:
2062+
*
2063+
* enum Expr[+T]:
2064+
* case I1(x: Int) extends Expr[Int]
2065+
* case I2(x: Int) extends Expr[Int]
2066+
* case B(x: Boolean) extends Expr[Boolean]
2067+
* import Expr.*
2068+
*
2069+
* The following function should compile:
2070+
*
2071+
* def foo[T](e: Expr[T]): T = e match
2072+
* case I1(_) | I2(_) => 42
2073+
*
2074+
* since `T >: Int` is subsumed by both alternatives in the first match clause.
2075+
*
2076+
* However, the following should not:
2077+
*
2078+
* def foo[T](e: Expr[T]): T = e match
2079+
* case I1(_) | B(_) => 42
2080+
*
2081+
* since the `I1(_)` case gives the constraint `T >: Int` while `B(_)` gives `T >: Boolean`.
2082+
* Neither of the constraints is subsumed by the other.
2083+
*/
20602084
def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] = boundary:
20612085
constrs match
20622086
case Nil => break(None)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,6 +2834,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28342834
gadtConstrs += ctx.gadt
28352835
res
28362836
.mapconserve(ensureValueTypeOrWildcard)
2837+
// Look for the necessary constraint that is subsumed by all alternatives.
2838+
// Use that constraint as the outcome if possible, otherwise fallback to not using
2839+
// GADT reasoning for soundness.
28372840
TypeComparer.necessaryGadtConstraint(gadtConstrs.toList, preGadt) match
28382841
case Some(constr) => nestedCtx.gadtState.restore(constr)
28392842
case None => nestedCtx.gadtState.restore(preGadt)

tests/neg/gadt-alt-expr1.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
enum Expr[+T]:
2+
case I1() extends Expr[Int]
3+
case I2() extends Expr[Int]
4+
case B() extends Expr[Boolean]
5+
import Expr.*
6+
def foo[T](e: Expr[T]): T =
7+
e match
8+
case I1() | I2() => 42 // ok
9+
case B() => true
10+
def bar[T](e: Expr[T]): T =
11+
e match
12+
case I1() | B() => 42 // error
13+
case I2() => 0

tests/neg/gadt-alt-expr2.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
enum Expr[+T]:
2+
case I1() extends Expr[Int]
3+
case I2() extends Expr[Int]
4+
case I3() extends Expr[Int]
5+
case I4() extends Expr[Int]
6+
case I5() extends Expr[Int]
7+
case B() extends Expr[Boolean]
8+
import Expr.*
9+
def test1[T](e: Expr[T]): T =
10+
e match
11+
case I1() | I2() | I3() | I4() | I5() => 42 // ok
12+
case B() => true
13+
def test2[T](e: Expr[T]): T =
14+
e match
15+
case I1() | I2() | I3() | I4() | I5() | B() => 42 // error

tests/neg/gadt-alt-expr3.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
trait A
2+
trait B extends A
3+
trait C extends B
4+
trait D
5+
enum Expr[+T]:
6+
case IsA() extends Expr[A]
7+
case IsB() extends Expr[B]
8+
case IsC() extends Expr[C]
9+
case IsD() extends Expr[D]
10+
import Expr.*
11+
def test1[T](e: Expr[T]): T = e match
12+
case IsA() => new A {}
13+
case IsB() => new B {}
14+
case IsC() => new C {}
15+
def test2[T](e: Expr[T]): T = e match
16+
case IsA() | IsB() =>
17+
// IsA() implies T >: A
18+
// IsB() implies T >: B
19+
// So T >: B is chosen
20+
new B {}
21+
case IsC() => new C {}
22+
def test3[T](e: Expr[T]): T = e match
23+
case IsA() | IsB() | IsC() =>
24+
// T >: C is chosen
25+
new C {}
26+
def test4[T](e: Expr[T]): T = e match
27+
case IsA() | IsB() | IsC() =>
28+
new B {} // error
29+
def test5[T](e: Expr[T]): T = e match
30+
case IsA() | IsB() =>
31+
new A {} // error
32+
def test6[T](e: Expr[T]): T = e match
33+
case IsA() | IsC() | IsD() =>
34+
new C {} // error

tests/neg/gadt-alt-expr4.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
trait A
2+
trait B extends A
3+
trait C extends B
4+
enum Expr[T]:
5+
case IsA() extends Expr[A]
6+
case IsB() extends Expr[B]
7+
case IsC() extends Expr[C]
8+
import Expr.*
9+
def test1[T](e: Expr[T]): T = e match
10+
case IsA() => new A {}
11+
case IsB() => new B {}
12+
case IsC() => new C {}
13+
def test2[T](e: Expr[T]): T = e match
14+
case IsA() | IsB() =>
15+
// IsA() implies T =:= A
16+
// IsB() implies T =:= B
17+
// No necessary constraint can be found
18+
new B {} // error
19+
case IsC() => new C {}

tests/neg/gadt-alt-expr5.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
trait A
2+
trait B extends A
3+
trait C extends B
4+
enum Expr[-T]:
5+
case IsA() extends Expr[A]
6+
case IsB() extends Expr[B]
7+
case IsC() extends Expr[C]
8+
import Expr.*
9+
def test1[T](e: Expr[T]): Unit = e match
10+
case IsA() | IsB() =>
11+
val t1: T = ???
12+
val t2: A = t1
13+
val t3: B = t1 // error
14+
case IsC() =>

tests/pos/gadt-alt-doc1.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
trait Document[Doc <: Document[Doc]]
2+
sealed trait Conversion[Doc, V]
3+
4+
case class C[Doc <: Document[Doc]]() extends Conversion[Doc, Doc]
5+
6+
def Test[Doc <: Document[Doc], V](conversion: Conversion[Doc, V]) =
7+
conversion match
8+
case C() | C() => ??? // error

0 commit comments

Comments
 (0)