Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization check for methods with non-hot parameters #13999

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ object SymUtils:

def isVolatile(using Context): Boolean = self.hasAnnotation(defn.VolatileAnnot)

def isNonHotParams(using Context): Boolean =
// for now constructor proxy of a case class = takes non-hot parameters
self.isSyntheticApply

// constructor proxy of a case class
def isSyntheticApply(using Context): Boolean =
self.is(Flags.Synthetic) && self.owner.is(Flags.Module) && self.owner.companionClass.is(Flags.Case)

def isAnyOverride(using Context): Boolean = self.is(Override) || self.is(AbsOverride)
// careful: AbsOverride is a term only flag. combining with Override would catch only terms.

Expand Down
140 changes: 83 additions & 57 deletions compiler/src/dotty/tools/dotc/transform/init/Semantic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import ast.tpd._
import util.EqHashMap
import config.Printers.init as printer
import reporting.trace as log
import dotty.tools.dotc.transform.SymUtils._

import Errors._

Expand Down Expand Up @@ -121,7 +122,7 @@ object Semantic {
}

/** A function value */
case class Fun(expr: Tree, thisV: Ref, klass: ClassSymbol, env: Env) extends Value
case class Fun(expr: Tree, thisV: Value, klass: ClassSymbol, env: Env) extends Value

/** A value which represents a set of addresses
*
Expand Down Expand Up @@ -253,7 +254,7 @@ object Semantic {
*/

