Skip to content

Commit

Permalink
Add mapAccumulate to Traverse
Browse files Browse the repository at this point in the history
  • Loading branch information
BalmungSan committed May 22, 2022
1 parent 3cb7639 commit 0fffb03
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
19 changes: 16 additions & 3 deletions core/src/main/scala/cats/Traverse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,21 @@ import simulacrum.typeclass
override def map[A, B](fa: F[A])(f: A => B): F[B] =
traverse[Id, A, B](fa)(f)

/**
* Akin to [[map]], but allows to keep track of a state value
* when calling the function.
*/
def mapAccumulate[A, S, B](fa: F[A], init: S)(f: (S, A) => (S, B)): (S, F[B]) =
traverse(fa) { a =>
State(s => f(s, a))
}.run(init).value

/**
* Akin to [[map]], but also provides the value's index in structure
* F when calling the function.
*/
def mapWithIndex[A, B](fa: F[A])(f: (A, Int) => B): F[B] =
traverse(fa)(a => State((s: Int) => (s + 1, f(a, s)))).runA(0).value
mapAccumulate(fa, 0)((i, a) => (i + 1) -> f(a, i))._2

/**
* Akin to [[traverse]], but also provides the value's index in
Expand Down Expand Up @@ -206,10 +215,14 @@ object Traverse {
typeClassInstance.sequence[G, B](self.asInstanceOf[F[G[B]]])
def flatSequence[G[_], B](implicit ev$1: A <:< G[F[B]], G: Applicative[G], F: FlatMap[F]): G[F[B]] =
typeClassInstance.flatSequence[G, B](self.asInstanceOf[F[G[F[B]]]])(G, F)
def mapWithIndex[B](f: (A, Int) => B): F[B] = typeClassInstance.mapWithIndex[A, B](self)(f)
def mapAccumulate[S, B](init: S)(f: (S, A) => (S, B)): (S, F[B]) =
typeClassInstance.mapAccumulate[A, S, B](self, init)(f)
def mapWithIndex[B](f: (A, Int) => B): F[B] =
typeClassInstance.mapWithIndex[A, B](self)(f)
def traverseWithIndexM[G[_], B](f: (A, Int) => G[B])(implicit G: Monad[G]): G[F[B]] =
typeClassInstance.traverseWithIndexM[G, A, B](self)(f)(G)
def zipWithIndex: F[(A, Int)] = typeClassInstance.zipWithIndex[A](self)
def zipWithIndex: F[(A, Int)] =
typeClassInstance.zipWithIndex[A](self)
}
trait AllOps[F[_], A]
extends Ops[F, A]
Expand Down
11 changes: 11 additions & 0 deletions laws/src/main/scala/cats/laws/TraverseLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ trait TraverseLaws[F[_]] extends FunctorLaws[F] with FoldableLaws[F] with Unorde
first <-> traverseFirst
}

def mapAccumulateRef[A, S, B](fa: F[A], init: S, f: (S, A) => (S, B)): IsEq[(S, F[B])] = {
val lhs = F.mapAccumulate(fa, init)(f)

val rhsState = F.traverse(fa) { a =>
State(s => f(s, a))
}
val rhs = rhsState.run(init).value

lhs <-> rhs
}

def mapWithIndexRef[A, B](fa: F[A], f: (A, Int) => B): IsEq[F[B]] = {
val lhs = F.mapWithIndex(fa)(f)
val rhs = F.traverse(fa)(a => State((s: Int) => (s + 1, f(a, s)))).runA(0).value
Expand Down
4 changes: 3 additions & 1 deletion laws/src/main/scala/cats/laws/discipline/TraverseTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ trait TraverseTests[F[_]] extends FunctorTests[F] with FoldableTests[F] with Uno
EqYFM: Eq[Y[F[M]]],
EqOptionA: Eq[Option[A]]
): RuleSet = {
implicit def EqXFBYFB: Eq[(X[F[B]], Y[F[B]])] =
implicit val EqXFBYFB: Eq[(X[F[B]], Y[F[B]])] =
new Eq[(X[F[B]], Y[F[B]])] {
override def eqv(x: (X[F[B]], Y[F[B]]), y: (X[F[B]], Y[F[B]])): Boolean =
EqXFB.eqv(x._1, y._1) && EqYFB.eqv(x._2, y._2)
}

new RuleSet {
def name: String = "traverse"
def bases: Seq[(String, RuleSet)] = Nil
Expand All @@ -76,6 +77,7 @@ trait TraverseTests[F[_]] extends FunctorTests[F] with FoldableTests[F] with Uno
"traverse traverseTap" -> forAll(laws.traverseTap[B, M, X] _),
"traverse derive foldMap" -> forAll(laws.foldMapDerived[A, M] _),
"traverse order consistency" -> forAll(laws.traverseOrderConsistent[A] _),
"traverse ref mapAccumulate" -> forAll(laws.mapAccumulateRef[A, M, C] _),
"traverse ref mapWithIndex" -> forAll(laws.mapWithIndexRef[A, C] _),
"traverse ref traverseWithIndexM" -> forAll(laws.traverseWithIndexMRef[Option, A, C] _),
"traverse ref zipWithIndex" -> forAll(laws.zipWithIndexRef[A, C] _)
Expand Down
13 changes: 13 additions & 0 deletions tests/shared/src/test/scala/cats/tests/TraverseSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ abstract class TraverseSuite[F[_]: Traverse](name: String)(implicit ArbFInt: Arb
}
}

test(s"Traverse[$name].mapAccumulate") {
forAll { (fa: F[Int], init: Int, fn: ((Int, Int)) => (Int, Int)) =>
val lhs = fa.mapAccumulate(init)((s, a) => fn((s, a)))

val rhs = fa.foldLeft((init, List.empty[Int])) { case ((s1, acc), a) =>
val (b, s2) = fn((a, s1))
(s2, b :: acc)
}

assert(lhs._2.toList === rhs._2.reverse)
}
}

test(s"Traverse[$name].mapWithIndex") {
forAll { (fa: F[Int], fn: ((Int, Int)) => Int) =>
assert(fa.mapWithIndex((a, i) => fn((a, i))).toList === (fa.toList.zipWithIndex.map(fn)))
Expand Down

0 comments on commit 0fffb03

Please sign in to comment.