Have you ever wondered why there is no efficient and convenient way of merging unrelated collection traversals in Scala’s standard library? – such as computing the average of a list as its sum divided by its count, but in one pass…

In this post, I show how to abuse for comprehension and leverage the following syntax to define aggregators, the dual to iterators:

def average = for { s <- sum[Double]; c <- count } yield s/c
val ls = List(1.0, 2.0, 3.0)
assert(average(ls) == 2.0)

// `average` can then be composed within bigger loops modularly:
def avgmax = for {
  avg <- average
  m   <- max[Double]
} yield (avg,m)
// this does a single traversal of `ls`:
assert(avgmax(ls) == (2.0, Some(3.0)))

The Problem

In Scala, what options does one have to aggregate different things in parallel, in a single collections traversal? As far as I know, there is no solution to this problem in the standard library. There certainly are third party solutions, such as origami or scala-fold, which rely on applicative functors; but they can be awkward to use and incur non-negligible overhead.

So, to do parallel aggregations in vanilla Scala, one currently either has to traverse the collection several times (not efficient), reimplement each aggregation as part of a mega-aggregation (not modular, as it requires reimplementing each sub-aggregation), or use imperative variables and mutation (error-prone and not modular). Here I show a low-level, efficient way to do this in a way dual to Iterator, while supporting a nice for comprehension syntax.

But first let’s review the Scala concept of Iterator.

Iterators

Iterators are an abstraction for fast stateful iteration, usually hidden behind a stateless interface such as Iterable.

In essence, the iterator interface is defined as:

trait Iterator[+A] {
   /** returns the next element, if any, and advances the iteration state */
   def next(): Option[A]
}

However, for performance reasons (avoiding the extra allocation/indirection of Option), Scala defines Iterator as:

trait Iterator[+A] {
  /** whether there is a next element */
  def hasNext(): Boolean
  /** get the the next element, and advances the iteration state;
    * requires hasNext to have returned true */
  def next(): A
}

Scala also provides all kinds of useful methods on operators – not the least of which are map, flatMap, filter/withFilter, and foreach, which allow us to use iterators in for comprehension. For example, below is a nested loop that computes a filtered cartesian product using iterators:

// computes all `(x,y)` tuples with `x` in `xs` and `y` in `ys` where `x < y`:
for { x <- xs.iterator; y <- ys.iterator; if x < y } yield (x,y)

This desugars into:

xs.iterator.flatMap(x => ys.iterator.withFilter(y => x < y).map(y => (x,y)))

Aggregators

It seems like the dual to Iterator, which we will call Aggregator, could logically be defined as:

trait Aggregator[-A, +To] {
  /** Whether we accept more elements */
  def wantsNext(): Boolean
  /** Accept one more element */
  def next(a: A): Unit
  /** Retrieve the aggregation result */
  def result(): To
}

Instead of producing elements one by one, Aggregator consumes elements one by one; and instead of starting with an initial value (an iterable), you end up with a result, via the result() method.

Here is how to implement sum using this interface:

def sum[N](implicit N: scala.Numeric[N]) = new Aggregator[N,N] {
  def wantsNext: Boolean = true
  var result = N.zero
  def next(a: N): Unit = result = N.plus(result,a)
}

Now, we would like to be able to define aggregators and compose them in a modular way, to achieve the stated goal of this article – for instance, computing average from sum and count in a single pass.

Abusing For Comprehension

The trick is surprisingly simple, and consists in defining flatMap with a twist: its argument function should accept its own argument by name, and not by value…

trait Aggregator[-A, +To] { outer =>
  // ...    
  
  // type parameter A0<:A is necessary to avoid violating the contrariance of A
  def flatMap[A0<:A,R](f: (=> To) => Aggregator[A0,R]): Aggregator[A0,R] =
    new Aggregator[A0,R] {
      val inner = f(outer.result())  // `outer.result()` is a thunk passed by name
      def wantsNext: Boolean = inner.wantsNext || outer.wantsNext
      def result(): R = inner.result()
      def next(elem: A0): Unit = {
        if (outer.wantsNext) outer.next(elem)
        if (inner.wantsNext) inner.next(elem) }
    }
  
}

We initialize the inner Aggregator from the f function passed to flatMap, by passing f a thunk that knows how to compute the current result of the outer aggregator. We then return a new Aggregator that combines both sub-aggregations. The implementation of other functions like map and foreach is straightforward, and can be seen in the companion repository of this blog post.

The beauty of this approach is that it initializes the control structure eagerly, and then while iterating does not produce any extra allocations at all. It uses efficient imperative variable updates behind the scenes, hidden from the user.

More Advanced Examples

An interesting fact about Aggregator comprehension is that later aggregators can refer to the current value of earlier ones. For example, here is how to sum the elements of a list that have an even index:

// sum all elements of even index:
def sumEvenIdx = for {
  idx <- count.map(_ - 1) // the current index is the current count minus 1
  s   <- sum[Double] when (idx % 2 == 0)
} yield s

assert(sumEvenIdx(ls) == 
  (for { (x,i) <- ls.zipWithIndex if i % 2 == 0 } yield x).sum)

The when method simply takes a condition by name, and evaluates it whenever a value is fed to the aggregator, to decide whether to aggregate the new value or not.

  def when(cond: => Boolean) = new Aggregator[A, To] {
    def wantsNext() = outer.wantsNext();  def result() = outer.result()
    def next(x: A): Unit = if (cond) outer.next(x) }

The wantsNext method can be used to short-circuit a computation, so that when we are only interested in the result of an iteration up until a certain point, we can stop it early. For instance, an until method can be used to take only n elements into a toBuffer aggregator, as in toBuffer[Int].until(_.size == n).

Finally, notice that an interesting if unoptimal way of defining all possible Aggregator’s would be to compose usages of the primitive current aggregator, of type Aggregator[T,Option[T]]. For example, this is how to define sum based on current:

def sum = {
  lazy val s: Aggregator[Double,Double] =
    for { x <- current[Double] } yield x.fold(0.0)(_ + s.result())
  s
}
assert(sum(List(1.0, 2.0, 3.0)) == 6.0)

(That is, as long as the implementation of map stores a local variable accumulating the current state. Another implementation of map would be to compute the transformed result lazily – only when the result() method is called, which would result in making the above crash with a stack overflow.)

Limitations

It is better not to implement the withFilter method (which is used to enable the if guard inside for comprehensions), because due to the way Scala desugars for comprehensions, it does not seem possible to give it a satisfying semantics. For instance, we could make the semantics so that the previous sumEvenIdx example can be rewritten for { idx <- count; s <- sum[Double] if idx % 2 == 0 }, but the problem is that if the user wrote instead for { idx <- count; if idx % 2 == 0; s <- sum[Double] }, the condition would be evaluated too eagerly, resulting in surprising behavior.

Unfortunately, there is also a problem with the (horrible) way in which Scala desugars intermediate bindings in for comprehensions, such that writing c <- count; idx = c - 1; ... instead of idx <- count.map(_ - 1); ... in the previous example does not have the expected semantics – c will only ever be assigned the initial value of count, namely 0. This is a more serious problem, because we cannot prevent users from writing such bindings (whereas we can prevent them from using if guards).

Going Further

What I show here is all but a simple gimmick based on some whacky combination of Scala features (lambdas with by-name parameters used as part of for comprehensions).

Can we take this trick further, and use it to enable more powerful patterns? In a future post, I will show how this mechanism can be used to zip arbitrary data sources together “efficiently” (without creating intermediate tuples), while still relying on a nice for comprehension syntax.