Skip to content

Tricking GHC into evaluating recursive functions at compile time

November 5, 2009

Here is a trick I came up with for a project of mine. Suppose you have a GADT like this very simple one:

data T a where
  TInt  :: Int -> T Int
  TPair :: T a -> T b -> T (a,b)

and a function which does something with it:

sumT :: T a -> Int
sumT (TInt n) = n
sumT (TPair l r) = sumT l + sumT r

Now, let’s use the two:

term = TPair (TPair (TInt 1) (TInt 2)) (TInt 3)
foo = sumT term

Since foo is constant, we would expect GHC to evaluate it at compile time and just bind it to 6 in the compiled code, right?

Wrong! For this to happen, GHC would have to inline sumT. But sumT is a recursive function and GHC never inlines those because it might get into an infinite loop otherwise. This means that it won’t optimise foo at all which was absolutely unacceptable in my program. I spent about two days fiddling with inline pragmas, rewrite rules and other unpleasant things until I found a satisfactory solution.

My first attempt was to only inline sumT if it is applied to a constructor. We could try adding a couple of rewrite rules.

"sumT/TInt"  forall n. sumT (TInt n) = n
"sumT/TPair" forall l r. sumT (TPair l r) = sumT l + sumT r

Alas, this doesn’t work most of time. Basically, trying to match on non-trivial constructors in rewrite rules is never a good idea. We could introduce “virtual” constructors, use them everywhere instead of the real ones and match on them.

tInt :: Int -> T Int
{-# NOINLINE CONLIKE tInt #-}
tInt = TInt

tPair :: T a -> T b -> T (a,b)
{-# NOINLINE CONLIKE tPair #-}
tPair = TPair

"sumT/tInt" forall n. sumT (tInt n) = n
"sumT/tPair" forall l r. sumT (tPair l r) = sumT l + sumT r

This works much better but, unfortunately, still fails in my program. There, I make extensive use of type families so the Core generated by GHC has casts all over the place. Casts make rule matching highly unreliable because rules don’t ignore them (a particularly ugly wart that I keep running into). So what to do?

The solution I came up with requires adding a unit component to every recursive constructor.

data T a where
  TInt  :: Int -> T Int
  TPair :: () -> T a -> T b -> T (a,b)

Where we previously wrote TPair, we will now write TPair (). In fact, let’s provide a convenience function for that:

tPair :: T a -> T b -> T (a,b)
tPair = TPair ()

Now, we define a non-recursive version of sumT which is parametrised with a function it is supposed to apply to the pair components.

sumT_cont :: (forall a. () -> T a -> Int) -> T a -> Int
{-# INLINE sumT_cont #-}
sumT_cont cont (TInt  n) = n
sumT_cont cont (TPair u l r) = cont u l + cont u r

Note that since sumT_cont isn’t recursive it can be freely inlined. Note also that we pass the unit value from the constructor to cont. This is absolutely essential.

The actual recursive sum is defined via sumT_cont. Of course, it has to be parametrised with () (which it ignores).

sumT' :: () -> T a -> Int
sumT' _ = sumT_cont sumT'

sumT :: T a -> Int
{-# INLINE sumT #-}
sumT = sumT' ()

The final missing piece that makes the whole thing work is this simple rewrite rule:

"sumT'"  sumT' () = sumT_cont sumT'

It “inlines” sumT' but only if it is applied to (). Why is this useful? Let’s see what happens if we apply sumT to a term which GHC knows nothing about:

   sumT x
= {inline sumT}
   sumT' () x
= {apply rule "sumT'"}
   sumT_cont sumT' x
= {inline sumT_cont}
    case x of TInt n -> n
              TPair u l r -> sumT' u l + sumT' u r

Rewriting sumT' to sumT_cont sumT’ again would be a disaster as it would put us into an infinite rewriting loop. This is precisely the reason why GHC won’t inline recursive functions. But our rule doesn’t match here because u is not guaranteed to be ()!

So what happens if we apply sumT to a term that is at least partially static?

   sumT (tPair (tInt 1) y)
= {inline sumT and tPair}
   sumT' () (TPair () (TInt 1) y)
= {apply rule "sumT'"}
   sumT_cont sumT' (TPair () (TInt 1) y)
= {inline sumT_cont}
    case TPair () (TInt 1) y of TInt n -> n
                                TPair u l r -> sumT' u l + sumT' u r
= {eliminate case}
   sumT' () (TInt 1) + sumT' () y
= {apply rule "sumT'" twice}
    sumT_cont sumT' (TInt 1) + sumT_cont sumT' y
= {inline sumT_cont, eliminate case}
    1 + case y of TInt n -> n
                  TPair u l r -> sumT' u l + sumT' u r

This looks good! In effect, GHC executed sumT for the statically known portion of the term at compile time and deferred the rest to run time. This worked because when it eliminated the case on TPair it bound u in the case alternative to (). This allowed it to apply the "sumT'" rule again and thus to get rid of the TInt constructor in the left component. The right component is unknown, though, so rewriting stops there. In general, after “inlining” (via the rewrite rule) sumT' once, GHC will only apply the rule again if it eliminates the case, thus binding u to (). This, in turn, is only possible if the head of the term is a known constructor so GHC will continue rewriting and inlining until it consumes all known constructors but will not get into an infinite loop. For foo from my first example, which is fully constant, it will perform the entire computation at compile time and reduce it to 6.

A word of warning: it is possible to get GHC into an infinite loop with this approach by constructing infinite but statically known terms. For instance, we could apply the same technique to this type.

data U = UInt Int | UPair U U

But now, this term gets us into trouble:

x = UPair (UInt 1) x

This technique works best with GADTs like T that do not admit infinite terms.

About these ads
6 Comments leave one →
  1. solrize permalink
    November 5, 2009 20:17

    My eyes glazed over about halfway through. What about adding a builtin
    ifix :: Int -> (a->a) -> a
    so “ifix 5″ is like “fix” but limited to 5 levels of recursion, etc.? The compiler could inline it if the parameter is an int literal of small enough size. There would be no danger of looping.

  2. November 6, 2009 00:53

    @Roman: Sneaky! I love it.

    @solrize: Such a function (implemented directly) would be recursive and run into the same recursive inlining problem. GHC doesn’t provide you with the machinery to tell it that the recursive call is strictly decreasing and well-founded according to such a measure.

  3. Richard G. permalink
    November 6, 2009 08:05

    Am I wrong in thinking that the same effect could be achieved with Template Haskell?

    • November 6, 2009 13:15

      Template Haskell doesn’t help if the constructors only become exposed after optimisations (in particular, after inlining).

  4. Peter permalink
    November 21, 2009 12:30

    Given that you seem to be prepared to play tricks on the inliner to make it do what you want, have you considered adding an annotation that basically tells GHC to evaluate a closed expression? With -fexpose-all-unfoldings the information to do it is available. Evaluating closed expressions could of course make compilation very expensive, and for some expressions (replicate 10000 “a”) it might make the resulting program much worse, but as far as I can tell it would allow you to write the beautiful function that you first came up with.

    • November 27, 2009 01:44

      The problem is that I want to put most of this code in a library. The annotation would have to be provided by the library’s client which is bad. Also, it’s rather tricky to annotate particular expressions because they get moved around and transformed a lot by the optimiser.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

Follow

Get every new post delivered to your Inbox.

%d bloggers like this: