import numpy as np
a = np.array([[1,2,3],
[4,5,6],
[7,8,9]])
b = np.array([[1,2,3]]).T
c = a.dot(b) #function
jacobian = a # as partial
While autograd
is a good library, make sure to check out its upgraded version JAX which is very well documented (compared to autograd).
A simple example:
import jax.numpy as jnp
from jax import jacfwd
# Define some simple function.
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Note that here, I want a derivative of a "vector" output function (inputs*a + b is a vector) wrt a input
# "vector" a at a0: Derivative of vector wrt another vector is a matrix: The Jacobian
def simpleJ(a, b, inputs): #inputs is a matrix, a & b are vectors
return sigmoid(jnp.dot(inputs, a) + b)
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
b = jnp.array([0.2, 0.1, 0.3, 0.2])
a0 = jnp.array([0.1,0.7,0.7])
# Isolate the function: variables to be differentiated from the constant parameters
f = lambda a: simpleJ(a, b, inputs) # Now f is just a function of variable to be differentiated
J = jacfwd(f)
# Till now I have only calculated the derivative, it still needs to be evaluated at a0.
J(a0)