Skip to content

Commit

Permalink
Merge pull request #14026 from dotty-staging/avoid-numbered-9
Browse files Browse the repository at this point in the history
Sound type avoidance (hopefully!)
  • Loading branch information
odersky authored Dec 14, 2021
2 parents 221fc71 + 629006b commit 7bf25f5
Show file tree
Hide file tree
Showing 48 changed files with 730 additions and 391 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
protected def rootContext(using Context): Context = {
ctx.initialize()
ctx.base.setPhasePlan(comp.phases)
val rootScope = new MutableScope
val rootScope = new MutableScope(0)
val bootstrap = ctx.fresh
.setPeriod(Period(comp.nextRunId, FirstPhaseId))
.setScope(rootScope)
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import typer.ProtoTypes
import transform.SymUtils._
import transform.TypeUtils._
import core._
import Scopes.newScope
import util.Spans._, Types._, Contexts._, Constants._, Names._, Flags._, NameOps._
import Symbols._, StdNames._, Annotations._, Trees._, Symbols._
import Decorators._, DenotTransformers._
Expand Down Expand Up @@ -344,7 +345,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
else parents
val cls = newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1,
coord = fns.map(_.span).reduceLeft(_ union _))
newScope, coord = fns.map(_.span).reduceLeft(_ union _))
val constr = newConstructor(cls, Synthetic, Nil, Nil).entered
def forwarder(fn: TermSymbol, name: TermName) = {
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ private sealed trait YSettings:
val YprintSyms: Setting[Boolean] = BooleanSetting("-Yprint-syms", "When printing trees print info in symbols instead of corresponding info in trees.")
val YprintDebug: Setting[Boolean] = BooleanSetting("-Yprint-debug", "When printing trees, print some extra information useful for debugging.")
val YprintDebugOwners: Setting[Boolean] = BooleanSetting("-Yprint-debug-owners", "When printing trees, print owners of definitions.")
val YprintLevel: Setting[Boolean] = BooleanSetting("-Yprint-level", "print nesting levels of symbols and type variables.")
val YshowPrintErrors: Setting[Boolean] = BooleanSetting("-Yshow-print-errors", "Don't suppress exceptions thrown during tree printing.")
val YtestPickler: Setting[Boolean] = BooleanSetting("-Ytest-pickler", "Self-test for pickling functionality; should be used with -Ystop-after:pickler.")
val YcheckReentrant: Setting[Boolean] = BooleanSetting("-Ycheck-reentrant", "Check that compiled program does not contain vars that can be accessed from a global root.")
Expand Down
26 changes: 20 additions & 6 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ abstract class Constraint extends Showable {
/** A constraint that includes the relationship `p1 <: p2`.
* `<:` relationships between parameters ("edges") are propagated, but
* non-parameter bounds are left alone.
*
* @param direction Must be set to `KeepParam1` or `KeepParam2` when
* `p2 <: p1` is already true depending on which parameter
* the caller intends to keep. This will avoid propagating
* bounds that will be redundant after `p1` and `p2` are
* unified.
*/
def addLess(p1: TypeParamRef, p2: TypeParamRef)(using Context): This

/** A constraint resulting from adding p2 = p1 to this constraint, and at the same
* time transferring all bounds of p2 to p1
*/
def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): This
def addLess(p1: TypeParamRef, p2: TypeParamRef,
direction: UnificationDirection = UnificationDirection.NoUnification)(using Context): This

/** A new constraint which is derived from this constraint by removing
* the type parameter `param` from the domain and replacing all top-level occurrences
Expand Down Expand Up @@ -174,3 +176,15 @@ abstract class Constraint extends Showable {
*/
def checkConsistentVars()(using Context): Unit
}

/** When calling `Constraint#addLess(p1, p2, ...)`, the caller might end up
* unifying one parameter with the other, this enum lets `addLess` know which
* direction the unification will take.
*/
enum UnificationDirection:
/** Neither p1 nor p2 will be instantiated. */
case NoUnification
/** `p2 := p1`, p1 left uninstantiated. */
case KeepParam1
/** `p1 := p2`, p2 left uninstantiated. */
case KeepParam2
218 changes: 198 additions & 20 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import Flags._
import config.Config
import config.Printers.typr
import reporting.trace
import typer.ProtoTypes.newTypeVar
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import StdNames.tpnme
import UnificationDirection.*
import NameKinds.AvoidNameKind

/** Methods for adding constraints and solving them.
*
Expand Down Expand Up @@ -56,20 +58,68 @@ trait ConstraintHandling {
*/
protected var comparedTypeLambdas: Set[TypeLambda] = Set.empty

