pytorch模型可视化,torchviz和tensorboardX方式
torchviz方式: 1 from torchviz import make_dot 2 inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH).requires_grad_(True) #有).requires_grad_(True)显示输入形状 3 model = vgg() #model是vgg类的实例 4 vis_graph = make_dot(model(inputs_fake), params=dict(list(model.named_parameters()) + [('x', inputs_fake)])) 5 vis_graph.view() tensorboardX方式: from tensorboardX import SummaryWriter inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH) with SummaryWriter(comment='vgg') as w: w.add_graph(model, (inputs_fake,)) torchviz生成一个pdf,pdf怎样命名还不知道,或许只能默认命名。 来源: https://www.cnblogs.com/zhangziyan