object Cache {
opaque type CacheStore = mutable.Map[Value, EqHashMap[Tree, Value]]
opaque type CacheStore = mutable.Map[Value, EqHashMap[Tree, mutable.Map[Env, Value]]]
private type Heap = Map[Ref, Objekt]

class Cache {
Expand Down Expand Up @@ -290,33 +291,33 @@ object Semantic {

def hasChanged = changed

def contains(value: Value, expr: Tree) =
current.contains(value, expr) || stable.contains(value, expr)
def contains(value: Value, expr: Tree, env: Env) =
current.contains(value, expr, env) || stable.contains(value, expr, env)

def apply(value: Value, expr: Tree) =
if current.contains(value, expr) then current(value)(expr)
else stable(value)(expr)
def apply(value: Value, expr: Tree, env: Env) =
if current.contains(value, expr, env) then current(value)(expr)(env)
else stable(value)(expr)(env)

/** Copy the value of `(value, expr)` from the last cache to the current cache
* (assuming it's `Hot` if it doesn't exist in the cache).
*
* Then, runs `fun` and update the caches if the values change.
*/
def assume(value: Value, expr: Tree, cacheResult: Boolean)(fun: => Result): Contextual[Result] =
def assume(value: Value, expr: Tree, env: Env, cacheResult: Boolean)(fun: => Result): Contextual[Result] =
val assumeValue: Value =
if last.contains(value, expr) then
last.get(value, expr)
if last.contains(value, expr, env) then
last.get(value, expr, env)
else
last.put(value, expr, Hot)
last.put(value, expr, env, Hot)
Hot
end if
current.put(value, expr, assumeValue)
current.put(value, expr, env, assumeValue)

val actual = fun
if actual.value != assumeValue then
this.changed = true
last.put(value, expr, actual.value)
current.put(value, expr, actual.value)
last.put(value, expr, env, actual.value)
current.put(value, expr, env, actual.value)
else
// It's tempting to cache the value in stable, but it's unsound.
// The reason is that the current value may depend on other values
Expand All @@ -333,8 +334,10 @@ object Semantic {
private def commitToStableCache() =
current.foreach { (v, m) =>
// It's useless to cache value for ThisRef.
if v.isWarm then m.iterator.foreach { (e, res) =>
stable.put(v, e, res)
if v.isWarm then m.iterator.foreach { (e, c) =>
c.iterator.foreach { (env, res) =>
stable.put(v, e, env, res)
}
}
}

Expand Down Expand Up @@ -383,12 +386,13 @@ object Semantic {
}

extension (cache: CacheStore)
def contains(value: Value, expr: Tree) = cache.contains(value) && cache(value).contains(expr)
def get(value: Value, expr: Tree): Value = cache(value)(expr)
def remove(value: Value, expr: Tree) = cache(value).remove(expr)
def put(value: Value, expr: Tree, result: Value): Unit = {
val innerMap = cache.getOrElseUpdate(value, new EqHashMap[Tree, Value])
innerMap(expr) = result
def contains(value: Value, expr: Tree, env: Env) = cache.contains(value) && cache(value).contains(expr) && cache(value)(expr).contains(env)
def get(value: Value, expr: Tree, env: Env): Value = cache(value)(expr)(env)
def remove(value: Value, expr: Tree, env: Env) = cache(value)(expr).remove(env)
def put(value: Value, expr: Tree, env: Env, result: Value): Unit = {
val treeMap = cache.getOrElseUpdate(value, new EqHashMap[Tree, mutable.Map[Env, Value]])
val innerMap = treeMap.getOrElseUpdate(expr, mutable.Map.empty)
innerMap(env) = result
}
end extension
}
Expand Down Expand Up @@ -587,21 +591,51 @@ object Semantic {
}
}

def call(meth: Symbol, args: List[ArgInfo], superType: Type, source: Tree, needResolve: Boolean = true): Contextual[Result] = log("call " + meth.show + ", args = " + args, printer, (_: Result).show) {
def call(meth: Symbol, args: List[ArgInfo], superType: Type, source: Tree, needResolve: Boolean = true): Contextual[Result] = log("call " + meth.show + ", args = " + args + ", value = " + value, printer, (_: Result).show) {
def checkArgs = args.flatMap(_.promote)

// Perform the analysis of the `target` symbol's rhs.
def performCall(target: Symbol) =
val isLocal = !meth.owner.isClass
val trace1 = trace.add(source)
if target.hasSource then
given Trace = trace1
val cls = target.owner.enclosingClass.asClass
val ddef = target.defTree.asInstanceOf[DefDef]
val env2 = Env(ddef, args.map(_.value).widenArgs)
// normal method call
withEnv((if isLocal then env else Env.empty).union(if target.isNonHotParams then env2 else Env.empty)) {
eval(ddef.rhs, value, cls, cacheResult = true) ++ (if target.isNonHotParams then Errors.empty else checkArgs)
}
else if value.canIgnoreMethodCall(target) then
Result(Hot, Nil)
else
// no source code available
val error = CallUnknown(target, source, trace.toVector)
Result(Hot, error :: checkArgs)

// fast track if the current object is already initialized
if promoted.isCurrentObjectPromoted then Result(Hot, Nil)
else value match {
case Hot =>
Result(Hot, checkArgs)
val tryPromote = checkArgs
if tryPromote.isEmpty then
Result(Hot, Errors.empty)
// If we cannot resolve the meth (if it's not effectively final),
// then just stop.
else if needResolve && !meth.isEffectivelyFinal then
Result(Hot, tryPromote)
// If the method requires hot parameters, stop.
else if !meth.isNonHotParams then
Result(Hot, tryPromote)
else
performCall(meth)

case Cold =>
val error = CallCold(meth, source, trace.toVector)
Result(Hot, error :: checkArgs)

case ref: Ref =>
val isLocal = !meth.owner.isClass
val target =
if !needResolve then
meth
Expand All @@ -611,22 +645,7 @@ object Semantic {
resolve(ref.klass, meth)

if target.isOneOf(Flags.Method) then
val trace1 = trace.add(source)
if target.hasSource then
given Trace = trace1
val cls = target.owner.enclosingClass.asClass
val ddef = target.defTree.asInstanceOf[DefDef]
val env2 = Env(ddef, args.map(_.value).widenArgs)
// normal method call
withEnv(if isLocal then env else Env.empty) {
eval(ddef.rhs, ref, cls, cacheResult = true) ++ checkArgs
}
else if ref.canIgnoreMethodCall(target) then
Result(Hot, Nil)
else
// no source code available
val error = CallUnknown(target, source, trace.toVector)
Result(Hot, error :: checkArgs)
performCall(target)
else
// method call resolves to a field
val obj = ref.objekt
Expand Down Expand Up @@ -765,7 +784,7 @@ object Semantic {
}
end extension

extension (ref: Ref)
extension (value: Value)
def accessLocal(tmref: TermRef, klass: ClassSymbol, source: Tree): Contextual[Result] =
val sym = tmref.symbol

Expand All @@ -774,8 +793,8 @@ object Semantic {
if sym.is(Flags.Param) && sym.owner.isConstructor then
// if we can get the field from the Ref (which can only possibly be
// a secondary constructor parameter), then use it.
if (ref.objekt.hasField(sym))
Result(ref.objekt.field(sym), Errors.empty)
if (value.isInstanceOf[Ref] && value.asInstanceOf[Ref].objekt.hasField(sym))
Result(value.asInstanceOf[Ref].objekt.field(sym), Errors.empty)
// instances of local classes inside secondary constructors cannot
// reach here, as those values are abstracted by Cold instead of Warm.
// This enables us to simplify the domain without sacrificing
Expand All @@ -788,13 +807,17 @@ object Semantic {
// It's always safe to approximate them with `Cold`.
Result(Cold, Nil)
else if sym.is(Flags.Param) then
default()
// If the symbol is a method that takes non-hot parameters, we look it up in the environment.
if sym.isNonHotParams then
Result(env.lookup(sym), Nil)
else
default()
else
sym.defTree match {
case vdef: ValDef =>
// resolve this for local variable
val enclosingClass = sym.owner.enclosingClass.asClass
val thisValue2 = resolveThis(enclosingClass, ref, klass, source)
val thisValue2 = resolveThis(enclosingClass, value, klass, source)
thisValue2 match {
case Hot => Result(Hot, Errors.empty)

Expand Down Expand Up @@ -941,7 +964,7 @@ object Semantic {
end extension

// ----- Policies ------------------------------------------------------
extension (value: Ref)
extension (value: Value)
/** Can the method call on `value` be ignored?
*
* Note: assume overriding resolution has been performed.
Expand Down Expand Up @@ -1044,17 +1067,17 @@ object Semantic {
*
* This method only handles cache logic and delegates the work to `cases`.
*/
def eval(expr: Tree, thisV: Ref, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Result).show) {
if (cache.contains(thisV, expr)) Result(cache(thisV, expr), Errors.empty)
else cache.assume(thisV, expr, cacheResult) { cases(expr, thisV, klass) }
def eval(expr: Tree, thisV: Value, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Result).show) {
if (cache.contains(thisV, expr, env)) Result(cache(thisV, expr, env), Errors.empty)
else cache.assume(thisV, expr, env, cacheResult) { cases(expr, thisV, klass) }
}

/** Evaluate a list of expressions */
def eval(exprs: List[Tree], thisV: Ref, klass: ClassSymbol): Contextual[List[Result]] =
def eval(exprs: List[Tree], thisV: Value, klass: ClassSymbol): Contextual[List[Result]] =
exprs.map { expr => eval(expr, thisV, klass) }

/** Evaluate arguments of methods */
def evalArgs(args: List[Arg], thisV: Ref, klass: ClassSymbol): Contextual[(List[Error], List[ArgInfo])] =
def evalArgs(args: List[Arg], thisV: Value, klass: ClassSymbol): Contextual[(List[Error], List[ArgInfo])] =
val errors = new mutable.ArrayBuffer[Error]
val argInfos = new mutable.ArrayBuffer[ArgInfo]
args.foreach { arg =>
Expand All @@ -1074,7 +1097,7 @@ object Semantic {
*
* Note: Recursive call should go to `eval` instead of `cases`.
*/
def cases(expr: Tree, thisV: Ref, klass: ClassSymbol): Contextual[Result] =
def cases(expr: Tree, thisV: Value, klass: ClassSymbol): Contextual[Result] =
expr match {
case Ident(nme.WILDCARD) =>
// TODO: disallow `var x: T = _`
Expand Down Expand Up @@ -1239,7 +1262,10 @@ object Semantic {
else Result(Hot, checkTermUsage(tdef.rhs, thisV, klass))

case tpl: Template =>
init(tpl, thisV, klass)
thisV match {
case ref: Ref => init(tpl, ref, klass)
case _ => throw new Exception("initialization called on a non-ref value: " + thisV)
}

case _: Import | _: Export =>
Result(Hot, Errors.empty)
Expand All @@ -1249,7 +1275,7 @@ object Semantic {
}

/** Handle semantics of leaf nodes */
def cases(tp: Type, thisV: Ref, klass: ClassSymbol, source: Tree): Contextual[Result] = log("evaluating " + tp.show, printer, (_: Result).show) {
def cases(tp: Type, thisV: Value, klass: ClassSymbol, source: Tree): Contextual[Result] = log("evaluating " + tp.show, printer, (_: Result).show) {
tp match {
case _: ConstantType =>
Result(Hot, Errors.empty)
Expand Down Expand Up @@ -1338,7 +1364,7 @@ object Semantic {
}

/** Compute the outer value that correspond to `tref.prefix` */
def outerValue(tref: TypeRef, thisV: Ref, klass: ClassSymbol, source: Tree): Contextual[Result] =
def outerValue(tref: TypeRef, thisV: Value, klass: ClassSymbol, source: Tree): Contextual[Result] =
val cls = tref.classSymbol.asClass
if tref.prefix == NoPrefix then
val enclosing = cls.owner.lexicallyEnclosingClass.asClass
Expand Down Expand Up @@ -1471,7 +1497,7 @@ object Semantic {
*
* This is intended to avoid type soundness issues in Dotty.
*/
def checkTermUsage(tpt: Tree, thisV: Ref, klass: ClassSymbol): Contextual[List[Error]] =
def checkTermUsage(tpt: Tree, thisV: Value, klass: ClassSymbol): Contextual[List[Error]] =
val buf = new mutable.ArrayBuffer[Error]
val traverser = new TypeTraverser {
def traverse(tp: Type): Unit = tp match {
Expand Down
8 changes: 8 additions & 0 deletions tests/init/pos/case-class-apply.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
case class Case(b: Base)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests the Hot case in the call method but not Ref. How do I add such a case?


class Base {
val n = 10
val f = Case(this)

val m = 10
}