Have you ever wondered why there is no efficient and convenient way of merging unrelated collection traversals in Scala’s standard library? – such as computing the average of a list as its sum divided by its count, but in one pass…
In this post, I show how to abuse for
comprehension and leverage the following syntax to define aggregators, the dual to iterators:
def average = for { s <- sum[Double]; c <- count } yield s/c
val ls = List(1.0, 2.0, 3.0)
assert(average(ls) == 2.0)
// `average` can then be composed within bigger loops modularly:
def avgmax = for {
avg <- average
m <- max[Double]
} yield (avg,m)
// this does a single traversal of `ls`:
assert(avgmax(ls) == (2.0, Some(3.0)))
The Problem
In Scala, what options does one have to aggregate different things in parallel, in a single collections traversal? As far as I know, there is no solution to this problem in the standard library. There certainly are third party solutions, such as origami or scala-fold, which rely on applicative functors; but they can be awkward to use and incur non-negligible overhead.
So, to do parallel aggregations in vanilla Scala, one currently either has to traverse the collection several times (not efficient), reimplement each aggregation as part of a mega-aggregation (not modular, as it requires reimplementing each sub-aggregation), or use imperative variables and mutation (error-prone and not modular). Here I show a low-level, efficient way to do this in a way dual to Iterator
, while supporting a nice for
comprehension syntax.
But first let’s review the Scala concept of Iterator
.
Iterators
Iterators are an abstraction for fast stateful iteration, usually hidden behind a stateless interface such as Iterable
.
In essence, the iterator interface is defined as:
trait Iterator[+A] {
/** returns the next element, if any, and advances the iteration state */
def next(): Option[A]
}
However, for performance reasons (avoiding the extra allocation/indirection of Option
), Scala defines Iterator
as:
trait Iterator[+A] {
/** whether there is a next element */
def hasNext(): Boolean
/** get the the next element, and advances the iteration state;
* requires hasNext to have returned true */
def next(): A
}
Scala also provides all kinds of useful methods on operators – not the least of which are map
, flatMap
, filter
/withFilter
, and foreach
, which allow us to use iterators in for
comprehension. For example, below is a nested loop that computes a filtered cartesian product using iterators:
// computes all `(x,y)` tuples with `x` in `xs` and `y` in `ys` where `x < y`:
for { x <- xs.iterator; y <- ys.iterator; if x < y } yield (x,y)
This desugars into:
xs.iterator.flatMap(x => ys.iterator.withFilter(y => x < y).map(y => (x,y)))
Aggregators
It seems like the dual to Iterator
, which we will call Aggregator
, could logically be defined as:
trait Aggregator[-A, +To] {
/** Whether we accept more elements */
def wantsNext(): Boolean
/** Accept one more element */
def next(a: A): Unit
/** Retrieve the aggregation result */
def result(): To
}
Instead of producing elements one by one, Aggregator
consumes elements one by one; and instead of starting with an initial value (an iterable), you end up with a result, via the result()
method.
Here is how to implement sum
using this interface:
def sum[N](implicit N: scala.Numeric[N]) = new Aggregator[N,N] {
def wantsNext: Boolean = true
var result = N.zero
def next(a: N): Unit = result = N.plus(result,a)
}
Now, we would like to be able to define aggregators and compose them in a modular way, to achieve the stated goal of this article – for instance, computing average
from sum
and count
in a single pass.
Abusing For Comprehension
The trick is surprisingly simple, and consists in defining flatMap
with a twist: its argument function should accept its own argument by name, and not by value…
trait Aggregator[-A, +To] { outer =>
// ...
// type parameter A0<:A is necessary to avoid violating the contrariance of A
def flatMap[A0<:A,R](f: (=> To) => Aggregator[A0,R]): Aggregator[A0,R] =
new Aggregator[A0,R] {
val inner = f(outer.result()) // `outer.result()` is a thunk passed by name
def wantsNext: Boolean = inner.wantsNext || outer.wantsNext
def result(): R = inner.result()
def next(elem: A0): Unit = {
if (outer.wantsNext) outer.next(elem)
if (inner.wantsNext) inner.next(elem) }
}
}
We initialize the inner Aggregator
from the f
function passed to flatMap
, by passing f
a thunk that knows how to compute the current result of the outer aggregator. We then return a new Aggregator
that combines both sub-aggregations. The implementation of other functions like map
and foreach
is straightforward, and can be seen in the companion repository of this blog post.
The beauty of this approach is that it initializes the control structure eagerly, and then while iterating does not produce any extra allocations at all. It uses efficient imperative variable updates behind the scenes, hidden from the user.
More Advanced Examples
An interesting fact about Aggregator
comprehension is that later aggregators can refer to the current value of earlier ones. For example, here is how to sum the elements of a list that have an even index:
// sum all elements of even index:
def sumEvenIdx = for {
idx <- count.map(_ - 1) // the current index is the current count minus 1
s <- sum[Double] when (idx % 2 == 0)
} yield s
assert(sumEvenIdx(ls) ==
(for { (x,i) <- ls.zipWithIndex if i % 2 == 0 } yield x).sum)
The when
method simply takes a condition by name, and evaluates it whenever a value is fed to the aggregator, to decide whether to aggregate the new value or not.
def when(cond: => Boolean) = new Aggregator[A, To] {
def wantsNext() = outer.wantsNext(); def result() = outer.result()
def next(x: A): Unit = if (cond) outer.next(x) }
The wantsNext
method can be used to short-circuit a computation, so that when we are only interested in the result of an iteration up until a certain point, we can stop it early. For instance, an until
method can be used to take only n
elements into a toBuffer
aggregator, as in toBuffer[Int].until(_.size == n)
.
Finally, notice that an interesting if unoptimal way of defining all possible Aggregator
’s would be to compose usages of the primitive current
aggregator, of type Aggregator[T,Option[T]]
. For example, this is how to define sum
based on current
:
def sum = {
lazy val s: Aggregator[Double,Double] =
for { x <- current[Double] } yield x.fold(0.0)(_ + s.result())
s
}
assert(sum(List(1.0, 2.0, 3.0)) == 6.0)
(That is, as long as the implementation of map
stores a local variable accumulating the current state. Another implementation of map
would be to compute the transformed result lazily – only when the result()
method is called, which would result in making the above crash with a stack overflow.)
Limitations
It is better not to implement the withFilter
method (which is used to enable the if
guard inside for
comprehensions), because due to the way Scala desugars for
comprehensions, it does not seem possible to give it a satisfying semantics. For instance, we could make the semantics so that the previous sumEvenIdx
example can be rewritten for { idx <- count; s <- sum[Double] if idx % 2 == 0 }
, but the problem is that if the user wrote instead for { idx <- count; if idx % 2 == 0; s <- sum[Double] }
, the condition would be evaluated too eagerly, resulting in surprising behavior.
Unfortunately, there is also a problem with the (horrible) way in which Scala desugars intermediate bindings in for
comprehensions, such that writing c <- count; idx = c - 1; ...
instead of idx <- count.map(_ - 1); ...
in the previous example does not have the expected semantics – c
will only ever be assigned the initial value of count
, namely 0
. This is a more serious problem, because we cannot prevent users from writing such bindings (whereas we can prevent them from using if
guards).
Going Further
What I show here is all but a simple gimmick based on some whacky combination of Scala features (lambdas with by-name parameters used as part of for
comprehensions).
Can we take this trick further, and use it to enable more powerful patterns? In a future post, I will show how this mechanism can be used to zip arbitrary data sources together “efficiently” (without creating intermediate tuples), while still relying on a nice for
comprehension syntax.