Symbolic differentiation in a few lines of code
Jun 26, 2023
Rich FitzJohn
10 minute read
Tags:

We are working on automatic differentiation of odin models, which requires support for differentiating expressions symbolically, in order to write new equations that can be used to numerically propagate derivatives of a model.

R already has support for doing this via the D function:

D(quote(2 * x^2 * log(sqrt(x))), "x")
## 2 * (2 * x) * log(sqrt(x)) + 2 * x^2 * (0.5 * x^-0.5/sqrt(x))

and the Deriv package provides an extensible interface. However, odin has peculiar syntax with arrays and we’re interested in trying to differentiate through stochastic functions, so a bespoke solution felt useful.

Symbolic differentiation turns out to be surprisingly easy, and quite elegant, to implement; this post shows the general idea.

To start, consider differentiating the expression x^2 + x^3 with respect to x. Recall the mechanical rules of differentiation from school that we can write this as d/dx x^2 + d/dx x^3 and then that we differentiate functions of the form x^n as n x^(n - 1) – this is the primary insight we need: that the process is recursive as we break down every operation into smaller chunks and keep on differentiating interior expressions with respect to x until there’s nothing left.

The simplest possible differentiation rules concern numbers; d/dx n for any number n is zero (that is, the gradient of n with respect to x is zero). Similarly, for any symbol (say a but not a + b) except x the derivative is also zero. And the derivative of x with respect to x is one. With this, we have the edge case for a recursive function:

differentiate <- function(expr, name) {
  if (!is.recursive(expr)) {
    if (identical(expr, as.symbol(name))) 1 else 0
  } else {
    stop("not yet implemented")
  }
}

which we can apply like so:

differentiate(quote(x), "x")
## [1] 1
differentiate(quote(a), "x")
## [1] 0
differentiate(quote(1), "x")
## [1] 0

Interesting expressions are not supported yet:

differentiate(quote(x + x * x), "x")
## Error in differentiate(quote(x + x * x), "x"): not yet implemented

To implement the case where we have compound expressions (is.recursive(expr) returning TRUE), consider the way we can represent these expressions:

expr <- quote(x + x * x)
as.list(expr)
## [[1]]
## `+`
## 
## [[2]]
## x
## 
## [[3]]
## x * x

Every call can be represented this way - the first element is the function being called and the remaining elements are its arguments. This structure is recursive:

as.list(expr[[3]])
## [[1]]
## `*`
## 
## [[2]]
## x
## 
## [[3]]
## x

To apply our differentiation rules we need to describe how to handle each function (here, + and *) and put together the results, descending into the subexpressions with differentiate() again until we get our edge cases.

The rule for differentiating sums is very straightforward, as noted above: we take the sum of the derivatives!

d_plus <- function(expr, name) {
  call("+", differentiate(expr[[2]], name), differentiate(expr[[3]], name))
}

The call() function constructs expression arguments (so call("+", quote(x), 1) returns x + 1) and here we are descending into each expression with differentiate(). We then rewrite differentiate() to call d_plus() when required:

differentiate <- function(expr, name) {
  if (!is.recursive(expr)) {
    if (identical(expr, as.symbol(name))) 1 else 0
  } else {
    fn <- as.character(expr[[1]])
    switch(fn,
           "+" = d_plus(expr, name),
           stop("not yet implemented"))
  }
}

With this we can differentiate a sum of any depth:

differentiate(quote(x + 5), "x")
## 1 + 0
differentiate(quote(x + y + x), "x")
## 1 + 0 + 1

We can then proceed, writing out rules for different functions as we need them. For example, the product rule:

d_product <- function(expr, name) {
  a <- expr[[2]]
  b <- expr[[3]]
  da <- differentiate(a, name)
  db <- differentiate(b, name)
  call("+", call("*", da, b), call("*", a, db))
}

or the quotient rule

d_quotient <- function(expr, name) {
  a <- expr[[2]]
  b <- expr[[3]]
  da <- differentiate(a, name)
  db <- differentiate(b, name)
  ## da / b - a * db / (b * b)
  call("-", call("/", da, b), call("/", call("*", a, db), call("*", b, b)))
}

For subtraction, we need to distinguish between unary minus (e.g., -a) and subtraction (e.g., a - b)

d_minus <- function(expr, name) {
  if (length(expr) == 2) {
    call("-", differentiate(expr[[2]], name))
  } else {
    call("-", differentiate(expr[[2]], name), differentiate(expr[[3]], name))
  }
}

