Algorithmic Differentiation

for the working developer


Mathieu Besançon
@matbesancon
INRIA Lille, Polytechnique Montréal

Slides: https://matbesancon.github.io/slides/ad4dev

In [3]:
import LightGraphs
const LG = LightGraphs
import SimpleWeightedGraphs
using SimpleWeightedGraphs: SimpleWeightedDiGraph
import GraphPlot

Motivations

Finding a shortest path... in Paris... with strikes... 🔥

In [4]:
# 4 vertices-graph, weighted edges
const paris_evening = SimpleWeightedDiGraph(4)
const labels = (gare = 1, metro = 2, bike = 3, coffee = 4)
LG.add_edge!(paris_evening, 1, 2, 7.0)
LG.add_edge!(paris_evening, 1, 3, 4.0)
LG.add_edge!(paris_evening, 2, 4, 15.0)
LG.add_edge!(paris_evening, 3, 4, 25.0);
In [5]:
GraphPlot.gplot(
    paris_evening,
    nodelabel = collect(keys(labels)),
)
Out[5]:
gare metro bike coffee
In [6]:
GraphPlot.gplot(
    paris_evening,
    nodelabel = collect(keys(labels)),
    edgelabel = [e.weight for e in LG.edges(paris_evening)],
)
Out[6]:
7.0 4.0 15.0 25.0 gare metro bike coffee

Did I mention strikes?

prob_failure: probability of the metro not working => have to go back and take a bike.

  • Metro working: shortest path = 7 + 15 = 22min
  • Metro not working: shortest path = 4 + 25 = 29min
  • Going to metro and then taking a bike: 2 × 7 + 4 + 25 = 43min
In [7]:
function shortest_path_to_coffee(g::SimpleWeightedDiGraph, p::Real)
    times = LG.weights(g)
    metro_duration = times[1,2] + times[2,4]
    bike_duration = times[1,3] + times[3,4]
    no_metro_duration = 2 * times[1,2] + bike_duration
    average_metro_time = p * no_metro_duration + (1-p) * metro_duration
    if average_metro_time < bike_duration
        average_metro_time
    else
        bike_duration
    end
end
Out[7]:
shortest_path_to_coffee (generic function with 1 method)
In [8]:
import Plots
ps = collect(0.0:0.005:1.0)
Plots.plot(ps, [shortest_path_to_coffee(paris_evening, p) for p in ps], label="", width=5)
Plots.xlabel!("Probability")
Plots.ylabel!("Average time")
Out[8]:
0.00 0.25 0.50 0.75 1.00 22 24 26 28 Probability Average time

Sensitivity

At a given solution, how sensible is the output to a change of solution?

In [9]:
import ForwardDiff
# ForwardDiff.derivative(f, x)
In [10]:
ForwardDiff.derivative(
    p -> shortest_path_to_coffee(paris_evening, p),
    0.5
)
Out[10]:
0.0
In [11]:
ForwardDiff.derivative(
    p -> shortest_path_to_coffee(paris_evening, p),
    0.2
)
Out[11]:
21.0

Maths overdose?

Two approaches

Computing derivatives => non-standard interpretation of a function/program.

  • Source-to-source transformation: transform the initial program to compute sensitivities
  • Operator overloading: create a number-like structure carrying derivatives

This structure is a Dual Number and defines all standard arithmetic operations
+, -, *, /...

Number: $x$
Dual number: $x + y\varepsilon$
$x$ is the "real" part, $y$ the sensitivity, $\varepsilon$ the small variation.

A Haskell implementation

newtype Dual a = Dual (a, a)

instance Num a => Num (Dual a) where
    Dual (x1,y1) + Dual (x2,y2) = Dual (x1+x2, y1+y2)
    Dual (x1,y1) - Dual (x2,y2) = Dual (x1-x2, y1-y2)
    Dual (x1,y1) * Dual (x2,y2) = Dual (x1*x2, x1*y2 + x2*y1)
    negate (Dual (x,y))         = Dual (-x,-y)
    abs (Dual (x,y))            = Dual (abs x, 0)
    signum (Dual (x,y))         = Dual(signum x, 0)
    fromInteger i               = Dual ((fromInteger i), 0)

