Skip to content

Commit

Permalink
Optimize failures, retries and stack trace generation (zio#9020)
Browse files Browse the repository at this point in the history
* Optimize failures and retries

* Fix compiling errors

* PR comments

* Improve performance of methods on Cause

* PR comments

* More PR comments
  • Loading branch information
kyri-petrou authored Jul 19, 2024
1 parent 81f3f8b commit 0e799a9
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 117 deletions.
165 changes: 116 additions & 49 deletions core/shared/src/main/scala/zio/Cause.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
/**
* Adds the specified annotations.
*/
final def annotated(annotations: Map[String, String]): Cause[E] =
mapAnnotations(_ ++ annotations)
final def annotated(anns: Map[String, String]): Cause[E] =
if (anns.isEmpty) self else mapAnnotations(_ ++ anns)

/**
* Grabs the annotations from the cause.
Expand All @@ -54,6 +54,24 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
case (z, _) => z
}

private[zio] final def applyAll(
trace: StackTrace,
spans: List[LogSpan],
annotations: Map[String, String]
): Cause[E] = {
val isEmptyTrace = trace.isEmpty
val isEmptySpans = spans.isEmpty
val isEmptyAnns = annotations.isEmpty

if (isEmptyTrace && isEmptySpans && isEmptyAnns) self
else
mapAll(
if (isEmptyTrace) ZIO.identityFn else _ ++ trace,
if (isEmptySpans) ZIO.identityFn else _ ::: spans,
if (isEmptyAnns) ZIO.identityFn else _ ++ annotations
)
}

/**
* Maps the error value of this cause to the specified constant value.
*/
Expand All @@ -73,9 +91,7 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
*/
final def defects: List[Throwable] =
self
.foldLeft(List.empty[Throwable]) { case (z, Die(v, _)) =>
v :: z
}
.foldLeft(List.empty[Throwable]) { case (z, Die(v, _)) => v :: z }
.reverse

/**
Expand Down Expand Up @@ -145,21 +161,19 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
*/
final def find[Z](f: PartialFunction[Cause[E], Z]): Option[Z] = {
@tailrec
def loop(cause: Cause[E], stack: List[Cause[E]]): Option[Z] =
f.lift(cause) match {
case Some(z) => Some(z)
case None =>
cause match {
case Then(left, right) => loop(left, right :: stack)
case Both(left, right) => loop(left, right :: stack)
case Stackless(cause, _) => loop(cause, stack)
case _ =>
stack match {
case hd :: tl => loop(hd, tl)
case Nil => None
}
}
def loop(cause: Cause[E], stack: List[Cause[E]]): Option[Z] = {
val out = f.lift(cause)
if (out.isDefined) out
else {
cause match {
case Then(left, right) => loop(left, right :: stack)
case Both(left, right) => loop(left, right :: stack)
case Stackless(cause, _) => loop(cause, stack)
case _ if stack.nonEmpty => loop(stack.head, stack.tail)
case _ => None
}
}
}
loop(self, Nil)
}

Expand All @@ -170,7 +184,7 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
final def flatMap[E2](f: E => Cause[E2]): Cause[E2] =
foldLog[Cause[E2]](
Empty,
(e, trace, spans, annotations) => f(e).traced(trace).spanned(spans).annotated(annotations),
(e, trace, spans, annotations) => f(e).applyAll(trace, spans, annotations),
(t, trace, spans, annotations) => Die(t, trace, spans, annotations),
(fiberId, trace, spans, annotations) => Interrupt(fiberId, trace, spans, annotations)
)(
Expand Down Expand Up @@ -245,26 +259,28 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
stacklessCase(context, cause, stackless) :: causes
}
}
loop(List(self), List.empty).head
if (self eq Empty) empty(context)
else loop(List(self), List.empty).head
}

