torchviz方式:
1 from torchviz import make_dot 2 inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH).requires_grad_(True) #有).requires_grad_(True)显示输入形状 3 model = vgg() #model是vgg类的实例 4 vis_graph = make_dot(model(inputs_fake), params=dict(list(model.named_parameters()) + [('x', inputs_fake)])) 5 vis_graph.view()
tensorboardX方式:
from tensorboardX import SummaryWriter inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH) with SummaryWriter(comment='vgg') as w: w.add_graph(model, (inputs_fake,))
torchviz生成一个pdf,pdf怎样命名还不知道,或许只能默认命名。