Intro
Basics of Differentiation
To start from the very start. Let's square a number.
import scala.math.*
def sq(x: Double) = x * x
List(1.0, 2.0, 3.0).map(sq)
// res0: List[Double] = List(1.0, 4.0, 9.0)
Something that you'll notice, is that the result of the square of a number increases faster than the increase in the inputs. The rate of increase of a function is called it's derivative. Math say;
Which is one mathematical notation of a derivative. Such derivaties can (sometimes!) be derived symbolically (see chat GPT or a math textbook), but also numerically at a point.
“Duel” ands “Jets"
In case you didn't read Spire's scaladoc yet, you should. I've copied and pasted this bit.
While a complete treatment of the mechanics of automatic differentiation is beyond the scope of this header (see http://en.wikipedia.org/wiki/Automatic_differentiation for details), the basic idea is to extend normal arithmetic with an extra element "h" such that
h itself is non zero - an infinitesimal.
Dual numbers are extensions of the real numbers analogous to complex numbers: whereas complex numbers augment the reals by introducing an imaginary unit i such that
Dual numbers introduce an "infinitesimal" unit h such that . Analogously to a complex number , a dual number has two components: the "real" component x, and an "infinitesimal" component y. Surprisingly, this leads to a convenient method for computing exact derivatives without needing to manipulate complicated symbolic expressions.
For example, consider the function
evaluated at 10. Using normal arithmetic,
f(10 + h) = (10 + h) * (10 + h)
= 100 + 2 * 10 * h + h * h
= 100 + 20 * h +---
+----- |
| +--- This is zero
|
+----------------- This is df/dx
Spire offers us the ability to compute derivatives using Dual numbers through it's Jet
implementation.
import spire._
import spire.math._
import spire.implicits.*
import spire.math.Jet.*
given jd: JetDim = JetDim(1)
val y = Jet(10.0) + Jet.h[Double](0)
// y: Jet[Double] = Jet(real = 10.0, infinitesimal = Array(1.0))
y * y
// res1: Jet[Double] = Jet(real = 100.0, infinitesimal = Array(20.0))
Where we tracked the derivative of the first dimension.
Backward Mode
Is something of a mind bend. FOr the example of squaring a number, we need to track the computation graph. A little AST / compiler of the calculation.
Let's visualize the computation graph of squaring a number with a Mermaid diagram:
graph LR input["input x"] --> x[x] input["input x"] --> y y[x] --> mul[*] x[x] --> mul[*] mul --> result["f(x) = x²"]
Is a representation of the "forward" pass of the computation graph for squaring a number.
Now, let's visualize the "backward" pass of the computation graph, where we propagate derivatives from the output back to the input:
graph RL result["f(x) = x²"] --> mul["* (gradient = 1.0)"] mul --> x["x (gradient = x)"] mul --> y["x (gradient = x)"] x --> input["input x (gradient = x + x)"] y --> input
The Chain Rule
The chain rule is a fundamental concept in calculus that allows us to find the derivative of a composite function. If we have a function h(x) = f(g(x)), the chain rule states that:
Or in the more common notation:
This rule is particularly important in backward mode automatic differentiation, as it allows us to propagate gradients backwards through complex computational graphs.
In this backward pass:
- We start with a gradient of 1.0 at the output - the gradient of anything with respect to itself is 1.0
- At the multiplication node, the gradient splits and flows to both inputs
- Each input to the multiplication receives the gradient multiplied by the other input
- The gradients from both paths sum at the input node, giving us the total derivative of 2x