protected var myNecessaryConstraintsOnly = false
/** When collecting the constraints needed for a particular subtyping
* judgment to be true, we sometimes need to approximate the constraint
* set (see `TypeComparer#either` for example).
*
* Normally, this means adding extra constraints which may not be necessary
* for the subtyping judgment to be true, but if this variable is set to true
* we will instead under-approximate and keep only the constraints that must
* always be present for the subtyping judgment to hold.
*
* This is needed for GADT bounds inference to be sound, but it is also used
* when constraining a method call based on its expected type to avoid adding
* constraints that would later prevent us from typechecking method
* arguments, see or-inf.scala and and-inf.scala for examples.
*/
protected def necessaryConstraintsOnly(using Context): Boolean =
ctx.mode.is(Mode.GadtConstraintInference) || myNecessaryConstraintsOnly

def checkReset() =
assert(addConstraintInvocations == 0)
assert(frozenConstraint == false)
assert(caseLambda == NoType)
assert(homogenizeArgs == false)
assert(comparedTypeLambdas == Set.empty)

def nestingLevel(param: TypeParamRef) = constraint.typeVarOfParam(param) match
case tv: TypeVar => tv.nestingLevel
case _ => Int.MaxValue

/** If `param` is nested deeper than `maxLevel`, try to instantiate it to a
* fresh type variable of level `maxLevel` and return the new variable.
* If this isn't possible, throw a TypeError.
*/
def atLevel(maxLevel: Int, param: TypeParamRef)(using Context): TypeParamRef =
if nestingLevel(param) <= maxLevel then return param
LevelAvoidMap(0, maxLevel)(param) match
case freshVar: TypeVar => freshVar.origin
case _ => throw new TypeError(
i"Could not decrease the nesting level of ${param} from ${nestingLevel(param)} to $maxLevel in $constraint")

def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = constraint.nonParamBounds(param)

/** The full lower bound of `param` includes both the `nonParamBounds` and the
* params in the constraint known to be `<: param`, except that
* params with a `nestingLevel` higher than `param` will be instantiated
* to a fresh param at a legal level. See the documentation of `TypeVar`
* for details.
*/
def fullLowerBound(param: TypeParamRef)(using Context): Type =
constraint.minLower(param).foldLeft(nonParamBounds(param).lo)(_ | _)
val maxLevel = nestingLevel(param)
var loParams = constraint.minLower(param)
if maxLevel != Int.MaxValue then
loParams = loParams.mapConserve(atLevel(maxLevel, _))
loParams.foldLeft(nonParamBounds(param).lo)(_ | _)

/** The full upper bound of `param`, see the documentation of `fullLowerBounds` above. */
def fullUpperBound(param: TypeParamRef)(using Context): Type =
constraint.minUpper(param).foldLeft(nonParamBounds(param).hi)(_ & _)
val maxLevel = nestingLevel(param)
var hiParams = constraint.minUpper(param)
if maxLevel != Int.MaxValue then
hiParams = hiParams.mapConserve(atLevel(maxLevel, _))
hiParams.foldLeft(nonParamBounds(param).hi)(_ & _)

/** Full bounds of `param`, including other lower/upper params.
*
Expand All @@ -79,10 +129,111 @@ trait ConstraintHandling {
def fullBounds(param: TypeParamRef)(using Context): TypeBounds =
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))

/** If true, eliminate wildcards in bounds by avoidance, otherwise replace
* them by fresh variables.
/** An approximating map that prevents types nested deeper than maxLevel as
* well as WildcardTypes from leaking into the constraint.
* Note that level-checking is turned off after typer and in uncommitable
* TyperState since these leaks should be safe.
*/
protected def approximateWildcards: Boolean = true
class LevelAvoidMap(topLevelVariance: Int, maxLevel: Int)(using Context) extends TypeOps.AvoidMap:
variance = topLevelVariance

/** Are we allowed to refer to types of the given `level`? */
private def levelOK(level: Int): Boolean =
level <= maxLevel || ctx.isAfterTyper || !ctx.typerState.isCommittable

def toAvoid(tp: NamedType): Boolean =
tp.prefix == NoPrefix && !tp.symbol.isStatic && !levelOK(tp.symbol.nestingLevel)

/** Return a (possibly fresh) type variable of a level no greater than `maxLevel` which is:
* - lower-bounded by `tp` if variance >= 0
* - upper-bounded by `tp` if variance <= 0
* If this isn't possible, return the empty range.
*/
def legalVar(tp: TypeVar): Type =
val oldParam = tp.origin
val nameKind =
if variance > 0 then AvoidNameKind.UpperBound
else if variance < 0 then AvoidNameKind.LowerBound
else AvoidNameKind.BothBounds

