# Tricking GHC into evaluating recursive functions at compile time

Here is a trick I came up with for a project of mine. Suppose you have a GADT like this very simple one:

data T a where TInt :: Int -> T Int TPair :: T a -> T b -> T (a,b)

and a function which does something with it:

sumT :: T a -> Int sumT (TInt n) = n sumT (TPair l r) = sumT l + sumT r

Now, let’s use the two:

term = TPair (TPair (TInt 1) (TInt 2)) (TInt 3) foo = sumT term

Since `foo` is constant, we would expect GHC to evaluate it at compile time and just bind it to 6 in the compiled code, right?

Wrong! For this to happen, GHC would have to inline `sumT`. But `sumT` is a recursive function and GHC never inlines those because it might get into an infinite loop otherwise. This means that it won’t optimise `foo` at all which was absolutely unacceptable in my program. I spent about two days fiddling with inline pragmas, rewrite rules and other unpleasant things until I found a satisfactory solution.

My first attempt was to only inline `sumT` if it is applied to a constructor. We could try adding a couple of rewrite rules.

"sumT/TInt" forall n. sumT (TInt n) = n "sumT/TPair" forall l r. sumT (TPair l r) = sumT l + sumT r

Alas, this doesn’t work most of time. Basically, trying to match on non-trivial constructors in rewrite rules is never a good idea. We could introduce “virtual” constructors, use them everywhere instead of the real ones and match on them.

tInt :: Int -> T Int {-# NOINLINE CONLIKE tInt #-} tInt = TInt tPair :: T a -> T b -> T (a,b) {-# NOINLINE CONLIKE tPair #-} tPair = TPair "sumT/tInt" forall n. sumT (tInt n) = n "sumT/tPair" forall l r. sumT (tPair l r) = sumT l + sumT r

This works much better but, unfortunately, still fails in my program. There, I make extensive use of type families so the Core generated by GHC has casts all over the place. Casts make rule matching highly unreliable because rules don’t ignore them (a particularly ugly wart that I keep running into). So what to do?

The solution I came up with requires adding a unit component to every recursive constructor.

data T a where TInt :: Int -> T Int TPair :: () -> T a -> T b -> T (a,b)

Where we previously wrote `TPair`, we will now write `TPair ()`. In fact, let’s provide a convenience function for that:

tPair :: T a -> T b -> T (a,b) tPair = TPair ()

Now, we define a non-recursive version of `sumT` which is parametrised with a function it is supposed to apply to the pair components.

sumT_cont :: (forall a. () -> T a -> Int) -> T a -> Int {-# INLINE sumT_cont #-} sumT_cont cont (TInt n) = n sumT_cont cont (TPair u l r) = cont u l + cont u r

Note that since `sumT_cont` isn’t recursive it can be freely inlined. Note also that we pass the unit value from the constructor to `cont`. This is absolutely essential.

The actual recursive sum is defined via `sumT_cont`. Of course, it has to be parametrised with `()` (which it ignores).

sumT' :: () -> T a -> Int sumT' _ = sumT_cont sumT' sumT :: T a -> Int {-# INLINE sumT #-} sumT = sumT' ()

The final missing piece that makes the whole thing work is this simple rewrite rule:

"sumT'" sumT' () = sumT_cont sumT'

It “inlines” `sumT'` but *only* if it is applied to `()`. Why is this useful? Let’s see what happens if we apply `sumT` to a term which GHC knows nothing about:

sumT x = {inline sumT} sumT' () x = {apply rule "sumT'"} sumT_cont sumT' x = {inline sumT_cont} case x of TInt n -> n TPair u l r -> sumT' u l + sumT' u r

Rewriting `sumT'` to `sumT_cont sumT’` again would be a disaster as it would put us into an infinite rewriting loop. This is precisely the reason why GHC won't inline recursive functions. But our rule doesn't match here because `u` is not guaranteed to be `()`!

So what happens if we apply `sumT` to a term that is at least partially static?

sumT (tPair (tInt 1) y) = {inline sumT and tPair} sumT' () (TPair () (TInt 1) y) = {apply rule "sumT'"} sumT_cont sumT' (TPair () (TInt 1) y) = {inline sumT_cont} case TPair () (TInt 1) y of TInt n -> n TPair u l r -> sumT' u l + sumT' u r = {eliminate case} sumT' () (TInt 1) + sumT' () y = {apply rule "sumT'" twice} sumT_cont sumT' (TInt 1) + sumT_cont sumT' y = {inline sumT_cont, eliminate case} 1 + case y of TInt n -> n TPair u l r -> sumT' u l + sumT' u r

This looks good! In effect, GHC executed `sumT` for the statically known portion of the term at compile time and deferred the rest to run time. This worked because when it eliminated the case on `TPair` it bound `u` in the case alternative to `()`. This allowed it to apply the `"sumT'"` rule again and thus to get rid of the `TInt` constructor in the left component. The right component is unknown, though, so rewriting stops there. In general, after "inlining" (via the rewrite rule) `sumT'` once, GHC will only apply the rule again if it eliminates the case, thus binding `u` to `()`. This, in turn, is only possible if the head of the term is a known constructor so GHC will continue rewriting and inlining until it consumes all known constructors but will not get into an infinite loop. For `foo` from my first example, which is fully constant, it will perform the entire computation at compile time and reduce it to 6.

A word of warning: it *is* possible to get GHC into an infinite loop with this approach by constructing infinite but statically known terms. For instance, we could apply the same technique to this type.

data U = UInt Int | UPair U U

But now, this term gets us into trouble:

x = UPair (UInt 1) x

This technique works best with GADTs like `T` that do not admit infinite terms.

My eyes glazed over about halfway through. What about adding a builtin

ifix :: Int -> (a->a) -> a

so “ifix 5″ is like “fix” but limited to 5 levels of recursion, etc.? The compiler could inline it if the parameter is an int literal of small enough size. There would be no danger of looping.

@Roman: Sneaky! I love it.

@solrize: Such a function (implemented directly) would be recursive and run into the same recursive inlining problem. GHC doesn’t provide you with the machinery to tell it that the recursive call is strictly decreasing and well-founded according to such a measure.

Am I wrong in thinking that the same effect could be achieved with Template Haskell?

Template Haskell doesn’t help if the constructors only become exposed after optimisations (in particular, after inlining).

Given that you seem to be prepared to play tricks on the inliner to make it do what you want, have you considered adding an annotation that basically tells GHC to evaluate a closed expression? With -fexpose-all-unfoldings the information to do it is available. Evaluating closed expressions could of course make compilation very expensive, and for some expressions (replicate 10000 “a”) it might make the resulting program much worse, but as far as I can tell it would allow you to write the beautiful function that you first came up with.

The problem is that I want to put most of this code in a library. The annotation would have to be provided by the library’s client which is bad. Also, it’s rather tricky to annotate particular expressions because they get moved around and transformed a lot by the optimiser.