From cc1e37ea1b0060d26b1360e1917e9576bd44c8d3 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Wed, 7 Feb 2024 22:24:17 +0000 Subject: [PATCH 1/2] Add GADT symbols when typing typing-ahead lambda bodies --- .../src/dotty/tools/dotc/typer/Namer.scala | 8 ++++++- tests/pos/i19570.min1.scala | 23 ++++++++++++++++++ tests/pos/i19570.min2.scala | 24 +++++++++++++++++++ tests/pos/i19570.orig.scala | 14 +++++++++++ 4 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 tests/pos/i19570.min1.scala create mode 100644 tests/pos/i19570.min2.scala create mode 100644 tests/pos/i19570.orig.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 819b43fcec2c..e5e05e02a7d7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1733,8 +1733,14 @@ class Namer { typer: Typer => val tpe = (paramss: @unchecked) match case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams) case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams) + val rhsCtx = (paramss: @unchecked) match + case TypeSymbols(tparams) :: TermSymbols(_) :: Nil => + val rhsCtx = ctx.fresh.setFreshGADTBounds + rhsCtx.gadtState.addToConstraint(tparams) + rhsCtx + case TermSymbols(_) :: Nil => ctx if (isFullyDefined(tpe, ForceDegree.none)) tpe - else typedAheadExpr(mdef.rhs, tpe).tpe + else typedAheadExpr(mdef.rhs, tpe)(using rhsCtx).tpe case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) => mdef match { diff --git a/tests/pos/i19570.min1.scala b/tests/pos/i19570.min1.scala new file mode 100644 index 000000000000..2cbc852641d3 --- /dev/null +++ b/tests/pos/i19570.min1.scala @@ -0,0 +1,23 @@ +enum Op[A]: + case Dup[T]() extends Op[(T, T)] + +def foo[R](f: [A] => Op[A] => R): R = ??? + +def test = + foo([A] => (o: Op[A]) => o match + case o: Op.Dup[u] => + summon[A =:= (u, u)] // Error: Cannot prove that A =:= (u, u) + () + ) + foo[Unit]([A] => (o: Op[A]) => o match + case o: Op.Dup[u] => + summon[A =:= (u, u)] // Ok + () + ) + foo({ + val f1 = [B] => (o: Op[B]) => o match + case o: Op.Dup[u] => + summon[B =:= (u, u)] // Also ok + () + f1 + }) diff --git a/tests/pos/i19570.min2.scala b/tests/pos/i19570.min2.scala new file mode 100644 index 000000000000..b1450d7e2d1a --- /dev/null +++ b/tests/pos/i19570.min2.scala @@ -0,0 +1,24 @@ +sealed trait Op[A, B] { def giveA: A; def giveB: B } +final case class Dup[T](x: T) extends Op[T, (T, T)] { def giveA: T = x; def giveB: (T, T) = (x, x) } + +class Test: + def foo[R](f: [A, B] => (o: Op[A, B]) => R): R = ??? + + def m1: Unit = + foo([A, B] => (o: Op[A, B]) => o match + case o: Dup[t] => + var a1: t = o.giveA + var a2: A = o.giveA + a1 = a2 + a2 = a1 + + var b1: (t, t) = o.giveB + var b2: B = o.giveB + b1 = b2 + b2 = b1 + + summon[A =:= t] // ERROR: Cannot prove that A =:= t. + summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t). + + () + ) diff --git a/tests/pos/i19570.orig.scala b/tests/pos/i19570.orig.scala new file mode 100644 index 000000000000..6e574f52be91 --- /dev/null +++ b/tests/pos/i19570.orig.scala @@ -0,0 +1,14 @@ +enum Op[A, B]: + case Dup[T]() extends Op[T, (T, T)] + +def foo[R](f: [A, B] => (o: Op[A, B]) => R): R = + f(Op.Dup()) + +def test = + foo([A, B] => (o: Op[A, B]) => { + o match + case o: Op.Dup[t] => + summon[A =:= t] // ERROR: Cannot prove that A =:= t. + summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t). + 42 + }) From 2c81588e620e0ae62aa6641db4aebf9683bd97d3 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Thu, 22 Feb 2024 14:39:47 +0000 Subject: [PATCH 2/2] Extract shared prepareRhsCtx --- .../src/dotty/tools/dotc/typer/Namer.scala | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index e5e05e02a7d7..577907e243d9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1733,12 +1733,7 @@ class Namer { typer: Typer => val tpe = (paramss: @unchecked) match case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams) case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams) - val rhsCtx = (paramss: @unchecked) match - case TypeSymbols(tparams) :: TermSymbols(_) :: Nil => - val rhsCtx = ctx.fresh.setFreshGADTBounds - rhsCtx.gadtState.addToConstraint(tparams) - rhsCtx - case TermSymbols(_) :: Nil => ctx + val rhsCtx = prepareRhsCtx(ctx.fresh, paramss) if (isFullyDefined(tpe, ForceDegree.none)) tpe else typedAheadExpr(mdef.rhs, tpe)(using rhsCtx).tpe @@ -1938,14 +1933,7 @@ class Namer { typer: Typer => var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType) if sym.isInlineMethod then rhsCtx = rhsCtx.addMode(Mode.InlineableBody) if sym.is(ExtensionMethod) then rhsCtx = rhsCtx.addMode(Mode.InExtensionMethod) - val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten - if (typeParams.nonEmpty) { - // we'll be typing an expression from a polymorphic definition's body, - // so we must allow constraining its type parameters - // compare with typedDefDef, see tests/pos/gadt-inference.scala - rhsCtx.setFreshGADTBounds - rhsCtx.gadtState.addToConstraint(typeParams) - } + rhsCtx = prepareRhsCtx(rhsCtx, paramss) def typedAheadRhs(pt: Type) = PrepareInlineable.dropInlineIfError(sym, @@ -1990,4 +1978,15 @@ class Namer { typer: Typer => lhsType orElse WildcardType } end inferredResultType + + /** Prepare a GADT-aware context used to type the RHS of a ValOrDefDef. */ + def prepareRhsCtx(rhsCtx: FreshContext, paramss: List[List[Symbol]])(using Context): FreshContext = + val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten + if typeParams.nonEmpty then + // we'll be typing an expression from a polymorphic definition's body, + // so we must allow constraining its type parameters + // compare with typedDefDef, see tests/pos/gadt-inference.scala + rhsCtx.setFreshGADTBounds + rhsCtx.gadtState.addToConstraint(typeParams) + rhsCtx }