/** If it exists, return the first param in the list created in a previous call to `legalVar(tp)`
* with the appropriate level and variance.
*/
def findParam(params: List[TypeParamRef]): Option[TypeParamRef] =
params.find(p =>
nestingLevel(p) <= maxLevel && representedParamRef(p) == oldParam &&
(p.paramName.is(AvoidNameKind.BothBounds) ||
variance != 0 && p.paramName.is(nameKind)))

// First, check if we can reuse an existing parameter, this is more than an optimization
// since it avoids an infinite loop in tests/pos/i8900-cycle.scala
findParam(constraint.lower(oldParam)).orElse(findParam(constraint.upper(oldParam))) match
case Some(param) =>
constraint.typeVarOfParam(param)
case _ =>
// Otherwise, try to return a fresh type variable at `maxLevel` with
// the appropriate constraints.
val name = nameKind(oldParam.paramName.toTermName).toTypeName
val freshVar = newTypeVar(TypeBounds.upper(tp.topType), name,
nestingLevel = maxLevel, represents = oldParam)
val ok =
if variance < 0 then
addLess(freshVar.origin, oldParam)
else if variance > 0 then
addLess(oldParam, freshVar.origin)
else
unify(freshVar.origin, oldParam)
if ok then freshVar else emptyRange
end legalVar

override def apply(tp: Type): Type = tp match
case tp: TypeVar if !tp.isInstantiated && !levelOK(tp.nestingLevel) =>
legalVar(tp)
// TypeParamRef can occur in tl bounds
case tp: TypeParamRef =>
constraint.typeVarOfParam(tp) match
case tvar: TypeVar =>
apply(tvar)
case _ => super.apply(tp)
case _ =>
super.apply(tp)

override def mapWild(t: WildcardType) =
if ctx.mode.is(Mode.TypevarsMissContext) then super.mapWild(t)
else
val tvar = newTypeVar(apply(t.effectiveBounds).toBounds, nestingLevel = maxLevel)
tvar
end LevelAvoidMap

/** Approximate `rawBound` if needed to make it a legal bound of `param` by
* avoiding wildcards and types with a level strictly greater than its
* `nestingLevel`.
*
* Note that level-checking must be performed here and cannot be delayed
* until instantiation because if we allow level-incorrect bounds, then we
* might end up reasoning with bad bounds outside of the scope where they are
* defined. This can lead to level-correct but unsound instantiations as
* demonstrated by tests/neg/i8900.scala.
*/
protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type =
// Over-approximate for soundness.
var variance = if isUpper then -1 else 1
// ...unless we can only infer necessary constraints, in which case we
// flip the variance to under-approximate.
if necessaryConstraintsOnly then variance = -variance

val approx = new LevelAvoidMap(variance, nestingLevel(param)):
override def legalVar(tp: TypeVar): Type =
// `legalVar` will create a type variable whose bounds depend on
// `variance`, but whether the variance is positive or negative,
// we can still infer necessary constraints since just creating a
// type variable doesn't reduce the set of possible solutions.
// Therefore, we can safely "unflip" the variance flipped above.
// This is necessary for i8900-unflip.scala to typecheck.
val v = if necessaryConstraintsOnly then -this.variance else this.variance
atVariance(v)(super.legalVar(tp))
approx(rawBound)
end legalBound

protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
if !constraint.contains(param) then true
Expand All @@ -91,12 +242,7 @@ trait ConstraintHandling {
// so we shouldn't allow them as constraints either.
false
else
val dropWildcards = new AvoidWildcardsMap:
if !isUpper then variance = -1
override def mapWild(t: WildcardType) =
if approximateWildcards then super.mapWild(t)
else newTypeVar(apply(t.effectiveBounds).toBounds)
val bound = dropWildcards(rawBound)
val bound = legalBound(param, rawBound, isUpper)
val oldBounds @ TypeBounds(lo, hi) = constraint.nonParamBounds(param)
val equalBounds = (if isUpper then lo else hi) eq bound
if equalBounds && !bound.existsPart(_ eq param, StopAt.Static) then
Expand Down Expand Up @@ -191,19 +337,50 @@ trait ConstraintHandling {

def location(using Context) = "" // i"in ${ctx.typerState.stateChainStr}" // use for debugging

/** Make p2 = p1, transfer all bounds of p2 to p1
* @pre less(p1)(p2)
/** Unify p1 with p2: one parameter will be kept in the constraint, the
* other will be removed and its bounds transferred to the remaining one.
*
* If p1 and p2 have different `nestingLevel`, the parameter with the lowest
* level will be kept and the transferred bounds from the other parameter
* will be adjusted for level-correctness.
*/
private def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): Boolean = {
constr.println(s"unifying $p1 $p2")
assert(constraint.isLess(p1, p2))
constraint = constraint.addLess(p2, p1)
if !constraint.isLess(p1, p2) then
constraint = constraint.addLess(p1, p2)

val level1 = nestingLevel(p1)
val level2 = nestingLevel(p2)
val pKept = if level1 <= level2 then p1 else p2
val pRemoved = if level1 <= level2 then p2 else p1

constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1)

val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)