/**
* Folds over the cause to statefully compute a value.
*/
final def foldLeft[Z](z: Z)(f: PartialFunction[(Z, Cause[E]), Z]): Z = {
@tailrec
def loop(z: Z, cause: Cause[E], stack: List[Cause[E]]): Z =
(f.applyOrElse[(Z, Cause[E]), Z](z -> cause, _ => z), cause) match {
case (z, Then(left, right)) => loop(z, left, right :: stack)
case (z, Both(left, right)) => loop(z, left, right :: stack)
case (z, Stackless(cause, _)) => loop(z, cause, stack)
case (z, _) =>
stack match {
case hd :: tl => loop(z, hd, tl)
case Nil => z
}
def loop(z0: Z, cause: Cause[E], stack: List[Cause[E]]): Z = {
val z = f.applyOrElse[(Z, Cause[E]), Z](z0 -> cause, _._1)
cause match {
case Then(left, right) => loop(z, left, right :: stack)
case Both(left, right) => loop(z, left, right :: stack)
case Stackless(cause, _) => loop(z, cause, stack)
case _ if stack.nonEmpty => loop(z, stack.head, stack.tail)
case _ => z
}
loop(z, self, Nil)
}

if (self eq Empty) z
else loop(z, self, Nil)
}

final def foldLog[Z](
Expand Down Expand Up @@ -329,12 +345,12 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
@tailrec
def loop(cause: Cause[E], stack: List[Cause[E]]): Boolean =
cause match {
case Fail(value, trace) => false
case Die(value, trace) => false
case Interrupt(fiberId, trace) => false
case Then(left, right) => loop(left, right :: stack)
case Both(left, right) => loop(left, right :: stack)
case Stackless(cause, _) => loop(cause, stack)
case _: Fail[?] => false
case _: Die => false
case _: Interrupt => false
case Then(left, right) => loop(left, right :: stack)
case Both(left, right) => loop(left, right :: stack)
case Stackless(cause, _) => loop(cause, stack)
case _ =>
stack match {
case hd :: tl => loop(hd, tl)
Expand Down Expand Up @@ -364,9 +380,9 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
*/
final def isTraced: Boolean =
find {
case Die(_, trace) if trace != StackTrace.none => ()
case Fail(_, trace) if trace != StackTrace.none => ()
case Interrupt(_, trace) if trace != StackTrace.none => ()
case Die(_, trace) if !trace.isEmpty => ()
case Fail(_, trace) if !trace.isEmpty => ()
case Interrupt(_, trace) if !trace.isEmpty => ()
}.isDefined

/**
Expand Down Expand Up @@ -414,9 +430,25 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
/**
* Transforms the error type of this cause with the specified function.
*/
final def map[E1](f: E => E1): Cause[E1] =
def map[E1](f: E => E1): Cause[E1] =
flatMap(e => Fail(f(e), StackTrace.none))

protected def mapAll(
ft: StackTrace => StackTrace,
fs: List[LogSpan] => List[LogSpan],
fa: Map[String, String] => Map[String, String]
): Cause[E] =
foldLog[Cause[E]](
Empty,
(e, trace, spans, annotations) => Fail(e, ft(trace), fs(spans), fa(annotations)),
(t, trace, spans, annotations) => Die(t, ft(trace), fs(spans), fa(annotations)),
(fiberId, trace, spans, annotations) => Interrupt(fiberId, ft(trace), fs(spans), fa(annotations))
)(
(left, right) => Then(left, right),
(left, right) => Both(left, right),
(cause, stackless) => Stackless(cause, stackless)
)

/**
* Transforms the annotations in this cause with the specified function.
*/
Expand Down Expand Up @@ -517,7 +549,7 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
* Adds the specified spans.
*/
def spanned(spans: List[LogSpan]): Cause[E] =
mapSpans(_ ::: spans)
if (spans.isEmpty) self else mapSpans(_ ::: spans)

/**
* Grabs a complete, linearized list of log spans for the cause. Note: This
Expand Down Expand Up @@ -638,7 +670,7 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
* Adds the specified execution trace to traces.
*/
final def traced(trace: StackTrace): Cause[E] =
mapTrace(_ ++ trace)
if (trace.isEmpty) self else mapTrace(_ ++ trace)

/**
* Returns a homogenized list of failures for the cause. This homogenization
Expand Down Expand Up @@ -703,7 +735,8 @@ sealed abstract class Cause[+E] extends Product with Serializable { self =>
}
}

loop(self :: Nil, FiberId.None, false, Nil).reverse
if (self eq Empty) Nil
else loop(self :: Nil, FiberId.None, false, Nil).reverse
}

/**
Expand Down Expand Up @@ -848,11 +881,27 @@ object Cause extends Serializable {
(causeOption, stackless) => causeOption.map(Stackless(_, stackless))
)

case object Empty extends Cause[Nothing]
case object Empty extends Cause[Nothing] { self =>
override def map[E1](f: Nothing => E1): Cause[E1] = self
override protected def mapAll(
ft: StackTrace => StackTrace,
fs: List[LogSpan] => List[LogSpan],
fa: Map[String, String] => Map[String, String]
): Cause[Nothing] = self
}

sealed case class Fail[+E](value: E, override val trace: StackTrace) extends Cause[E] { self =>
override def annotations: Map[String, String] = Map.empty
override def spans: List[LogSpan] = List.empty

final override def map[E1](f: E => E1): Cause[E1] =
Fail(f(value), trace, spans, annotations)

final override protected def mapAll(
ft: StackTrace => StackTrace,
fs: List[LogSpan] => List[LogSpan],
fa: Map[String, String] => Map[String, String]
): Cause[E] = Fail(value, ft(trace), fs(spans), fa(annotations))
}

object Fail {
Expand All @@ -866,6 +915,14 @@ object Cause extends Serializable {
sealed case class Die(value: Throwable, override val trace: StackTrace) extends Cause[Nothing] { self =>
override def annotations: Map[String, String] = Map.empty
override def spans: List[LogSpan] = List.empty

final override def map[E1](f: Nothing => E1): Cause[E1] = self

final override protected def mapAll(
ft: StackTrace => StackTrace,
fs: List[LogSpan] => List[LogSpan],
fa: Map[String, String] => Map[String, String]
): Cause[Nothing] = Die(value, ft(trace), fs(spans), fa(annotations))
}

object Die extends AbstractFunction2[Throwable, StackTrace, Die] {
Expand All @@ -879,6 +936,14 @@ object Cause extends Serializable {
sealed case class Interrupt(fiberId: FiberId, override val trace: StackTrace) extends Cause[Nothing] { self =>
override def annotations: Map[String, String] = Map.empty
override def spans: List[LogSpan] = List.empty

final override def map[E1](f: Nothing => E1): Cause[E1] = self

final override protected def mapAll(
ft: StackTrace => StackTrace,
fs: List[LogSpan] => List[LogSpan],
fa: Map[String, String] => Map[String, String]
): Cause[Nothing] = Interrupt(fiberId, ft(trace), fs(spans), fa(annotations))
}

object Interrupt extends AbstractFunction2[FiberId, StackTrace, Interrupt] {
Expand Down Expand Up @@ -919,7 +984,7 @@ object Cause extends Serializable {
else loop(leftSequential, rightSequential)
}

loop(List(left), List(right))
(left eq right) || loop(List(left), List(right))
}

/**
Expand All @@ -941,7 +1006,8 @@ object Cause extends Serializable {
else loop(sequential, updated)
}

loop(List(c), List.empty)
if (c eq Empty) Nil
else loop(List(c), List.empty)
}

/**
Expand Down Expand Up @@ -984,7 +1050,8 @@ object Cause extends Serializable {
else loop(stack.head, stack.tail, parallel, sequential)
}

loop(c, List.empty, Set.empty, List.empty)
if (c eq Empty) (Set.empty, Nil)
else loop(c, List.empty, Set.empty, List.empty)
}

private case class FiberTrace(trace: String) extends Throwable(null, null, true, false) {
Expand Down
5 changes: 3 additions & 2 deletions core/shared/src/main/scala/zio/Schedule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -487,15 +487,16 @@ trait Schedule[-Env, -In, +Out] extends Serializable { self =>
now <- Clock.currentDateTime
dec <- self.step(now, in, state)
v <- dec match {
case (state, out, Done) => ref.set((Some(out), state)) *> ZIO.fail(None)
case (state, out, Done) =>
ref.set((Some(out), state)) *> Exit.failNone.asInstanceOf[Exit[None.type, Out]]
case (state, out, Continue(interval)) =>
ref.set((Some(out), state)) *> ZIO.sleep(Duration.fromInterval(now, interval.start)) as out
}
} yield v

val last = ref.get.flatMap {
case (None, _) => ZIO.fail(new NoSuchElementException("There is no value left"))
case (Some(b), _) => ZIO.succeed(b)
case (Some(b), _) => Exit.succeed(b)
}

val reset = ref.set((None, self.initial))
Expand Down
8 changes: 5 additions & 3 deletions core/shared/src/main/scala/zio/StackTrace.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ package zio

import zio.stacktracer.TracingImplicits.disableAutoTrace

import scala.annotation.tailrec

final case class StackTrace(
fiberId: FiberId,
stackTrace: Chunk[Trace]
) { self =>

def ++(that: StackTrace): StackTrace =
StackTrace(self.fiberId combine that.fiberId, self.stackTrace ++ that.stackTrace)
if ((self eq that) || self.isEmpty) that
else if (that.isEmpty) self
else StackTrace(self.fiberId combine that.fiberId, self.stackTrace ++ that.stackTrace)

def isEmpty: Boolean = (fiberId eq FiberId.None) && stackTrace.isEmpty

def size: Int = stackTrace.length

Expand Down
Loading

0 comments on commit 0e799a9

Please sign in to comment.