Recurse Center

List comprehensions in eight lines of Clojure

Zach Allaun

List comprehensions are a syntactic construct that allow you to create new lists in an elegant and concise way.

Here’s a question, though: If your language did not support list comprehensions, what would it take to add them? For a language like Ruby, you may have to dig into the language implementation itself. If you’re using a Lisp, you can write a macro.

Lisp macros allows you to transform your code at compile time. Macro calls are run before the rest of your code during a process called macro expansion. You define macros almost like any other function, except that the arguments to a macro are unevaluated Lisp code. Because Lisp code is just data, you can operate on it like any other data.

This is how you can implement list comprehensions in eight lines of Clojure.

(defmacro list-comp [[binding seq-expr & bindings] body-expr]
  (cond (not binding)
        `(list ~body-expr)

        (= :when binding)
        `(when ~seq-expr (list-comp ~bindings ~body-expr))

        :else
        `(mapcat (fn [~binding] (list-comp ~bindings ~body-expr))
                 ~seq-expr)))

I’ve been thinking about this macro for three days now, and it’s still terribly exciting to me. It’s the macro I wish I had seen when I was learning about them. It takes full advantage of Clojure at compile-time and uses both higher-order functions and recursion.

Let’s see what it can do.

;; copy lists
(list-comp [x (range 10)]
  x)
;;=> (0 1 2 3 4 5 6 7 8 9)

;; double every number
(list-comp [x (range 5)]
  (* x 2))
;;=> (0 2 4 6 8)

;; filter out elements
(list-comp [x (range 10) :when (odd? x)]
   x)
;;=> (1 3 5 7 9)

;; nested bindings
(list-comp [x "abc"
            y [0 1 2]]
  [x y])
;;=> ([\a 0] [\a 1] [\a 2] [\b 0] [\b 1] [\b 2] [\c 0] [\c 1] [\c 2])

;; list permutations
(defn permutations [xs]
  (if-not (seq xs)
    (list ())
    (list-comp [x xs
                ys (permutations (list-comp [z xs :when (not= z x)]
                                   z))]
      (conj ys x))))

(permutations [1 2 3])
;;=> ((1 2 3) (1 3 2) (2 1 3) (2 3 1) (3 1 2) (3 2 1))

At compile time, it is transformed into code that produces the correct result. It’s a pretty handy macro. We can see how it works by looking at some macro expansions.

At its core, list-comp is just a recursive transformation. This is made obvious by looking at the first macroexpansion of a comprehension with nested bindings.

(macroexpand '(list-comp [x "ab" y [0 1]] [x y]))

;; yields

(mapcat (fn [x] (list-comp [y [0 1]] [x y]))
        "ab")
;;=> ([\a 0] [\a 1] [\b 0] [\b 1])

If we simply assume that the inner call to list-comp does the right thing – that is, it successfully binds y to 0 and then 1 – then it is clear that this expansion works. The call to mapcat will take care of binding x properly, and then concatenate the results. This will work with an arbitrary level of nesting.

What happens when we use the :when clause?

(macroexpand '(list-comp [x "ab" :when (= x \a) y [0 1]] [x y]))

;; yields

(mapcat (fn [x]
          (when (= x \a)
            (list-comp (y [0 1]) [x y])))
        "ab")
;;=> ([\a 0] [\a 1])

If the condition is truthy, we execute the inner comprehension. Otherwise, when returns nil, which, when concatenated with some other list, returns that other list.

However, the skeptical reader just noticed a possible wrench in the works. If :when wraps the recursive call, what happens when there is no recursive call left? That is, doesn’t everything break if :when is last form to appear in the bindings?

Thankfully, it doesn’t. The base-case is not the case of a single binding pair, but instead the case of an empty set of bindings. Then, we just return the body’s result as a singleton list.

(macroexpand '(list-comp [] :foo))

;; yields

(list :foo)

This means that our :when clause can appear at the end of the bindings vector.

(clojure.walk/macroexpand-all
  '(list-comp [x (range 10) :when (even? x)] x))

;; yields

(mapcat (fn [x]
          (when (even? x)
            (list x))) ;; the empty-binding case
        (range 10))
;;=> (0 2 4 6 8)

And that’s all there is to it.

A special thanks to Alan O'Donnell and Darius Bacon for joining me on the quest for the simplest list comprehension macro.

Edit Jan 2, 2013: I should have mentioned this, but for posterity’s sake, Clojure comes with list comprehensions already built in. See Clojure’s for macro.