instance Fractional a => Fractional (Dual a) where
    Dual (x1,y1) / Dual (x2,y2)   = Dual ((x1/x2), ((y1*x2 - x1*y2)/(x1*x1)))
    fromRational r                = Dual ((fromRational r), 0)

Usage

$f(x) = 3x^2 + 2x + 1$

f :: (Num a) => a -> a
f x = 3 * x*x + 2*x + 1

f (Dual (2.0, 1.0))
-- Dual (17.0,14.0)

$ f(2) = 3 \times 2^2 + 2 \times 2 + 1 = 17 $
$ f'(2) = 3 \times 2 \times 2 + 2 = 3 \times 4 + 2 = 12 + 2 = 14 $

Requirements for AD

  • Abstract way to deal with numbers: type classes, interfaces
  • Generically-typed algorithms

For source-to-source:

  • Ability to hook in the compilation process

Does this work everywhere?

Two algorithms for computing sqrt(s):

Babylonian:

function sqrt_babylonian(s)
    x = s / 2
    while abs(x^2 - s) > 0.01
        x = (x + s/x)/2        
    end
    x
end

Fast approximated sqrt:

C:

float sqrt_approx(float s) {
    int x = *(int*)&s; /* Same bits, but as an int */
    x -= 1 << 23; /* Subtract 2^m. */
    x >>= 1; /* Divide by 2. */
    x += 1 << 29; /* Add ((b + 1) / 2) * 2^m. */
    return *(float*)&x; /* Interpret again as float */
}

Julia

function sqrt_approx(s::Float32)
    x = reinterpret(Int32, s)
    x -= 1 << 23
    x >>= 1
    x += 1 << 29
    reinterpret(Float32, x)
end

Thanks λ

  • "Does this work in my fav language?"
  • "Why are we doing maths?"
  • "Is there an Ocaml example?"

Resources and bonus

Julia version of the Dual number example:
https://repl.it/@matbesancon/dual-prototype

Inspiration for the sqrt example:
AD in 10 minutes, Alan Edelman
sqrt implementations: Wikipedia

AD from a Haskeller:
AD for Dummies - Simon Peyton Jones, ECOOP 2019

Advanced Haskell-flavoured introduction to AD:
The Simple Essence of Automatic Differentiation - Conal Elliott. Paper and resources

Differentiate all the things:
A Differentiable Programming System to Bridge Machine Learning and Scientific Computing, Mike Innes et al.

What's up in Python?
https://github.com/google/jax

GIF sources: Tenor

Proof for sqrt

In [12]:
function sqrt_babylonian(s)
    x = s / 2
    while abs(x^2 - s) > 0.001
        x = (x + s/x)/2        
    end
    x
end

function sqrt_approx(s::Float32)
    x = reinterpret(Int32, s)
    x -= 1 << 23
    x >>= 1
    x += 1 << 29
    reinterpret(Float32, x)
end

@show ForwardDiff.derivative(sqrt, 2.0)
@show ForwardDiff.derivative(sqrt_babylonian, 2.0);
ForwardDiff.derivative(sqrt, 2.0) = 0.35355339059327373
ForwardDiff.derivative(sqrt_babylonian, 2.0) = 0.353541906958862
In [13]:
@show ForwardDiff.derivative(sqrt_approx, 2.0);
MethodError: no method matching sqrt_approx(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(sqrt_approx),Float64},Float64,1})
Closest candidates are:
  sqrt_approx(!Matched::Float32) at In[12]:10

Stacktrace:
 [1] derivative(::typeof(sqrt_approx), ::Float64) at /home/matbesancon/.julia/packages/ForwardDiff/Asf4O/src/derivative.jl:14
 [2] top-level scope at show.jl:562
 [3] top-level scope at In[13]:1