Skip to content

Commit

Permalink
Allow providing alternative implementations for generating FiberIds (
Browse files Browse the repository at this point in the history
…zio#8778)

* Allow providing alternative implementations for generating `FiberId`s

* Revert bootstrap override

* Change usages of deprecated `FiberId.make` method

* Rename `Random` to `Live`
  • Loading branch information
kyri-petrou authored Apr 23, 2024
1 parent 17896de commit 5d72fd0
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 17 deletions.
14 changes: 7 additions & 7 deletions core-tests/shared/src/test/scala/zio/FiberRefsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ object FiberRefsSpec extends ZIOBaseSpec {
} yield assertTrue(value)
} +
test("interruptedCause") {
val parent = Unsafe.unsafe(implicit unsafe => FiberId.make(Trace.empty))
val child = Unsafe.unsafe(implicit unsafe => FiberId.make(Trace.empty))
val parent = Unsafe.unsafe(implicit unsafe => FiberId.Gen.Live.make(Trace.empty))
val child = Unsafe.unsafe(implicit unsafe => FiberId.Gen.Live.make(Trace.empty))

val parentFiberRefs = FiberRefs.empty
val childFiberRefs = parentFiberRefs.updatedAs(child)(FiberRef.interruptedCause, Cause.interrupt(parent))
Expand All @@ -33,7 +33,7 @@ object FiberRefsSpec extends ZIOBaseSpec {
*/
suite("optimizations") {
implicit val unsafe: Unsafe = Unsafe.unsafe
val fiberId = FiberId.make(implicitly)
val fiberId = FiberId.Gen.Live.make(implicitly)

val fiberRefs = List(
FiberRef.unsafe.make[Int](0, join = (a, b) => a + b),
Expand All @@ -53,24 +53,24 @@ object FiberRefsSpec extends ZIOBaseSpec {
} +
test("forkAs returns the same map if no fibers are modified during fork") {
val fr = makeFiberRefs(fiberRefs.take(4))
val isEq = fr.forkAs(FiberId.make(implicitly)) eq fr
val isEq = fr.forkAs(FiberId.Gen.Live.make(implicitly)) eq fr
assertTrue(isEq)
} +
test("joinAs returns the same map when fiber refs are unchanged after joining") {
val fr1 = makeFiberRefs(fiberRefs.drop(1))
val fr2 = makeFiberRefs(fiberRefs.drop(2))
val isEq = fr1.joinAs(FiberId.make(implicitly))(fr2) eq fr1
val isEq = fr1.joinAs(FiberId.Gen.Live.make(implicitly))(fr2) eq fr1
assertTrue(isEq)
} +
// Sanity checks
test("forkAs returns a different map if forked fibers are modified") {
val fr = makeFiberRefs(fiberRefs)
val isEq = fr.forkAs(FiberId.make(implicitly)) ne fr
val isEq = fr.forkAs(FiberId.Gen.Live.make(implicitly)) ne fr
assertTrue(isEq)
} +
test("joinAs returns a different map when fiber refs are changed after joining") {
val fr1, fr2 = makeFiberRefs(fiberRefs)
val isEq = fr1.joinAs(FiberId.make(implicitly))(fr2) ne fr1
val isEq = fr1.joinAs(FiberId.Gen.Live.make(implicitly))(fr2) ne fr1
assertTrue(isEq)
}
}
Expand Down
7 changes: 5 additions & 2 deletions core/shared/src/main/scala/zio/Fiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,11 @@ object Fiber extends FiberPlatformSpecific {

private[zio] object Runtime {

implicit val fiberOrdering: Ordering[Fiber.Runtime[?, ?]] =
Ordering.by[Fiber.Runtime[?, ?], Long](_.id.startTimeMillis)
implicit val fiberOrdering: Ordering[Fiber.Runtime[?, ?]] = { (x, y) =>
val byTime = x.id.startTimeMillis.compare(y.id.startTimeMillis)
if (byTime == 0) x.id.id.compare(y.id.id)
else byTime
}

abstract class Internal[+E, +A] extends Runtime[E, A]
}
Expand Down
44 changes: 40 additions & 4 deletions core/shared/src/main/scala/zio/FiberId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,48 @@ object FiberId {
def apply(id: Int, startTimeSeconds: Int, location: Trace): FiberId =
Runtime(id, startTimeSeconds * 1000L, location)

private[zio] def make(location: Trace)(implicit unsafe: Unsafe): FiberId.Runtime = {
val id = ThreadLocalRandom.current().nextInt(Int.MaxValue)
FiberId.Runtime(id, java.lang.System.currentTimeMillis(), location)
}
@deprecated("use `generate` instead", "1.0.0")
private[zio] def make(location: Trace)(implicit unsafe: Unsafe): FiberId.Runtime =
Gen.Live.make(location)

private[zio] def generate(fiberRefs: FiberRefs)(location: Trace)(implicit unsafe: Unsafe): FiberId.Runtime =
fiberRefs.getOrDefault(FiberRef.currentFiberIdGenerator).make(location)

case object None extends FiberId
final case class Runtime(id: Int, startTimeMillis: Long, location: Trace) extends FiberId
final case class Composite(left: FiberId, right: FiberId) extends FiberId

private[zio] trait Gen {
def make(location: Trace)(implicit unsafe: Unsafe): FiberId.Runtime
}

private[zio] object Gen {

/**
* Generates a fiber ID where the `id` is a random integer.
*
* This is more performant than using `FiberId.Gen.Ordered`, but cannot be
* used in cases that rely on strict ordering of fibers (e.g., in zio-test)
*/
object Live extends Gen {
def make(location: Trace)(implicit unsafe: Unsafe): FiberId.Runtime = {
val id = ThreadLocalRandom.current().nextInt(Int.MaxValue)
FiberId.Runtime(id, java.lang.System.currentTimeMillis(), location)
}
}

/**
* Generates a fiber ID where the `id` is a monotonically increasing
* integer.
*
* This is less performant than generating IDs randomly, but is required for
* cases that rely on strict ordering of fibers (e.g., in zio-test)
*/
object Monotonic extends Gen {
private[this] val counter = new java.util.concurrent.atomic.AtomicInteger(0)
def make(location: Trace)(implicit unsafe: Unsafe): FiberId.Runtime =
FiberId.Runtime(counter.getAndIncrement(), java.lang.System.currentTimeMillis(), location)
}
}

}
3 changes: 3 additions & 0 deletions core/shared/src/main/scala/zio/FiberRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,9 @@ object FiberRef {
private[zio] val currentFatal: FiberRef.WithPatch[IsFatal, IsFatal.Patch] =
FiberRef.unsafe.makeIsFatal(Runtime.defaultFatal)(Unsafe.unsafe)

private[zio] val currentFiberIdGenerator: FiberRef[FiberId.Gen] =
FiberRef.unsafe.make[FiberId.Gen](FiberId.Gen.Live)(Unsafe.unsafe)

private[zio] val currentLoggers: FiberRef.WithPatch[Set[ZLogger[String, Any]], SetPatch[ZLogger[String, Any]]] =
FiberRef.unsafe.makeSet(Runtime.defaultLoggers)(Unsafe.unsafe)

Expand Down
6 changes: 4 additions & 2 deletions core/shared/src/main/scala/zio/Runtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ trait Runtime[+R] { self =>

protected abstract class UnsafeAPIV1 extends UnsafeAPI with UnsafeAPI3 {

private val fiberIdGen = self.fiberRefs.getOrDefault(FiberRef.currentFiberIdGenerator)

def fork[E, A](zio: ZIO[R, E, A])(implicit trace: Trace, unsafe: Unsafe): internal.FiberRuntime[E, A] = {
val fiber = makeFiber(zio)
fiber.startConcurrently(zio)
Expand All @@ -145,7 +147,7 @@ trait Runtime[+R] { self =>
)(implicit trace: Trace, unsafe: Unsafe): Either[internal.FiberRuntime[E, A], Exit[E, A]] = {
import internal.FiberRuntime

val fiberId = FiberId.make(trace)
val fiberId = fiberIdGen.make(trace)
val fiberRefs = self.fiberRefs.updatedAs(fiberId)(FiberRef.currentEnvironment, environment)
val fiber = FiberRuntime[E, A](fiberId, fiberRefs.forkAs(fiberId), runtimeFlags)

Expand Down Expand Up @@ -191,7 +193,7 @@ trait Runtime[+R] { self =>
private def makeFiber[E, A](
zio: ZIO[R, E, A]
)(implicit trace: Trace, unsafe: Unsafe): internal.FiberRuntime[E, A] = {
val fiberId = FiberId.make(trace)
val fiberId = fiberIdGen.make(trace)
val fiberRefs = self.fiberRefs.updatedAs(fiberId)(FiberRef.currentEnvironment, environment)
val fiber = FiberRuntime[E, A](fiberId, fiberRefs.forkAs(fiberId), runtimeFlags)

Expand Down
2 changes: 1 addition & 1 deletion core/shared/src/main/scala/zio/ZIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2598,8 +2598,8 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
parentRuntimeFlags: RuntimeFlags,
overrideScope: FiberScope
)(implicit unsafe: Unsafe): internal.FiberRuntime[E1, A] = {
val childId = FiberId.make(trace)
val parentFiberRefs = parentFiber.getFiberRefs()
val childId = FiberId.generate(parentFiberRefs)(trace)
val childFiberRefs = parentFiberRefs.forkAs(childId) // TODO: Optimize

val childFiber = internal.FiberRuntime[E1, A](childId, childFiberRefs, parentRuntimeFlags)
Expand Down
5 changes: 4 additions & 1 deletion test/shared/src/main/scala/zio/test/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,12 @@ package object test extends CompileVariants {
)
}

private def testFiberRefGen(implicit trace: Trace): ULayer[Unit] =
ZLayer.scoped(FiberRef.currentFiberIdGenerator.locallyScoped(FiberId.Gen.Monotonic))

val testEnvironment: ZLayer[Any, Nothing, TestEnvironment] = {
implicit val trace = Tracer.newTrace
liveEnvironment >>> TestEnvironment.live
liveEnvironment >>> (TestEnvironment.live ++ testFiberRefGen)
}

/**
Expand Down

0 comments on commit 5d72fd0

Please sign in to comment.