pytorch模型可视化,torchviz和tensorboardX方式

淺唱寂寞╮ 提交于 2019-12-05 06:17:52

torchviz方式:

1 from torchviz import make_dot
2 inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH).requires_grad_(True)  #有).requires_grad_(True)显示输入形状
3 model = vgg()  #model是vgg类的实例
4 vis_graph = make_dot(model(inputs_fake), params=dict(list(model.named_parameters()) + [('x', inputs_fake)]))
5 vis_graph.view()

tensorboardX方式:

from tensorboardX import SummaryWriter
inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH)
with SummaryWriter(comment='vgg') as w:
    w.add_graph(model, (inputs_fake,))

torchviz生成一个pdf,pdf怎样命名还不知道,或许只能默认命名。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!