diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 6fbced1bd377..cd2813c807eb 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1594,6 +1594,9 @@ class Definitions { yield nme.apply.specializedFunction(r, List(t1, t2)).asTermName + @tu lazy val FunctionSpecializedApplyNames: collection.Set[Name] = + Function0SpecializedApplyNames ++ Function1SpecializedApplyNames ++ Function2SpecializedApplyNames + def functionArity(tp: Type)(using Context): Int = tp.dropDependentRefinement.dealias.argInfos.length - 1 /** Return underlying context function type (i.e. instance of an ContextFunctionN class) diff --git a/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala b/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala index b2aab80e98b4..49d7d0009f42 100644 --- a/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala @@ -6,6 +6,8 @@ import MegaPhase.MiniPhase import core.* import Symbols.*, Contexts.*, Types.*, Decorators.* import StdNames.nme +import SymUtils.* +import NameKinds.AdaptedClosureName /** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`, * provided `f` is a pure path of function type. @@ -15,6 +17,11 @@ import StdNames.nme * where a context function is expected, unless that value has the * syntactic form of a context function literal. * + * Also handle variants of eta-expansions where + * - result f.apply(X_1,...,X_n) is subject to a synthetic cast, or + * - the application uses a specialized apply method, or + * - the closure is adapted (see Erasure#adaptClosure) + * * Without this phase, when a contextual function is passed as an argument to a * recursive function, that would have the unfortunate effect of a linear growth * in transient thunks of identical type wrapped around each other, leading @@ -27,20 +34,36 @@ class EtaReduce extends MiniPhase: override def description: String = EtaReduce.description - override def transformBlock(tree: Block)(using Context): Tree = tree match - case Block((meth : DefDef) :: Nil, closure: Closure) - if meth.symbol == closure.meth.symbol => - meth.rhs match - case Apply(Select(fn, nme.apply), args) - if meth.paramss.head.corresponds(args)((param, arg) => + override def transformBlock(tree: Block)(using Context): Tree = + + def tryReduce(mdef: DefDef, rhs: Tree): Tree = rhs match + case Apply(Select(fn, name), args) + if (name == nme.apply || defn.FunctionSpecializedApplyNames.contains(name)) + && mdef.paramss.head.corresponds(args)((param, arg) => arg.isInstanceOf[Ident] && arg.symbol == param.symbol) - && isPurePath(fn) - && fn.tpe <:< tree.tpe - && defn.isFunctionClass(fn.tpe.widen.typeSymbol) => - report.log(i"eta reducing $tree --> $fn") - fn - case _ => tree - case _ => tree + && isPurePath(fn) + && fn.tpe <:< tree.tpe + && defn.isFunctionClass(fn.tpe.widen.typeSymbol) => + report.log(i"eta reducing $tree --> $fn") + fn + case TypeApply(Select(qual, _), _) if rhs.symbol.isTypeCast && rhs.span.isSynthetic => + tryReduce(mdef, qual) + case _ => + tree + + tree match + case Block((meth: DefDef) :: Nil, expr) if meth.symbol.isAnonymousFunction => + expr match + case closure: Closure if meth.symbol == closure.meth.symbol => + tryReduce(meth, meth.rhs) + case Block((adapted: DefDef) :: Nil, closure: Closure) + if adapted.name.is(AdaptedClosureName) && adapted.symbol == closure.meth.symbol => + tryReduce(meth, meth.rhs) + case _ => + tree + case _ => + tree + end transformBlock end EtaReduce diff --git a/tests/run/i14623.scala b/tests/run/i14623.scala new file mode 100644 index 000000000000..6f231448d1f5 --- /dev/null +++ b/tests/run/i14623.scala @@ -0,0 +1,15 @@ +object Thunk { + private[this] val impl = + ((x: Any) => x).asInstanceOf[(=> Any) => Function0[Any]] + + def asFunction0[A](thunk: => A): Function0[A] = impl(thunk).asInstanceOf[Function0[A]] +} + +@main def Test = + var i = 0 + val f1 = { () => i += 1; "" } + assert(Thunk.asFunction0(f1()) eq f1) + val f2 = { () => i += 1; i } + assert(Thunk.asFunction0(f2()) eq f2) + val f3 = { () => i += 1 } + assert(Thunk.asFunction0(f3()) eq f3)