It turns out that ( is a function too, and also needs a rule, but it is very simple:

d_parenthesis <- function(expr, name) {
  call("(", differentiate(expr[[2]], name))
}

We can put all these rules into a list:

rules <- list(
  "+" = d_plus,
  "-" = d_minus,
  "*" = d_product,
  "/" = d_quotient,
  "(" = d_parenthesis)

and rewrite our differentiate() implementation again:

differentiate <- function(expr, name) {
  if (!is.recursive(expr)) {
    if (identical(expr, as.symbol(name))) 1 else 0
  } else {
    fn <- as.character(expr[[1]])
    if (!(fn %in% names(rules))) {
      stop(sprintf("Differentiation of '%s' not yet implemented", fn))
    }
    rules[[fn]](expr, name)
  }
}

and with this we can differentiate all sorts of things:

differentiate(quote(-2 * x / (x * x - 3 * x)), "x")
## (-0 * x + -2 * 1)/(x * x - 3 * x) - -2 * x * (1 * x + x * 1 - 
##     (0 * x + 3 * 1))/((x * x - 3 * x) * (x * x - 3 * x))

This is fine, except that the generated expressions are fairly ugly, with lots of obviously redundant expressions (e.g., -0 * x which is obviously 0 and + -2 * 1 which is just - 2). However, the expressions agree with those from D once evaluated:

eval(differentiate(quote(-2 * x / (x * x - 3 * x)), "x"), list(x = pi))
## [1] 99.75819
eval(D(quote(-2 * x / (x * x - 3 * x)), "x"), list(x = pi))
## [1] 99.75819

Extending the implementation by adding more rules

We could extend this easily now by adding more rules, and the implementation will even tell us what we need to add. So if we try and evaluate

differentiate(quote(exp(2 * x)), "x")
## Error in differentiate(quote(exp(2 * x)), "x"): Differentiation of 'exp' not yet implemented

we get told to we need to implement the rule for exp which is simply:

d_exp <- function(expr, name) {
  call("*", differentiate(expr[[2]], name), expr)
}

(that is, d/dx exp(f(x)) is f'(x) exp(f(x))). We add this to our set of rules:

rules$exp <- d_exp

and now we can differentiate this new expression:

differentiate(quote(exp(2 * x)), "x")
## (0 * x + 2 * 1) * exp(2 * x)

Improving the implementation by writing sensible expressions

Simplifying the expressions turns out to be much more work than the differentiation. The trick that we use is to avoid using call() directly to build expressions and create a simplifying expression builder that applies some simple rules to avoid building overly complicated expressions. This does not simplify everything, but cuts out the most egregious bits of noise.

This is probably not important for the efficiency of generated code (we’re going to send this to an optimising compiler via some C++ code eventually) but it does make the resulting expressions easier to think about.

Consider replacing call("+", a, b) with something that will avoid creating silly expressions. If given numeric arguments for a and b we should sum them, and if either of a or b is zero we should return the other argument:

m_plus <- function(a, b) {
  if (is.numeric(a) && is.numeric(b)) {
    a + b
  } else if (is.numeric(b)) {
    m_plus(b, a)
  } else if (is.numeric(a) && a == 0) {
    b
  } else {
    call("+", a, b)
  }
}

The pattern here is that only the last branch gives up and actually builds an expression with call()

m_plus(3, 4)
## [1] 7
m_plus(quote(a), 0)
## a
m_plus(quote(a), 1)
## 1 + a
m_plus(0, quote(b))
## b
m_plus(1, quote(b))
## 1 + b
m_plus(quote(a), quote(b))
## a + b

We can do similar things with multiplication:

m_product <- function(a, b) {
  if (is.numeric(a) && is.numeric(b)) {
    a * b
  } else if (is.numeric(b)) {
    m_product(b, a)
  } else if (is.numeric(a) && a == 0) {
    0
  } else if (is.numeric(a) && a == 1) {
    b
  } else {
    call("*", a, b)
  }
}
m_product(3, 4)
## [1] 12
m_product(quote(a), 0)
## [1] 0
m_product(quote(a), 1)
## a
m_product(quote(a), 2)
## 2 * a
m_product(quote(a), quote(b))
## a * b

unary minus

is_call <- function(x, name) {
  is.recursive(x) && as.character(x[[1]]) == name
}
m_uminus <- function(a) {
  if (is.numeric(a)) {
    -a
  } else if (length(a) == 2 && identical(a[[1]], quote(`-`))) {
    a[[2]]
  } else if (is_call(a, "(")) {
    m_uminus(a[[2]])
  } else {
    call("-", a)
  }
}

subtraction

m_minus <- function(a, b) {
  if (is.numeric(a) && is.numeric(b)) {
    a - b
  } else if (is.numeric(a) && a == 0) {
    m_uminus(b)
  } else if (is.numeric(b) && b == 0) {
    a
  } else {
    call("-", a, b)
  }
}

and division

m_quotient <- function(a, b) {
  if (is.numeric(a) && is.numeric(b)) {
    a / b
  } else if (is.numeric(a) && a == 0) {
    0
  } else if (is.numeric(b) && b == 0) {
    Inf
  } else if (is.numeric(b) && b == 1) {
    a
  } else {
    call("/", a, b)
  }
}

Finally, parentheses (this is done differently in our implementation but this is a little simpler):

m_parenthesis <- function(a) {
  if (is.symbol(a) || is.numeric(a)) {
    a
  } else {
    call("(", a)
  }
}

We can then rewrite all our rules to use these functions instead of call() directly:

d_plus <- function(expr, name) {
  m_plus(differentiate(expr[[2]], name), differentiate(expr[[3]], name))
}

d_minus <- function(expr, name) {
  if (length(expr) == 2) {
    m_uminus(differentiate(expr[[2]], name))
  } else {
    m_minus(differentiate(expr[[2]], name), differentiate(expr[[3]], name))
  }
}

d_product <- function(expr, name) {
  a <- expr[[2]]
  b <- expr[[3]]
  da <- differentiate(a, name)
  db <- differentiate(b, name)
  m_plus(m_product(da, b), m_product(a, db))
}

d_quotient <- function(expr, name) {
  a <- expr[[2]]
  b <- expr[[3]]
  da <- differentiate(a, name)
  db <- differentiate(b, name)
  ## da / b - a * db / (b * b)
  m_minus(m_quotient(da, b), m_quotient(m_product(a, db), m_product(b, b)))
}

d_parenthesis <- function(expr, name) {
  m_parenthesis(differentiate(expr[[2]], name))
}

rules <- list(
  "+" = d_plus,
  "-" = d_minus,
  "*" = d_product,
  "/" = d_quotient,
  "(" = d_parenthesis)

Now, when we call differentiate(), things look much nicer:

differentiate(quote(-2 * x / (x * x - 3 * x)), "x")
## -2/(x * x - 3 * x) - -2 * x * (x + x - 3)/((x * x - 3 * x) * 
##     (x * x - 3 * x))

There are still some weirdnesses here (e.g., a - -2 * x * (x + x - 3) which are surprisingly hard to undo; the rhs of this expression is a tree with structure:

lobstr::ast(-2 * x * (x + x - 3))
## █─`*` 
## ├─█─`*` 
## │ ├─█─`-` 
## │ │ └─2 
## │ └─x 
## └─█─`(` 
##   └─█─`-` 
##     ├─█─`+` 
##     │ ├─x 
##     │ └─x 
##     └─3

so to simplify subtraction we have to extract the - from within two layers of multiplication, which is yet more recursion.

Our full implementation can be seen in this pull request, which follows the presentation here fairly closely.

Why go to this effort?

Given that the D function (and the Deriv package) can do all this (and more) already, it’s not very obvious why you might want to this. We don’t really want to differentiate R, but instead the domain specific language that supports odin. This is a small subset of R but there are two things that have quite different semantics that we need considerable control over (they’re not yet supported in the PR linked above).

Firstly, odin (via odin.dust) supports converting stochastic models into deterministic ones by taking expectations of the stochastic components; the underlying stochastic support already looks different to the call expected from R. So for example, we might write:

m <- rbinom(n, p)

to represent a binomial random draw with size n and probability p; R’s first argument (the number of draws to take) does not appear here. The expectation of this draw is simply n * p and we can now easily add rules to our differentiation support to differentiate this part of a model with respect to any other model quantity.

The other tricky part we have is the way that odin interprets array expressions, especially those that contain sums, as these have specific semantics. For example, the valid odin expression

lambda[] <- beta * sum(s_ij[, i])

could be rewritten as in something more R-ish

for (i in seq_along(lambda)) {
  lambda[i] <- beta * sum(s_ij[, i])
}

and differentiating the odin DSL version requires knowledge of these semantics. The solution to this is left as an exercise to the reader, and possibly a future blog post…




comments powered by Disqus