We've already seen that Spire provides the abillity to get the derivate of a function at a point through it's Jet
class. Through function composition, we can find the derivate of a pretty much arbitrary function.
import spire._
import spire.math._
import spire.implicits.*
import _root_.algebra.ring.Field
import spire.algebra.Trig
import spire.math.Jet.*
import io.github.quafadas.spireAD.*
def softmax[T: Trig: ClassTag](x: Array[T])(using
f: Field[T]
) = {
val exps = x.map(exp)
val sumExps = exps.foldLeft(f.zero)(_ + _)
exps.map(t => t / sumExps)
}
def sumSin[T: Trig: ClassTag](x: Array[T])(using
f: Field[T]
) = {
x.map(sin).foldLeft(f.zero)(_ + _)
}
val dim = 4
// dim: Int = 4
given jd: JetDim = JetDim(dim)
val range = (1 to dim).toArray.map(_.toDouble)
// range: Array[Double] = Array(1.0, 2.0, 3.0, 4.0)
softmax[Double](range)
// res0: Array[Double] = Array(
// 0.03205860328008499,
// 0.08714431874203257,
// 0.23688281808991013,
// 0.6439142598879722
// )
softmax[Jet[Double]](range.jetArr)
// res1: Array[Jet[Double]] = Array(
// Jet(
// real = 0.03205860328008499,
// infinitesimal = Array(
// 0.03103084923581511,
// -0.002793725142664097,
// -0.007594132289012969,
// -0.020642991804138044
// )
// ),
// Jet(
// real = 0.08714431874203257,
// infinitesimal = Array(
// -0.002793725142664097,
// 0.0795501864530196,
// -0.02064299180413805,
// -0.056113469506217456
// )
// ),
// Jet(
// real = 0.23688281808991013,
// infinitesimal = Array(
// -0.007594132289012968,
// -0.020642991804138047,
// 0.18076934858369267,
// -0.15253222449054166
// )
// ),
// Jet(
// real = 0.6439142598879722,
// infinitesimal = Array(
// -0.020642991804138044,
// -0.05611346950621745,
// -0.15253222449054163,
// 0.22928868580089717
// )
// )
// )
sumSin(softmax[Double](range))
// res2: Double = 0.9540912841722475
sumSin(softmax[Jet[Double]](range.jetArr))
// res3: Jet[Double] = Jet(
// real = 0.9540912841722475,
// infinitesimal = Array(
// 0.004340445639588408,
// 0.011512648745511889,
// 0.02557837632094498,
// -0.04143147070604522
// )
// )
Once you're past the (somewhat formiddable) list of Spire's typeclasses, we can use function composition to track the derivaties of arbitrarily complex functions. Pretty neat!
This is forward mode automatic differentation - it calculations the partial differentials of the output value, with respect to it's inputs.