问题
I have a neural network that's computing a vector quantity u
. I'd like to compute first and second-order jacobian with respect to the input x
, a single element.
Would anybody know how to do that in PyTorch? Below, the code snippet from my project.
import torch
import torch.nn as nn
class PINN(torch.nn.Module):
def __init__(self, layers:list):
super(PINN, self).__init__()
self.linears = nn.ModuleList([])
for i, dim in enumerate(layers[:-2]):
self.linears.append(nn.Linear(dim, layers[i+1]))
self.linears.append(nn.ReLU())
self.linears.append(nn.Linear(layers[-2], layers[-1]))
def forward(self, x):
for layer in self.linears:
x = layer(x)
return x
I then instantiate my network
n_in = 1
units = 50
q = 500
pinn = PINN([n_in, units, units, units, q+1])
pinn
Which returns
PINN(
(linears): ModuleList(
(0): Linear(in_features=1, out_features=50, bias=True)
(1): ReLU()
(2): Linear(in_features=50, out_features=50, bias=True)
(3): ReLU()
(4): Linear(in_features=50, out_features=50, bias=True)
(5): ReLU()
(6): Linear(in_features=50, out_features=501, bias=True)
)
)
Then I compute both FO and SO jacobians
x = torch.randn(1, requires_grad=False)
u_x = torch.autograd.functional.jacobian(pinn, x, create_graph=True)
print("First Order Jacobian du/dx of shape {}, and features\n{}".format(u_x.shape, u_x)
u_xx = torch.autograd.functional.jacobian(lambda _: u_x, x)
print("Second Order Jacobian du_x/dx of shape {}, and features\n{}".format(u_xx.shape, u_xx)
Returns
First Order Jacobian du/dx of shape torch.Size([501, 1]), and features
tensor([[-0.0310],
[ 0.0139],
[-0.0081],
[-0.0248],
[-0.0033],
[ 0.0013],
[ 0.0040],
[ 0.0273],
[ 0.0234],
[ 0.0015],
[ 0.0085],
[ 0.0142],
[ 0.0600],
[-0.0534],
[-0.0087],
[ 0.0552],
[ 0.0110],
[-0.0349],
[-0.0161],
[-0.0145],
[ 0.0142],
[-0.0382],
[ 0.0252],
[-0.0258],
[-0.0287],
[ 0.0300],
[ 0.0352],
[-0.0052],
[ 0.0395],
[-0.0060],
[-0.0013],
[-0.0395],
[-0.0214],
[-0.0730],
[-0.0095],
[-0.0149],
[-0.0056],
[-0.0029],
[-0.0253],
[ 0.0039],
[ 0.0031],
[ 0.0110],
[ 0.0465],
[ 0.0512],
[ 0.0048],
[-0.0019],
[-0.0057],
[-0.0056],
[ 0.0257],
[ 0.0628],
[ 0.0127],
[ 0.0281],
[ 0.0518],
[ 0.0231],
[ 0.0410],
[ 0.0793],
[ 0.0210],
[-0.0526],
[ 0.0244],
[ 0.0327],
[ 0.0134],
[ 0.0106],
[ 0.0685],
[-0.0075],
[ 0.0122],
[ 0.0037],
[-0.0555],
[ 0.0279],
[ 0.0099],
[ 0.0216],
[-0.0738],
[ 0.0441],
[ 0.0358],
[ 0.0223],
[-0.0441],
[-0.0387],
[-0.0315],
[-0.0107],
[ 0.0258],
[ 0.0411],
[ 0.0144],
[-0.0079],
[ 0.0155],
[ 0.0444],
[ 0.0658],
[ 0.0106],
[ 0.0249],
[ 0.0185],
[ 0.0055],
[-0.0322],
[-0.0190],
[-0.0019],
[-0.0009],
[ 0.0286],
[-0.0416],
[-0.0162],
[-0.0034],
[ 0.0593],
[-0.0914],
[ 0.0113],
[ 0.0259],
[ 0.0201],
[-0.0395],
[-0.0179],
[-0.0043],
[-0.0132],
[-0.0061],
[ 0.0229],
[ 0.0256],
[ 0.0230],
[-0.0458],
[-0.0166],
[ 0.0158],
[-0.0024],
[ 0.0040],
[-0.0188],
[ 0.0586],
[ 0.0187],
[ 0.0077],
[-0.0501],
[ 0.0655],
[ 0.0785],
[-0.0285],
[-0.0404],
[ 0.0235],
[ 0.0133],
[-0.0083],
[-0.0838],
[ 0.0242],
[ 0.1028],
[-0.0031],
[-0.0287],
[ 0.0138],
[ 0.0385],
[ 0.0049],
[ 0.0275],
[ 0.0036],
[-0.0152],
[-0.0150],
[ 0.0347],
[ 0.0219],
[ 0.0041],
[ 0.0261],
[-0.0204],
[ 0.0085],
[ 0.0112],
[-0.0164],
[-0.0346],
[ 0.0330],
[ 0.0091],
[-0.0453],
[-0.0300],
[ 0.0553],
[-0.0202],
[-0.0051],
[-0.0281],
[ 0.0070],
[ 0.0615],
[ 0.0352],
[-0.0197],
[ 0.0221],
[ 0.0025],
[-0.0390],
[-0.0106],
[-0.0010],
[-0.0006],
[-0.0081],
[ 0.0507],
[-0.0104],
[ 0.0376],
[ 0.0311],
[ 0.0178],
[-0.0234],
[-0.0291],
[ 0.0053],
[ 0.0099],
[-0.0436],
[ 0.0335],
[-0.0018],
[-0.0019],
[ 0.0670],
[ 0.0409],
[-0.0060],
[-0.0073],
[-0.0200],
[ 0.0279],
[ 0.0010],
[ 0.0089],
[ 0.0060],
[-0.0372],
[-0.0542],
[ 0.0106],
[ 0.0092],
[ 0.0566],
[-0.0541],
[-0.0419],
[ 0.0204],
[-0.0547],
[-0.0161],
[-0.0472],
[ 0.0168],
[-0.0131],
[ 0.0106],
[-0.0005],
[ 0.0105],
[ 0.0206],
[-0.0273],
[-0.0077],
[-0.0065],
[ 0.0766],
[ 0.0790],
[-0.0067],
[ 0.0040],
[-0.0828],
[-0.0092],
[ 0.0752],
[ 0.0707],
[-0.0845],
[ 0.0277],
[-0.0117],
[ 0.0021],
[ 0.0920],
[ 0.0224],
[ 0.0082],
[-0.0144],
[ 0.0324],
[ 0.0060],
[-0.0256],
[ 0.0164],
[ 0.0141],
[-0.0024],
[ 0.0539],
[ 0.0015],
[ 0.0235],
[-0.0252],
[-0.0503],
[ 0.0511],
[-0.0008],
[-0.0119],
[ 0.0041],
[ 0.0076],
[-0.0177],
[ 0.0059],
[ 0.0433],
[ 0.0422],
[ 0.0751],
[ 0.0021],
[ 0.0507],
[-0.0155],
[-0.0344],
[-0.0138],
[ 0.0256],
[-0.0105],
[-0.0425],
[-0.0023],
[ 0.0314],
[ 0.0317],
[ 0.0160],
[-0.0217],
[-0.0438],
[ 0.0244],
[-0.0346],
[ 0.0088],
[ 0.0537],
[ 0.0541],
[ 0.0233],
[-0.0254],
[ 0.0318],
[ 0.0099],
[ 0.0644],
[-0.0043],
[ 0.0347],
[-0.0409],
[-0.0283],
[-0.0251],
[-0.0231],
[ 0.0780],
[-0.0187],
[-0.0313],
[-0.0066],
[-0.0281],
[-0.0269],
[ 0.0807],
[-0.0217],
[ 0.0066],
[ 0.0662],
[-0.0133],
[-0.0244],
[-0.0117],
[-0.0152],
[ 0.0076],
[ 0.0274],
[ 0.0261],
[ 0.0396],
[-0.0094],
[ 0.0276],
[ 0.0527],
[-0.0216],
[ 0.0039],
[-0.0319],
[ 0.0370],
[ 0.0092],
[-0.0205],
[ 0.0155],
[ 0.0221],
[-0.0613],
[ 0.0021],
[ 0.0047],
[-0.0269],
[-0.0464],
[-0.0075],
[ 0.0164],
[ 0.0231],
[ 0.0439],
[ 0.0348],
[-0.0283],
[ 0.0042],
[ 0.0028],
[ 0.0236],
[-0.0632],
[-0.0291],
[ 0.0568],
[ 0.0027],
[-0.0162],
[ 0.0493],
[-0.0120],
[-0.0255],
[-0.0083],
[ 0.0004],
[ 0.0483],
[-0.0043],
[ 0.0603],
[-0.0289],
[ 0.0450],
[ 0.0077],
[-0.0298],
[-0.0064],
[ 0.0074],
[-0.0256],
[ 0.0122],
[-0.0110],
[-0.0078],
[-0.0094],
[-0.0241],
[-0.0047],
[ 0.0173],
[ 0.0321],
[-0.0107],
[-0.0390],
[-0.0053],
[-0.0364],
[ 0.0223],
[ 0.0532],
[ 0.0462],
[ 0.0021],
[ 0.0321],
[-0.0440],
[-0.0049],
[-0.0171],
[-0.0302],
[-0.0192],
[-0.1082],
[-0.0361],
[-0.0015],
[ 0.0041],
[ 0.0172],
[-0.0409],
[-0.0592],
[-0.0004],
[-0.0162],
[ 0.0379],
[ 0.0769],
[-0.0244],
[ 0.0420],
[ 0.0259],
[ 0.0291],
[ 0.0566],
[ 0.0460],
[ 0.0115],
[-0.0127],
[-0.0054],
[ 0.0469],
[ 0.0133],
[-0.0086],
[-0.0240],
[ 0.0042],
[-0.0198],
[ 0.0280],
[ 0.0158],
[-0.0474],
[ 0.0130],
[ 0.0063],
[-0.1066],
[ 0.0362],
[-0.0571],
[ 0.0383],
[ 0.0108],
[ 0.0086],
[-0.0393],
[ 0.0373],
[ 0.0408],
[-0.0264],
[-0.0328],
[ 0.0769],
[ 0.0551],
[ 0.0406],
[ 0.0006],
[ 0.0376],
[-0.0209],
[ 0.0094],
[-0.0120],
[ 0.0645],
[-0.0351],
[ 0.0236],
[-0.0290],
[ 0.0283],
[-0.0391],
[ 0.0019],
[ 0.0216],
[ 0.0080],
[ 0.0153],
[-0.0118],
[-0.0038],
[-0.0125],
[ 0.0374],
[-0.0134],
[ 0.0264],
[-0.0156],
[ 0.0008],
[ 0.0019],
[-0.0237],
[ 0.0394],
[ 0.0267],
[ 0.0021],
[ 0.0002],
[ 0.0214],
[ 0.0088],
[ 0.0125],
[-0.0591],
[-0.0243],
[ 0.0263],
[-0.0117],
[ 0.0199],
[ 0.0130],
[-0.0444],
[-0.0208],
[-0.0272],
[ 0.0323],
[-0.0013],
[-0.0039],
[ 0.0128],
[ 0.0033],
[-0.0173],
[-0.0511],
[ 0.0592],
[ 0.0161],
[ 0.0401],
[ 0.0003],
[ 0.0038],
[ 0.0188],
[-0.0248],
[-0.0501],
[-0.0246],
[ 0.0111],
[-0.0182],
[-0.0194],
[-0.0053],
[-0.0232],
[-0.0045],
[-0.0351],
[ 0.0126],
[ 0.0359],
[ 0.0042],
[-0.0484],
[-0.0437],
[ 0.0380],
[-0.0087],
[-0.0288],
[ 0.0169],
[ 0.0134],
[ 0.0285],
[ 0.0620],
[-0.0097],
[-0.0362],
[-0.0024],
[ 0.0535],
[ 0.0444],
[ 0.0375],
[ 0.0157],
[-0.0061],
[ 0.0201],
[ 0.0064],
[ 0.0407],
[-0.0454],
[-0.0304],
[-0.0418],
[-0.0374],
[ 0.0152],
[ 0.0209],
[ 0.0320],
[-0.0197]], grad_fn=<ViewBackward>)
Second Order Jacobian du/dx of shape torch.Size([501, 1, 1]), and features
tensor([[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]]])
Should not u_xx
be a None
vector if it didn't depend on x
?
Thanks in advance
回答1:
So as @jodag mentioned in his comment, ReLU
being null or linear, its gradient is constant (except on 0
, which is a rare event), so its second-order derivative is zero. I changed the activation function to Tanh
, which finally allows me to compute the jacobian twice.
Final code is
import torch
import torch.nn as nn
class PINN(torch.nn.Module):
def __init__(self, layers:list):
super(PINN, self).__init__()
self.linears = nn.ModuleList([])
for i, dim in enumerate(layers[:-2]):
self.linears.append(nn.Linear(dim, layers[i+1]))
self.linears.append(nn.Tanh())
self.linears.append(nn.Linear(layers[-2], layers[-1]))
def forward(self, x):
for layer in self.linears:
x = layer(x)
return x
def compute_u_x(self, x):
self.u_x = torch.autograd.functional.jacobian(self, x, create_graph=True)
self.u_x = torch.squeeze(self.u_x)
return self.u_x
def compute_u_xx(self, x):
self.u_xx = torch.autograd.functional.jacobian(self.compute_u_x, x)
self.u_xx = torch.squeeze(self.u_xx)
return self.u_xx
Then calling compute_u_xx(x)
on an instance of PINN
with x.require_grad
set to True
gets me there. How to get rid of useless dimensions introduced by torch.autograd.functional.jacobian
remains to be understood though...
来源:https://stackoverflow.com/questions/64978232/pytorch-how-to-compute-second-order-jacobian