if level1 != level2 then
boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved)
val TypeBounds(lo, hi) = boundRemoved
// After avoidance, the interval might be empty, e.g. in
// tests/pos/i8900-promote.scala:
// >: x.type <: Singleton
// becomes:
// >: Int <: Singleton
// In that case, we can still get a legal constraint
// by replacing the lower-bound to get:
// >: Int & Singleton <: Singleton
if !isSub(lo, hi) then
boundRemoved = TypeBounds(lo & hi, hi)

val down = constraint.exclusiveLower(p2, p1)
val up = constraint.exclusiveUpper(p1, p2)
constraint = constraint.unify(p1, p2)
val bounds = constraint.nonParamBounds(p1)
val lo = bounds.lo
val hi = bounds.hi

val newBounds = (boundKept & boundRemoved).bounds
constraint = constraint.updateEntry(pKept, newBounds).replace(pRemoved, pKept)

val lo = newBounds.lo
val hi = newBounds.hi
isSub(lo, hi) &&
down.forall(addOneBound(_, hi, isUpper = true)) &&
up.forall(addOneBound(_, lo, isUpper = false))
Expand Down Expand Up @@ -256,6 +433,7 @@ trait ConstraintHandling {
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
constraint.entry(param) match
case entry: TypeBounds =>
val maxLevel = nestingLevel(param)
val useLowerBound = fromBelow || param.occursIn(entry.hi)
val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
typr.println(s"approx ${param.show}, from below = $fromBelow, inst = ${inst.show}")
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ object Contexts {
if owner != null && owner.isClass then owner.asClass.unforcedDecls
else scope

def nestingLevel: Int =
val sc = effectiveScope
if sc != null then sc.nestingLevel else 0

/** Sourcefile corresponding to given abstract file, memoized */
def getSource(file: AbstractFile, codec: => Codec = Codec(settings.encoding.value)) = {
util.Stats.record("Context.getSource")
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ class Definitions {
private def newPermanentClassSymbol(owner: Symbol, name: TypeName, flags: FlagSet, infoFn: ClassSymbol => Type) =
newClassSymbol(owner, name, flags | Permanent | NoInits | Open, infoFn)

private def enterCompleteClassSymbol(owner: Symbol, name: TypeName, flags: FlagSet, parents: List[TypeRef], decls: Scope = newScope) =
private def enterCompleteClassSymbol(owner: Symbol, name: TypeName, flags: FlagSet, parents: List[TypeRef]): ClassSymbol =
enterCompleteClassSymbol(owner, name, flags, parents, newScope(owner.nestingLevel + 1))

private def enterCompleteClassSymbol(owner: Symbol, name: TypeName, flags: FlagSet, parents: List[TypeRef], decls: Scope) =
newCompleteClassSymbol(owner, name, flags | Permanent | NoInits | Open, parents, decls).entered

private def enterTypeField(cls: ClassSymbol, name: TypeName, flags: FlagSet, scope: MutableScope) =
Expand Down Expand Up @@ -433,7 +436,7 @@ class Definitions {
Any_toString, Any_##, Any_getClass, Any_isInstanceOf, Any_typeTest, Object_eq, Object_ne)

@tu lazy val AnyKindClass: ClassSymbol = {
val cls = newCompleteClassSymbol(ScalaPackageClass, tpnme.AnyKind, AbstractFinal | Permanent, Nil)
val cls = newCompleteClassSymbol(ScalaPackageClass, tpnme.AnyKind, AbstractFinal | Permanent, Nil, newScope(0))
if (!ctx.settings.YnoKindPolymorphism.value)
// Enable kind-polymorphism by exposing scala.AnyKind
cls.entered
Expand Down
Loading

0 comments on commit 7bf25f5

Please sign in to comment.