Log Sum of Exponentials for Robust Sums on the Log Scale

This is a public service announcement in the interest of more robust numerical calculations.

Like matrix inverse, exponentiation is bad news. It’s prone to overflow or underflow. Just try this in R:

> exp(-800)
> exp(800)

That’s not rounding error you see. The first one evaluates to zero (underflows) and the second to infinity (overflows).

A log density of -800 is not unusual with the log likelihood of even a modestly sized data set. So what do we do? Work on the log scale, of course. It turns products into sums, and sums are much less prone to overflow or underflow.

log(a * b) = log(a) + log(b)

But what do we do if we need the log of a sum, not a product? We turn to a mainstay of statistical computing, the log sum of exponentials function.

log(a + b) = log(exp(log(a)) + exp(log(b))

           = log_sum_exp(log(a), log(b))

We use a little algebraic trick to prevent overflow and underflow while preserving as many accurate leading digits in the result as possible:

log_sum_exp(u, v) = max(u, v) + log(exp(u - max(u, v)) + exp(v - max(u, v)))

The leading digits are preserved by pulling the maximum outside. The arithmetic is robust becuase subtracting the maximum on the inside makes sure that only negative numbers or zero are ever exponentiated, so there can be no overflow on those calculations. If there is underflow, we know the leading digits have already been returned as part of the max term on the outside.

Mixtures

We use log-sum-of-exponentials extensively in the internal C++ code for Stan, and it also pops up in user programs when there is a need to marginalize out discrete parameters (as in mixture models or state-space models). For instance, if we have a normal log density function, we can compute the mixture density with mixing proportion lambda as

log_sum_exp(log(lambda) + normal_log(y, mu[1], sigma[1]),
            log1m(lambda) + normal_log(y, mu[2], sigma[2]));

The function log1m is used for robustness; it’s value is defined algebraically by

log1m(u) = log(1 - u)

But unlike the naive calculation, it won’t underflow to 0 when u is close to 1 and 1 – u overflows to 1. Try this in R:

log(1 - 10e-20)
log1p(-10e-20)

log1m isn’t built in, but log1p is and negation doesn’t lose us any digits. The subtraction in the first expression overflows to 1 so the log returns 0 (thus the overall expression underflows). But the second case returns the correct non-zero result.

What goes on under the hood is that different approximations are used to the log function depending on the value of u, typically using lower-order series expansions when standard algorithms are prone to underflow or overflow.

32 thoughts on “Log Sum of Exponentials for Robust Sums on the Log Scale

  1. “1 – u overflows to 1”

    Small nitpick: 1-u is /rounded/ to 1, because overflow specifically means that the number is larger than the largest representable floating-point number (1e350 or so). The difference in terminology is that when 1-u is rounded to 1 the /relative/ error is still very small, but if u underflowed to 0, the relative error would be 100%.

    • Thanks. I’m a numerical computing novice, so I appreciate it. Is there a finer-grained term than “rounding” for this behavior when you subtract add/subtract something too small (ideally something not confusible with another standard operation, like rounding to integers? I know “catastrophic cancellation”, but that’s what happens with 1-u when u is close to 1.

      • I don’t know such a term, although there might be. Perhaps saturation, but googling it seems it means something else. I looked it up in Higham’s Accuracy and Stability of Numerical Algorithms (great book for these questions, Chapter 4 is the one that covers error analysis of summation methods), and he doesn’t seem to introduce any such term, just uses rounding. I think from the point of view of floating-point arithmetic there isn’t anything terribly special about such rounding compared to any other rounding: the basic promise is that the result of 1-u is the mathematically precise number that is rounded to the nearest (according to a certain rule) representable number. So in a way it’s not a “special” occurrence the way underflow would be, because the promise is fulfilled (in an unhelpful kind of way). 1-u is just an ill-conditioned function of u near u=1, so it’s not really floating-point arithmetic’s fault that it fails.

        • In fact, there /used/ to be some CPUs with FP units that would raise an exception whenever a number is subtracted from a number very close to it (this was pre-IEEE754 I think). While I never thought much about it, my understanding is that the consensus was that this was a bad/unhelpful idea. Perhaps it’s discussed in Bill Kahan’s papers somewhere.

        • From “An Interview with the Old Man of Floating-Point”

          The primary objection to Gradual Underflow arose from a fear that it would degrade the performance of the fastest arithmetics. Microcoded implementations like the i8087 and MC68881/2 had nothing to lose, but hardwired floating-point, especially if pipelined, seemed slowed by the extra steps Gradual Underflow required even if no underflows occurred. Two implementation strategies were discussed during meetings of the p754 committee. One used traps, one trap to catch results that underflowed and denormalize them to produce subnormal numbers, and another trap to catch subnormal operands and prenormalize them. The second strategy inserted two extra stages into the pipeline, one stage to prenormalize subnormal operands and another to denormalize underflowed results. Both strategies seemed to inject delays into the critical path for all floating-point operations, slowing all of them.

      • Is it an option to use a bigger data type? Or do they not exist?

        Assuming these sort of issues crop up at many other places in Numeric / Scientific Computing isn’t the time ripe for the languages to internalize this handling under the hood?

        • There isn’t a much bigger data type with widespread hardware support. On laptops, desktops, and servers you can pretty much rely on the C ‘long double’ type compiling to a hardware-supported type with at least the 15 bits of exponent in the Intel 80-bit extended precision format, but I don’t think you can practically rely on more.

          15 bits instead of 11 doesn’t get you very far when the exponent increases with both the sample size and the number of parameters — it only slightly postpones needing to know a bit about floating point computing. In contrast, going from single to double precision was a big deal, and let lots of people get away with ignoring rounding.

        • Could there be a non-floating point type defined for these cases? Maybe a compiler / language abstraction that relies on mapping things behind the scenes to finally the same old types supported by existing hardware.

          Can those types coexist with the other regular types (say, within expressions) or are they such an exotic beast that type-casting and inter-conversion would be too onerous?

        • They potentially could, but honestly it’d be forever before they were actually usable places.

          Languages like C – the baseline and mainstay compiled languages – tend to have new versions published on a schedule. Once the new version is out (or just advanced enough that it’s clear what direction it’s going in), compiler developers go off and build support for the new features or changed features into their compilers.

          The schedule on this compiler support, however, can be totally variable. As an example, let’s take C++11, the C++ standard from, well, 2011.

          The two “main” compilers used for these kinds of languages are Clang and GCC – GCC as the default on a lot of GNU/Linux platforms, Clang as the default on OSX and others. Both got the finished C++11 Spec in September 2011. Clang switched over to C++11 in April 2013; GCC in June *2016*.

          Now, 11 was a pretty big revision, so there’s that. But in a roundabout way what I’m getting at is that while a language abstraction might be doable, it could be anywhere between 2 and 5 years – even once that abstraction is accepted and supported – before it’s actually reliably usable on an inter-platform basis.

  2. How hard would it be given the design of the Stan compiler to detect these patterns and convert them for the user? I’m assuming there’s an intermediate representation in the compiler that these sorts of optimizations could be done on without the user having to worry about it.

    • Mjskay:

      Doing this correction automatically could be tricky, but we are planning to write a parser that catches common errors and flags them for the user. So this could be one of the patterns in our list.

    • The source-to-source compiler that takes a Stan program and produces a C++ header file defining a class is factored through an abstract syntax tree. What we’re talking about here are algebraic transforms (called “peephole optimizations”), that would replace log(1 + u) with log1p(u). One issue that arises, as with much optimization, is that runtime error messages will reference the log1p() function, not the log() function; fixing that issue would be much more involved, as we’d have to track the original and transformed version and pass them to where errors arise at runtime.

      • If you can do this kind of peephole optimization and throw compiler warnings I think it would be valuable. If the compiler warning is sufficiently informative the user should be able to figure out what went on:

        “Warning: converted log(1-x) to equivalent computation log1m(x) for numerical stability at line LLL ”

        or “Warning: converted SOMECODEHERE to equivalent form log_sum_exp(SOMEOTHERCODE) for numerical stability at line LLL”

        The thing that might be difficult is to figure out whether the numerical stability is really happening. For example, suppose someone has log(1+x) in their code, if x is very large then log1m(-x) might not necessarily be more numerically accurate, since it’s probably optimized for the case where x = O(1)

        So, it might be better to simply throw warnings like “consider converting log(1-x) to log1m(x) for numerical accuracy if x is expected to be close to 1 at line LLL”

    • In what language? I don’t know what R does. The built-in C++ library has underlying flags that get set when operations like exp() overflow. You query them as global constants and they can be reset. C++11 gives you more flexiblity to set global options that cause these events to raise exceptions. I don’t think there’s anything that will detect when x + epsilon evaluates to x because epsilon is too small relative to x. Integer overflow was the cause of binary search being broken (or at least not scalable) in Java for many years; see Josh Bloch’s excellent post Extra, Extra – Read All About It: Nearly All Binary Searches and Mergesorts are Broken. Bloch recommends Bentley’s Programming Pearls, which also made a deep impression on me when I was starting to get serious about programming as opposed to algorithm analysis (that’s what a move from academia to industry does).

        • This is actually a proposed alternative to ieee type floating point (storing the logarithm of the magnitude of a number). The main difficulty is that it makes addition and subtraction a complex calculation, whereas multiplication/division is very easy. There was an article in American Scientist several years back about the merits and difficulties of this alternative.

        • Among the cool properties of Hakaru, a functional probabilistic programming language embedded in Haskell, is that it automatically converts arithmetic to the more stable log scale behind the scenes.

    • Here’s what the gnu c library says:

      http://www.gnu.org/software/libc/manual/html_node/FP-Exceptions.html

      specifically, it looks like at the processor level, when underflow occurs the processor sets a register or something (in these documentations called the “status word”) to let you know. The GNU C library is then responsible for checking that status and seeing if a “trap” has been set and then executing the appropriate signal handling code. what happens in other languages is dependent on their runtime code and exception handling mechanisms. But the processor itself doesn’t silently underflow.

      • However, the processor WILL silently round non-special values for example 1+1e-75 = 1 without any warning. Underflow specifically means calculating a number that is too close to 0 to be represented as different from 0. For example exp(-770) = 0 (underflow status set, I’m guessing).

  3. Takes me back to my college course in computation. A big excercise was to write a program that solved quadratic equations… we all laughed until we saw how unintuitive was an algorithm that managed to come up with sensible answers regardless of the relative magnitude of the inputs… I think my first effort got 2 out of 10.

    • the symbolic solution (-b +- sqrt(b^2-4ac))/2a can help us figure out what goes wrong, if b^2-4ac is the difference of two large numbers then it can be subject to cancellation, so although b^2-4ac might be something like 5, it’s because b^2 ~ 10 million and 4ac is 5 less than 10 million.

      next -b+-sqrt(foo) if foo is close to b^2 then again you get cancellation errors. Also foo ~ b^2 when 4ac ~ 0 relative to b^2, which can happen when a is near zero, which means that 1/2a could potentially overflow.

      all of these issues are similar to the ones involved in log_sum_exp

      I was TA for a stats class where the professor assigned a problem in which there was a simple probability density that was constant for an interval and then dropped linearly to zero. She asked the students to find the probability of a random draw being in a certain interval in which the triangular portion was just a little bit of the interval. The students wound up having to either be smart about it, or solve a quadratic that had bad properties when the calculations were carried out by hand on a calculator and paper (ie. subject to roundoff at say the 4th decimal digit which is typical of engineering paper calculations). Many people got wrong answers even though they clearly knew the method needed. The professor said she hadn’t really thought about how solving a quadratic could be difficult.

      • All of that… Plus there is an important branch based on whether b^2-4ac is 0, namely 2 imaginary, 2 real, or 1 solution, so precision on the sign of b^2-4ac is critical as well… Not to mention degenerate cases, namely a=0, or c=0, or a and b=0, or a,b,and c all equal to 0.

  4. If we wanted to alter the log_sum_exp function to take an input array/vector/whatever of arbitrary length in R, rather than just two scalars, would the following be correct?

    log_sum_exp <- function(x)
    {
    m = max(x)
    m + log(sum(exp(x – m)))
    }

  5. This is extremely helpful, thank you! I have been working on mixture models of animal body sizes (estimated through metrics of bones – I’m an archaeologist) and have been having this exact issue when trying to calculate log densities while using JAGS. Now to head back to my code and tinker some more. Or better yet, learn how to use Stan…

Leave a Reply

Your email address will not be published. Required fields are marked *