Skip to content

Commit

Permalink
Add mapAccumulate to Traverse
Browse files Browse the repository at this point in the history
  • Loading branch information
BalmungSan committed May 21, 2022
1 parent 3cb7639 commit f707496
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 16 deletions.
46 changes: 31 additions & 15 deletions core/src/main/scala/cats/Traverse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import cats.data.State
import cats.data.StateT

import simulacrum.typeclass
import scala.annotation.implicitNotFound

/**
* Traverse, also known as Traversable.
Expand All @@ -37,6 +38,7 @@ import simulacrum.typeclass
*
* See: [[https://www.cs.ox.ac.uk/jeremy.gibbons/publications/iterator.pdf The Essence of the Iterator Pattern]]
*/
@implicitNotFound("Could not find an instance of Traverse for ${F}")
@typeclass trait Traverse[F[_]] extends Functor[F] with Foldable[F] with UnorderedTraverse[F] { self =>

/**
Expand Down Expand Up @@ -135,12 +137,24 @@ import simulacrum.typeclass
override def map[A, B](fa: F[A])(f: A => B): F[B] =
traverse[Id, A, B](fa)(f)

/**
* Akin to [[map]], but allows to keep track of a state value
* when calling the function.
*/
def mapAccumulate[A, S, B](fa: F[A])(init: S)(f: (A, S) => (B, S)): (F[B], S) =
traverse(fa) { a =>
State.get[S].flatMap { s1 =>
val (b, s2) = f(a, s1)
State.set(s2).map(_ => b)
}
}.run(init).value.swap

/**
* Akin to [[map]], but also provides the value's index in structure
* F when calling the function.
*/
def mapWithIndex[A, B](fa: F[A])(f: (A, Int) => B): F[B] =
traverse(fa)(a => State((s: Int) => (s + 1, f(a, s)))).runA(0).value
mapAccumulate(fa)(0)((a, i) => f(a, i) -> (i + 1))._1

/**
* Akin to [[traverse]], but also provides the value's index in
Expand Down Expand Up @@ -185,12 +199,11 @@ object Traverse {
object ops {
implicit def toAllTraverseOps[F[_], A](target: F[A])(implicit tc: Traverse[F]): AllOps[F, A] {
type TypeClassType = Traverse[F]
} =
new AllOps[F, A] {
type TypeClassType = Traverse[F]
val self: F[A] = target
val typeClassInstance: TypeClassType = tc
}
} = new AllOps[F, A] {
type TypeClassType = Traverse[F]
val self: F[A] = target
val typeClassInstance: TypeClassType = tc
}
}
trait Ops[F[_], A] extends Serializable {
type TypeClassType <: Traverse[F]
Expand All @@ -206,10 +219,14 @@ object Traverse {
typeClassInstance.sequence[G, B](self.asInstanceOf[F[G[B]]])
def flatSequence[G[_], B](implicit ev$1: A <:< G[F[B]], G: Applicative[G], F: FlatMap[F]): G[F[B]] =
typeClassInstance.flatSequence[G, B](self.asInstanceOf[F[G[F[B]]]])(G, F)
def mapWithIndex[B](f: (A, Int) => B): F[B] = typeClassInstance.mapWithIndex[A, B](self)(f)
def mapAccumulate[S, B](init: S)(f: (A, S) => (B, S)): (F[B], S) =
typeClassInstance.mapAccumulate[A, S, B](self)(init)(f)
def mapWithIndex[B](f: (A, Int) => B): F[B] =
typeClassInstance.mapWithIndex[A, B](self)(f)
def traverseWithIndexM[G[_], B](f: (A, Int) => G[B])(implicit G: Monad[G]): G[F[B]] =
typeClassInstance.traverseWithIndexM[G, A, B](self)(f)(G)
def zipWithIndex: F[(A, Int)] = typeClassInstance.zipWithIndex[A](self)
def zipWithIndex: F[(A, Int)] =
typeClassInstance.zipWithIndex[A](self)
}
trait AllOps[F[_], A]
extends Ops[F, A]
Expand All @@ -221,12 +238,11 @@ object Traverse {
trait ToTraverseOps extends Serializable {
implicit def toTraverseOps[F[_], A](target: F[A])(implicit tc: Traverse[F]): Ops[F, A] {
type TypeClassType = Traverse[F]
} =
new Ops[F, A] {
type TypeClassType = Traverse[F]
val self: F[A] = target
val typeClassInstance: TypeClassType = tc
}
} = new Ops[F, A] {
type TypeClassType = Traverse[F]
val self: F[A] = target
val typeClassInstance: TypeClassType = tc
}
}
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseOps
Expand Down
14 changes: 14 additions & 0 deletions laws/src/main/scala/cats/laws/TraverseLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ trait TraverseLaws[F[_]] extends FunctorLaws[F] with FoldableLaws[F] with Unorde
first <-> traverseFirst
}

def mapAccumulateRef[A, S, B](fa: F[A], init: S, f: (A, S) => (B, S)): IsEq[(F[B], S)] = {
val lhs = F.mapAccumulate(fa)(init)(f)

val rhsState = F.traverse(fa) { a =>
State.get[S].flatMap { s1 =>
val (b, s2) = f(a, s1)
State.set(s2).map(_ => b)
}
}
val rhs = rhsState.run(init).value.swap

lhs <-> rhs
}

def mapWithIndexRef[A, B](fa: F[A], f: (A, Int) => B): IsEq[F[B]] = {
val lhs = F.mapWithIndex(fa)(f)
val rhs = F.traverse(fa)(a => State((s: Int) => (s + 1, f(a, s)))).runA(0).value
Expand Down
4 changes: 3 additions & 1 deletion laws/src/main/scala/cats/laws/discipline/TraverseTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ trait TraverseTests[F[_]] extends FunctorTests[F] with FoldableTests[F] with Uno
EqYFM: Eq[Y[F[M]]],
EqOptionA: Eq[Option[A]]
): RuleSet = {
implicit def EqXFBYFB: Eq[(X[F[B]], Y[F[B]])] =
implicit val EqXFBYFB: Eq[(X[F[B]], Y[F[B]])] =
new Eq[(X[F[B]], Y[F[B]])] {
override def eqv(x: (X[F[B]], Y[F[B]]), y: (X[F[B]], Y[F[B]])): Boolean =
EqXFB.eqv(x._1, y._1) && EqYFB.eqv(x._2, y._2)
}

new RuleSet {
def name: String = "traverse"
def bases: Seq[(String, RuleSet)] = Nil
Expand All @@ -76,6 +77,7 @@ trait TraverseTests[F[_]] extends FunctorTests[F] with FoldableTests[F] with Uno
"traverse traverseTap" -> forAll(laws.traverseTap[B, M, X] _),
"traverse derive foldMap" -> forAll(laws.foldMapDerived[A, M] _),
"traverse order consistency" -> forAll(laws.traverseOrderConsistent[A] _),
"traverse ref mapAccumulate" -> forAll(laws.mapAccumulateRef[A, M, C] _),
"traverse ref mapWithIndex" -> forAll(laws.mapWithIndexRef[A, C] _),
"traverse ref traverseWithIndexM" -> forAll(laws.traverseWithIndexMRef[Option, A, C] _),
"traverse ref zipWithIndex" -> forAll(laws.zipWithIndexRef[A, C] _)
Expand Down
13 changes: 13 additions & 0 deletions tests/shared/src/test/scala/cats/tests/TraverseSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ abstract class TraverseSuite[F[_]: Traverse](name: String)(implicit ArbFInt: Arb
}
}

test(s"Traverse[$name].mapAccumulate") {
forAll { (fa: F[Int], init: Int, fn: ((Int, Int)) => (Int, Int)) =>
val lhs = fa.mapAccumulate(init)((a, s) => fn((a, s)))

val rhs = fa.foldLeft((init, List.empty[Int])) { case ((s1, acc), a) =>
val (b, s2) = fn((a, s1))
(s2, b :: acc)
}

assert(lhs._1.toList === rhs._2.reverse)
}
}

test(s"Traverse[$name].mapWithIndex") {
forAll { (fa: F[Int], fn: ((Int, Int)) => Int) =>
assert(fa.mapWithIndex((a, i) => fn((a, i))).toList === (fa.toList.zipWithIndex.map(fn)))
Expand Down

0 comments on commit f707496

Please sign in to comment.