Intro

Basics of Differentiation

To start from the very start. Let's square a number.

import scala.math.*

def sq(x: Double) = x * x

List(1.0, 2.0, 3.0).map(sq)
// res0: List[Double] = List(1.0, 4.0, 9.0)

Something that you'll notice, is that the result of the square of a number increases faster than the increase in the inputs. The rate of increase of a function is called it's derivative. Math say;

f ( x ) = x 2

f ' ( x ) = 2 x

Which is one mathematical notation of a derivative. Such derivaties can (sometimes!) be derived symbolically (see chat GPT or a math textbook), but also numerically at a point.

“Duel” ands “Jets"

In case you didn't read Spire's scaladoc yet, you should. I've copied and pasted this bit.

While a complete treatment of the mechanics of automatic differentiation is beyond the scope of this header (see http://en.wikipedia.org/wiki/Automatic_differentiation for details), the basic idea is to extend normal arithmetic with an extra element "h" such that h 2 = 0

h itself is non zero - an infinitesimal.

Dual numbers are extensions of the real numbers analogous to complex numbers: whereas complex numbers augment the reals by introducing an imaginary unit i such that i 2 = - 1

Dual numbers introduce an "infinitesimal" unit h such that h 2 = 0 . Analogously to a complex number c = x + y i , a dual number d = x + y h has two components: the "real" component x, and an "infinitesimal" component y. Surprisingly, this leads to a convenient method for computing exact derivatives without needing to manipulate complicated symbolic expressions.

For example, consider the function f ( x ) = x x

evaluated at 10. Using normal arithmetic,

f(10 + h) = (10 + h) * (10 + h)
          = 100 + 2 * 10 * h + h * h
          = 100 + 20 * h       +---
                    +-----       |
                    |            +--- This is zero
                    |
                    +----------------- This is df/dx

Spire offers us the ability to compute derivatives using Dual numbers through it's Jet implementation.

import spire._
import spire.math._
import spire.implicits.*
import spire.math.Jet.*

given jd: JetDim = JetDim(1)
val y = Jet(10.0) + Jet.h[Double](0)
// y: Jet[Double] = Jet(real = 10.0, infinitesimal = Array(1.0))
y * y
// res1: Jet[Double] = Jet(real = 100.0, infinitesimal = Array(20.0))

Where we tracked the derivative of the first dimension.

Backward Mode

Is something of a mind bend. FOr the example of squaring a number, we need to track the computation graph. A little AST / compiler of the calculation.

Let's visualize the computation graph of squaring a number with a Mermaid diagram:

graph LR
  input["input x"] --> x[x]
  input["input x"] --> y
  y[x] --> mul[*]
  x[x] --> mul[*]
  mul --> result["f(x) = x²"]

Is a representation of the "forward" pass of the computation graph for squaring a number.

Now, let's visualize the "backward" pass of the computation graph, where we propagate derivatives from the output back to the input:

graph RL
  result["f(x) = x²"] --> mul["* (gradient = 1.0)"]
  mul --> x["x (gradient = x)"]
  mul --> y["x (gradient = x)"]
  x --> input["input x (gradient = x + x)"]
  y --> input

The Chain Rule

The chain rule is a fundamental concept in calculus that allows us to find the derivative of a composite function. If we have a function h(x) = f(g(x)), the chain rule states that:

d h d x = d f d g d g d x

Or in the more common notation:

h ( x ) = f ( g ( x ) ) g ( x )

This rule is particularly important in backward mode automatic differentiation, as it allows us to propagate gradients backwards through complex computational graphs.

In this backward pass:

  1. We start with a gradient of 1.0 at the output - the gradient of anything with respect to itself is 1.0
  2. At the multiplication node, the gradient splits and flows to both inputs
  3. Each input to the multiplication receives the gradient multiplied by the other input
  4. The gradients from both paths sum at the input node, giving us the total derivative of 2x