Skip to content

Commit

Permalink
Inferring tracked (#21628)
Browse files Browse the repository at this point in the history
Infer `tracked` for parameters that are referenced in the public
signatures of the defining class.
e.g.
```scala 3
class OrdSet(val ord: Ordering) {
  type Set = List[ord.T]
  def empty: Set = Nil

  implicit class helper(s: Set) {
    def add(x: ord.T): Set = x :: remove(x)
    def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0)
    def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0)
  }
}
```
In the example above, `ord` is referenced in the signatures of the
public members of `OrdSet`, so a `tracked` modifier will be inserted
automatically.

Aldo generalize the condition for infering tracked for context bounds.
Explicit `using val` witnesses will now also be `tracked` by default.

This implementation should be safe with regards to not introducing
spurious cyclic reference errors.
Current limitations (I'll create separate issues for them, once this is
merged):
- Inferring `tracked` for given classes is done after the desugaring to
class + def, so the def doesn't know about `tracked` being set on the
original constructor parameter. This might be worked around by watching
the original symbol or adding an attachment pointer to the implicit
wrapper.
  ```scala 3
  given mInst: (c: C) => M:
  def foo: c.T = c.foo
  ```
- Passing parameters as an **inferred** `tracked` arguments in parents
doesn't work, since forcing a parent (term) isn't safe.
  This can be replaced with a lint that is checked after Namer.
  • Loading branch information
KacperFKorban authored Jan 15, 2025
1 parent 312c89a commit 019d203
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 34 deletions.
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2754,6 +2754,9 @@ object SymDenotations {
/** Sets all missing fields of given denotation */
def complete(denot: SymDenotation)(using Context): Unit

/** Is this a completer for an explicit type tree */
def isExplicit: Boolean = false

def apply(sym: Symbol): LazyType = this
def apply(module: TermSymbol, modcls: ClassSymbol): LazyType = this

Expand Down
134 changes: 101 additions & 33 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class Namer { typer: Typer =>
if rhs.isEmpty || flags.is(Opaque) then flags |= Deferred
if flags.is(Param) then tree.rhs else analyzeRHS(tree.rhs)

def hasExplicitType(tree: ValOrDefDef): Boolean =
!tree.tpt.isEmpty || tree.mods.isOneOf(TermParamOrAccessor)

// to complete a constructor, move one context further out -- this
// is the context enclosing the class. Note that the context in which a
// constructor is recorded and the context in which it is completed are
Expand All @@ -291,6 +294,8 @@ class Namer { typer: Typer =>

val completer = tree match
case tree: TypeDef => TypeDefCompleter(tree)(cctx)
case tree: ValOrDefDef if Feature.enabled(Feature.modularity) && hasExplicitType(tree) =>
new Completer(tree, isExplicit = true)(cctx)
case _ => Completer(tree)(cctx)
val info = adjustIfModule(completer, tree)
createOrRefine[Symbol](tree, name, flags, ctx.owner, _ => info,
Expand Down Expand Up @@ -800,7 +805,7 @@ class Namer { typer: Typer =>
}

/** The completer of a symbol defined by a member def or import (except ClassSymbols) */
class Completer(val original: Tree)(ictx: Context) extends LazyType with SymbolLoaders.SecondCompleter {
class Completer(val original: Tree, override val isExplicit: Boolean = false)(ictx: Context) extends LazyType with SymbolLoaders.SecondCompleter {

protected def localContext(owner: Symbol): FreshContext = ctx.fresh.setOwner(owner).setTree(original)

Expand Down Expand Up @@ -1783,7 +1788,7 @@ class Namer { typer: Typer =>
sym.owner.typeParams.foreach(_.ensureCompleted())
completeTrailingParamss(constr, sym, indexingCtor = true)
if Feature.enabled(modularity) then
constr.termParamss.foreach(_.foreach(setTracked))
constr.termParamss.foreach(_.foreach(setTrackedConstrParam))

/** The signature of a module valdef.
* This will compute the corresponding module class TypeRef immediately
Expand Down Expand Up @@ -1923,22 +1928,24 @@ class Namer { typer: Typer =>
def wrapRefinedMethType(restpe: Type): Type =
wrapMethType(addParamRefinements(restpe, paramSymss))

def addTrackedIfNeeded(ddef: DefDef, owningSym: Symbol): Unit =
for params <- ddef.termParamss; param <- params do
val psym = symbolOfTree(param)
if needsTracked(psym, param, owningSym) then
psym.setFlag(Tracked)
setParamTrackedWithAccessors(psym, sym.maybeOwner.infoOrCompleter)

if Feature.enabled(modularity) then addTrackedIfNeeded(ddef, sym.maybeOwner)

if isConstructor then
// set result type tree to unit, but take the current class as result type of the symbol
typedAheadType(ddef.tpt, defn.UnitType)
val mt = wrapMethType(effectiveResultType(sym, paramSymss))
if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner)
mt
else if sym.isAllOf(Given | Method) && Feature.enabled(modularity) then
// set every context bound evidence parameter of a given companion method
// to be tracked, provided it has a type that has an abstract type member.
// Add refinements for all tracked parameters to the result type.
for params <- ddef.termParamss; param <- params do
val psym = symbolOfTree(param)
if needsTracked(psym, param) then psym.setFlag(Tracked)
valOrDefDefSig(ddef, sym, paramSymss, wrapRefinedMethType)
else
valOrDefDefSig(ddef, sym, paramSymss, wrapMethType)
val paramFn = if Feature.enabled(Feature.modularity) && sym.isAllOf(Given | Method) then wrapRefinedMethType else wrapMethType
valOrDefDefSig(ddef, sym, paramSymss, paramFn)
end defDefSig

/** Complete the trailing parameters of a DefDef,
Expand Down Expand Up @@ -1987,36 +1994,97 @@ class Namer { typer: Typer =>
cls.srcPos)
case _ =>

/** Under x.modularity, we add `tracked` to context bound witnesses
* that have abstract type members
private def setParamTrackedWithAccessors(psym: Symbol, ownerTpe: Type)(using Context): Unit =
for acc <- ownerTpe.decls.lookupAll(psym.name) if acc.is(ParamAccessor) do
acc.resetFlag(PrivateLocal)
psym.setFlag(Tracked)
acc.setFlag(Tracked)

/** `psym` needs tracked if it is referenced in any of the public signatures
* of the defining class or when `psym` is a context bound witness with an
* abstract type member
*/
def needsTracked(sym: Symbol, param: ValDef)(using Context) =
!sym.is(Tracked)
&& param.hasAttachment(ContextBoundParam)
&& sym.info.memberNames(abstractTypeNameFilter).nonEmpty

/** Under x.modularity, set every context bound evidence parameter of a class to be tracked,
* provided it has a type that has an abstract type member. Reset private and local flags
* so that the parameter becomes a `val`.
def needsTracked(psym: Symbol, param: ValDef, owningSym: Symbol)(using Context) =
lazy val abstractContextBound = isContextBoundWitnessWithAbstractMembers(psym, param, owningSym)
lazy val isRefInSignatures =
psym.maybeOwner.isPrimaryConstructor
&& isReferencedInPublicSignatures(psym)
!psym.is(Tracked)
&& psym.isTerm
&& (
abstractContextBound
|| isRefInSignatures
)

/** Under x.modularity, we add `tracked` to context bound witnesses and
* explicit evidence parameters that have abstract type members
*/
private def isContextBoundWitnessWithAbstractMembers(psym: Symbol, param: ValDef, owningSym: Symbol)(using Context): Boolean =
val accessorSyms = maybeParamAccessors(owningSym, psym)
(owningSym.isClass || owningSym.isAllOf(Given | Method))
&& (param.hasAttachment(ContextBoundParam) || (psym.isOneOf(GivenOrImplicit) && !accessorSyms.forall(_.isOneOf(PrivateLocal))))
&& psym.info.memberNames(abstractTypeNameFilter).nonEmpty

extension (sym: Symbol)
private def infoWithForceNonInferingCompleter(using Context): Type = sym.infoOrCompleter match
case tpe: LazyType if tpe.isExplicit => sym.info
case tpe if sym.isType => sym.info
case info => info

/** Under x.modularity, we add `tracked` to term parameters whose types are
* referenced in public signatures of the defining class
*/
private def isReferencedInPublicSignatures(sym: Symbol)(using Context): Boolean =
val owner = sym.maybeOwner.maybeOwner
val accessorSyms = maybeParamAccessors(owner, sym)
def checkOwnerMemberSignatures(owner: Symbol): Boolean =
owner.infoOrCompleter match
case info: ClassInfo =>
info.decls.filter(_.isPublic)
.filter(_ != sym.maybeOwner)
.exists { decl =>
tpeContainsSymbolRef(decl.infoWithForceNonInferingCompleter, accessorSyms)
}
case _ => false
checkOwnerMemberSignatures(owner)

/** Check if any of syms are referenced in tpe */
private def tpeContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean =
val acc = new ExistsAccumulator(
{ tpe => tpe.termSymbol.exists && syms.contains(tpe.termSymbol) },
StopAt.Static,
forceLazy = false
) {
override def apply(acc: Boolean, tpe: Type): Boolean = super.apply(acc, tpe.safeDealias)
}
acc(false, tpe)

private def maybeParamAccessors(owner: Symbol, sym: Symbol)(using Context): List[Symbol] = owner.infoOrCompleter match
case info: ClassInfo =>
info.decls.lookupAll(sym.name).filter(d => d.is(ParamAccessor)).toList
case _ => List(sym)

/** Under x.modularity, set every context bound evidence parameter or public
* using parameter of a class to be tracked, provided it has a type that has
* an abstract type member. Reset private and local flags so that the
* parameter becomes a `val`.
*/
def setTracked(param: ValDef)(using Context): Unit =
def setTrackedConstrParam(param: ValDef)(using Context): Unit =
val sym = symbolOfTree(param)
sym.maybeOwner.maybeOwner.infoOrCompleter match
case info: ClassInfo if needsTracked(sym, param) =>
case info: ClassInfo
if !sym.is(Tracked) && isContextBoundWitnessWithAbstractMembers(sym, param, sym.maybeOwner.maybeOwner) =>
typr.println(i"set tracked $param, $sym: ${sym.info} containing ${sym.info.memberNames(abstractTypeNameFilter).toList}")
for acc <- info.decls.lookupAll(sym.name) if acc.is(ParamAccessor) do
acc.resetFlag(PrivateLocal)
acc.setFlag(Tracked)
sym.setFlag(Tracked)
setParamTrackedWithAccessors(sym, info)
case _ =>

def inferredResultType(
mdef: ValOrDefDef,
sym: Symbol,
paramss: List[List[Symbol]],
paramFn: Type => Type,
fallbackProto: Type
)(using Context): Type =
mdef: ValOrDefDef,
sym: Symbol,
paramss: List[List[Symbol]],
paramFn: Type => Type,
fallbackProto: Type
)(using Context): Type =
/** Is this member tracked? This is true if it is marked as `tracked` or if
* it overrides a `tracked` member. To account for the later, `isTracked`
* is overriden to `true` as a side-effect of computing `inherited`.
Expand Down
41 changes: 40 additions & 1 deletion docs/_docs/reference/experimental/modularity.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,46 @@ This works as it should now. Without the addition of `tracked` to the
parameter of `SetFunctor` typechecking would immediately lose track of
the element type `T` after an `add`, and would therefore fail.

**Discussion**
**Syntax Change**

```
ClsParam ::= {Annotation} [{Modifier | ‘tracked’} (‘val’ | ‘var’)] Param
```

The (soft) `tracked` modifier is only allowed for `val` parameters of classes.

### Tracked inference

In some cases `tracked` can be infered and doesn't have to be written
explicitly. A common such case is when a class parameter is referenced in the
signatures of the public members of the class. e.g.
```scala 3
class OrdSet(val ord: Ordering) {
type Set = List[ord.T]
def empty: Set = Nil

implicit class helper(s: Set) {
def add(x: ord.T): Set = x :: remove(x)
def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0)
def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0)
}
}
```
In the example above, `ord` is referenced in the signatures of the public
members of `OrdSet`, so a `tracked` modifier will be inserted automatically.

Another common case is when a context bound has an associated type (i.e. an abstract type member) e.g.
```scala 3
trait TC:
type Self
type T

class Klass[A: {TC as tc}]
```

Here, `tc` is a context bound with an associated type `T`, so `tracked` will be inferred for `tc`.

### Discussion

Since `tracked` is so useful, why not assume it by default? First, `tracked` makes sense only for `val` parameters. If a class parameter is not also a field declared using `val` then there's nothing to refine in the constructor result type. One could think of at least making all `val` parameters tracked by default, but that would be a backwards incompatible change. For instance, the following code would break:

Expand Down
18 changes: 18 additions & 0 deletions tests/neg/infer-tracked-explicit-witness.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import scala.language.experimental.modularity

trait T:
type Self
type X
def foo: Self

class D[C](using wd: C is T)
class E(using we: Int is T)

def Test =
given w: Int is T:
def foo: Int = 42
type X = Long
val d = D(using w)
summon[d.wd.X =:= Long] // error
val e = E(using w)
summon[e.we.X =:= Long] // error
34 changes: 34 additions & 0 deletions tests/pos/infer-tracked-1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import scala.language.experimental.modularity
import scala.language.future

trait Ordering {
type T
def compare(t1:T, t2: T): Int
}

class SetFunctor(val ord: Ordering) {
type Set = List[ord.T]
def empty: Set = Nil

implicit class helper(s: Set) {
def add(x: ord.T): Set = x :: remove(x)
def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0)
def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0)
}
}

object Test {
val orderInt = new Ordering {
type T = Int
def compare(t1: T, t2: T): Int = t1 - t2
}

val IntSet = new SetFunctor(orderInt)
import IntSet.*

def main(args: Array[String]) = {
val set = IntSet.empty.add(6).add(8).add(23)
assert(!set.member(7))
assert(set.member(8))
}
}
18 changes: 18 additions & 0 deletions tests/pos/infer-tracked-explicit-witness.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import scala.language.experimental.modularity

trait T:
type Self
type X
def foo: Self

class D[C](using val wd: C is T)
class E(using val we: Int is T)

def Test =
given w: Int is T:
def foo: Int = 42
type X = Long
val d = D(using w)
summon[d.wd.X =:= Long]
val e = E(using w)
summon[e.we.X =:= Long]
8 changes: 8 additions & 0 deletions tests/pos/infer-tracked-parent-refinements.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import scala.language.experimental.modularity
import scala.language.future

trait WithValue { type Value = Int }

case class Year(value: Int) extends WithValue {
val x: Value = 2
}
65 changes: 65 additions & 0 deletions tests/pos/infer-tracked-parsercombinators-expanded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import scala.language.experimental.modularity
import scala.language.future

import collection.mutable

/// A parser combinator.
trait Combinator[T]:

/// The context from which elements are being parsed, typically a stream of tokens.
type Context
/// The element being parsed.
type Element

extension (self: T)
/// Parses and returns an element from `context`.
def parse(context: Context): Option[Element]
end Combinator

final case class Apply[C, E](action: C => Option[E])
final case class Combine[A, B](first: A, second: B)

object test:

class apply[C, E] extends Combinator[Apply[C, E]]:
type Context = C
type Element = E
extension(self: Apply[C, E])
def parse(context: C): Option[E] = self.action(context)

def apply[C, E]: apply[C, E] = new apply[C, E]

class combine[A, B](
val f: Combinator[A],
val s: Combinator[B] { type Context = f.Context}
) extends Combinator[Combine[A, B]]:
type Context = f.Context
type Element = (f.Element, s.Element)
extension(self: Combine[A, B])
def parse(context: Context): Option[Element] = ???

def combine[A, B](
_f: Combinator[A],
_s: Combinator[B] { type Context = _f.Context}
) = new combine[A, B](_f, _s)
// cast is needed since the type of new combine[A, B](_f, _s)
// drops the required refinement.

extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
if buf.isEmpty then None
else try Some(buf.head) finally buf.remove(0)

@main def hello: Unit = {
val source = (0 to 10).toList
val stream = source.to(mutable.ListBuffer)

val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
val m = Combine(n, n)

val c = combine(
apply[mutable.ListBuffer[Int], Int],
apply[mutable.ListBuffer[Int], Int]
)
val r = c.parse(m)(stream) // was type mismatch, now OK
val rc: Option[(Int, Int)] = r
}
Loading

0 comments on commit 019d203

Please sign in to comment.