问题
For my use case, I require to be able to take a pytorch module and interpret the sequence of layers in the module so that I can create a “connection” between the layers in some file format. Now let’s say I have a simple module as below
class mymodel(nn.Module):
def __init__(self, input_channels):
super(mymodel, self).__init__()
self.fc = nn.Linear(input_channels, input_channels)
def forward(self, x):
out = self.fc(x)
out += x
return out
if __name__ == "__main__":
net = mymodel(5)
for mod in net.modules():
print(mod)
Here the output yields:
mymodel(
(fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)
as you can see the information about the plus equals operation or plus operation is not captured as it is not a nnmodule in the forward function. My goal is to be able to create a graph connection from the pytorch module object to say something like this in json :
layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}
The input tensor names are arbitrary but it captures the essence of the graph and the connections between layers.
My question is, is there a way to extract the information from a pytorch object? I was thinking to use the .modules() but then realized that hand written operations are not captured this way as a module. I guess if everything is an nn.module then the .modules() might give me the network layer arrangement. Looking for some help here. I want to be able to know the connections between tensors to create a format as above.
回答1:
The information you are looking for is not stored in the nn.Module
, but rather in the grad_fn
attribute of the output tensor:
model = mymodel(channels)
pred = model(torch.rand((1, channels))
pred.grad_fn # all the information is in the computation graph of the output tensor
It is not trivial to extract this information. You might want to look at torchviz package that draws a nice graph from the grad_fn
information.
来源:https://stackoverflow.com/questions/58253003/deriving-the-structure-of-a-pytorch-network