diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowRecoverWithSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowRecoverWithSpec.scala index 022ac54bc0..db582146ff 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowRecoverWithSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowRecoverWithSpec.scala @@ -23,6 +23,8 @@ import pekko.stream.testkit.StreamSpec import pekko.stream.testkit.Utils._ import pekko.stream.testkit.scaladsl.TestSink +import scala.concurrent.Future + @nowarn // tests deprecated APIs class FlowRecoverWithSpec extends StreamSpec { @@ -62,6 +64,32 @@ class FlowRecoverWithSpec extends StreamSpec { .expectComplete() } + "recover with a completed future source" in { + Source.failed(ex) + .recoverWith { case _: Throwable => Source.future(Future.successful(3)) } + .runWith(TestSink[Int]()) + .request(1) + .expectNext(3) + .expectComplete() + } + + "recover with a failed future source" in { + Source.failed(ex) + .recoverWith { case _: Throwable => Source.future(Future.failed(ex)) } + .runWith(TestSink[Int]()) + .request(1) + .expectError(ex) + } + + "recover with a java stream source" in { + Source.failed(ex) + .recoverWith { case _: Throwable => Source.fromJavaStream(() => java.util.stream.Stream.of(1, 2, 3)) } + .runWith(TestSink[Int]()) + .request(3) + .expectNextN(1 to 3) + .expectComplete() + } + "recover with single source" in { Source(1 to 4) .map { a => diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala index 737116fd22..0e4133b561 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala @@ -24,7 +24,6 @@ import scala.concurrent.duration.{ FiniteDuration, _ } import scala.util.{ Failure, Success, Try } import scala.util.control.{ NoStackTrace, NonFatal } import scala.util.control.Exception.Catcher - import org.apache.pekko import pekko.actor.{ ActorRef, Terminated } import pekko.annotation.InternalApi @@ -36,9 +35,16 @@ import pekko.stream.Attributes.{ InputBuffer, LogLevels } import pekko.stream.Attributes.SourceLocation import pekko.stream.OverflowStrategies._ import pekko.stream.Supervision.Decider -import pekko.stream.impl.{ Buffer => BufferImpl, ContextPropagation, ReactiveStreamsCompliance, TraversalBuilder } +import pekko.stream.impl.{ + Buffer => BufferImpl, + ContextPropagation, + FailedSource, + JavaStreamSource, + ReactiveStreamsCompliance, + TraversalBuilder +} import pekko.stream.impl.Stages.DefaultAttributes -import pekko.stream.impl.fusing.GraphStages.SimpleLinearGraphStage +import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SimpleLinearGraphStage, SingleSource } import pekko.stream.scaladsl.{ DelayStrategy, Source } import pekko.stream.stage._ import pekko.util.{ unused, ConstantFun, OptionVal } @@ -2173,12 +2179,28 @@ private[pekko] object TakeWithin { case _: NotApplied.type => failStage(ex) case source: Graph[SourceShape[T] @unchecked, M @unchecked] if TraversalBuilder.isEmptySource(source) => completeStage() - case other: Graph[SourceShape[T] @unchecked, M @unchecked] => - TraversalBuilder.getSingleSource(other) match { - case OptionVal.Some(singleSource) => - emit(out, singleSource.elem.asInstanceOf[T], () => completeStage()) + case source: Graph[SourceShape[T] @unchecked, M @unchecked] => + TraversalBuilder.getValuePresentedSource(source) match { + case OptionVal.Some(graph) => graph match { + case singleSource: SingleSource[T @unchecked] => emit(out, singleSource.elem, () => completeStage()) + case failed: FailedSource[T @unchecked] => failStage(failed.failure) + case futureSource: FutureSource[T @unchecked] => futureSource.future.value match { + case Some(Success(elem)) => emit(out, elem, () => completeStage()) + case Some(Failure(ex)) => failStage(ex) + case None => + switchTo(source) + attempt += 1 + } + case iterableSource: IterableSource[T @unchecked] => + emitMultiple(out, iterableSource.elements, () => completeStage()) + case javaStreamSource: JavaStreamSource[T @unchecked, _] => + emitMultiple(out, javaStreamSource.open().iterator(), () => completeStage()) + case _ => + switchTo(source) + attempt += 1 + } case _ => - switchTo(other) + switchTo(source) attempt += 1 } case _ => throw new IllegalStateException() // won't happen, compiler exhaustiveness check pleaser