diff --git a/core/src/main/scala/cats/Traverse.scala b/core/src/main/scala/cats/Traverse.scala index b380f152297..dc85704e4c3 100644 --- a/core/src/main/scala/cats/Traverse.scala +++ b/core/src/main/scala/cats/Traverse.scala @@ -135,12 +135,21 @@ import simulacrum.typeclass override def map[A, B](fa: F[A])(f: A => B): F[B] = traverse[Id, A, B](fa)(f) + /** + * Akin to [[map]], but allows to keep track of a state value + * when calling the function. + */ + def mapAccumulate[A, S, B](fa: F[A], init: S)(f: (S, A) => (S, B)): (S, F[B]) = + traverse(fa) { a => + State(s => f(s, a)) + }.run(init).value + /** * Akin to [[map]], but also provides the value's index in structure * F when calling the function. */ def mapWithIndex[A, B](fa: F[A])(f: (A, Int) => B): F[B] = - traverse(fa)(a => State((s: Int) => (s + 1, f(a, s)))).runA(0).value + mapAccumulate(fa, 0)((i, a) => (i + 1) -> f(a, i))._2 /** * Akin to [[traverse]], but also provides the value's index in @@ -206,10 +215,14 @@ object Traverse { typeClassInstance.sequence[G, B](self.asInstanceOf[F[G[B]]]) def flatSequence[G[_], B](implicit ev$1: A <:< G[F[B]], G: Applicative[G], F: FlatMap[F]): G[F[B]] = typeClassInstance.flatSequence[G, B](self.asInstanceOf[F[G[F[B]]]])(G, F) - def mapWithIndex[B](f: (A, Int) => B): F[B] = typeClassInstance.mapWithIndex[A, B](self)(f) + def mapAccumulate[S, B](init: S)(f: (S, A) => (S, B)): (S, F[B]) = + typeClassInstance.mapAccumulate[A, S, B](self, init)(f) + def mapWithIndex[B](f: (A, Int) => B): F[B] = + typeClassInstance.mapWithIndex[A, B](self)(f) def traverseWithIndexM[G[_], B](f: (A, Int) => G[B])(implicit G: Monad[G]): G[F[B]] = typeClassInstance.traverseWithIndexM[G, A, B](self)(f)(G) - def zipWithIndex: F[(A, Int)] = typeClassInstance.zipWithIndex[A](self) + def zipWithIndex: F[(A, Int)] = + typeClassInstance.zipWithIndex[A](self) } trait AllOps[F[_], A] extends Ops[F, A] diff --git a/laws/src/main/scala/cats/laws/TraverseLaws.scala b/laws/src/main/scala/cats/laws/TraverseLaws.scala index 61bce31df46..46ba86761a3 100644 --- a/laws/src/main/scala/cats/laws/TraverseLaws.scala +++ b/laws/src/main/scala/cats/laws/TraverseLaws.scala @@ -106,6 +106,17 @@ trait TraverseLaws[F[_]] extends FunctorLaws[F] with FoldableLaws[F] with Unorde first <-> traverseFirst } + def mapAccumulateRef[A, S, B](fa: F[A], init: S, f: (S, A) => (S, B)): IsEq[(S, F[B])] = { + val lhs = F.mapAccumulate(fa, init)(f) + + val rhsState = F.traverse(fa) { a => + State(s => f(s, a)) + } + val rhs = rhsState.run(init).value + + lhs <-> rhs + } + def mapWithIndexRef[A, B](fa: F[A], f: (A, Int) => B): IsEq[F[B]] = { val lhs = F.mapWithIndex(fa)(f) val rhs = F.traverse(fa)(a => State((s: Int) => (s + 1, f(a, s)))).runA(0).value diff --git a/laws/src/main/scala/cats/laws/discipline/TraverseTests.scala b/laws/src/main/scala/cats/laws/discipline/TraverseTests.scala index 2cbbdc7d1c8..7bb28c7c496 100644 --- a/laws/src/main/scala/cats/laws/discipline/TraverseTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/TraverseTests.scala @@ -59,11 +59,12 @@ trait TraverseTests[F[_]] extends FunctorTests[F] with FoldableTests[F] with Uno EqYFM: Eq[Y[F[M]]], EqOptionA: Eq[Option[A]] ): RuleSet = { - implicit def EqXFBYFB: Eq[(X[F[B]], Y[F[B]])] = + implicit val EqXFBYFB: Eq[(X[F[B]], Y[F[B]])] = new Eq[(X[F[B]], Y[F[B]])] { override def eqv(x: (X[F[B]], Y[F[B]]), y: (X[F[B]], Y[F[B]])): Boolean = EqXFB.eqv(x._1, y._1) && EqYFB.eqv(x._2, y._2) } + new RuleSet { def name: String = "traverse" def bases: Seq[(String, RuleSet)] = Nil @@ -76,6 +77,7 @@ trait TraverseTests[F[_]] extends FunctorTests[F] with FoldableTests[F] with Uno "traverse traverseTap" -> forAll(laws.traverseTap[B, M, X] _), "traverse derive foldMap" -> forAll(laws.foldMapDerived[A, M] _), "traverse order consistency" -> forAll(laws.traverseOrderConsistent[A] _), + "traverse ref mapAccumulate" -> forAll(laws.mapAccumulateRef[A, M, C] _), "traverse ref mapWithIndex" -> forAll(laws.mapWithIndexRef[A, C] _), "traverse ref traverseWithIndexM" -> forAll(laws.traverseWithIndexMRef[Option, A, C] _), "traverse ref zipWithIndex" -> forAll(laws.zipWithIndexRef[A, C] _) diff --git a/tests/shared/src/test/scala/cats/tests/TraverseSuite.scala b/tests/shared/src/test/scala/cats/tests/TraverseSuite.scala index 5cc9d5dc68f..2ab357d5b8a 100644 --- a/tests/shared/src/test/scala/cats/tests/TraverseSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/TraverseSuite.scala @@ -36,6 +36,19 @@ abstract class TraverseSuite[F[_]: Traverse](name: String)(implicit ArbFInt: Arb } } + test(s"Traverse[$name].mapAccumulate") { + forAll { (fa: F[Int], init: Int, fn: ((Int, Int)) => (Int, Int)) => + val lhs = fa.mapAccumulate(init)((s, a) => fn((s, a))) + + val rhs = fa.foldLeft((init, List.empty[Int])) { case ((s1, acc), a) => + val (b, s2) = fn((a, s1)) + (s2, b :: acc) + } + + assert(lhs._2.toList === rhs._2.reverse) + } + } + test(s"Traverse[$name].mapWithIndex") { forAll { (fa: F[Int], fn: ((Int, Int)) => Int) => assert(fa.mapWithIndex((a, i) => fn((a, i))).toList === (fa.toList.zipWithIndex.map(fn)))