pytorch的模型解析

荒凉一梦 提交于 2020-11-17 03:41:59

如何获取pytorch的动态图?

model = torch.jit.load("test.pth")
graph = model.graph.copy()
torch._C._jit_pass_inline(graph)
node_list = graph.nodes()

加载模型后,获取模型的graph,这个graph就是需要的动态图。graph node就是计算图的计算节点(有序),关于各个层的相关参数都可以从node节点中获取,各个参数的相对位置需要查找一下该op的实现。

需要注意的是,需要使用 _jit_pass_inline来将graph的sub module展开。

如何获取pytorch的权重等参数?

对于非量化模型:

可以通过named_parameters或者state_dict获取。

对于量化模型:

在一次次的尝试和接口的设置中终于找到了!!!

a = 0
for model_name, module in model.named_modules():
    print(model_name)
    print(module)
    if a == 2:
        mod_c = module._c
        #print(mod_c.dump())
        param = module.__getattr__('_packed_params')
        print(param)
        print(type(param))
        print(dir(param))
        print(param._method_names())
        weight,bias = param.unpack()
        print(bias)
        break
    else:
        a = a + 1

首先可查看model的各个attribute的内容,方便后续直接getattr,这个过程可通过_c属性获得,然后将获得的属性值dump出来就可查看各个子module的所有内容,然后根据需要就可以获取想要的属性了。

解析时遇到以下问题:

1、权重的layout如何确定?

通过shape属性可以获取Tensor的shape,通过Tensor.storage()可以获取Tensor里的值,通过Tensor.layout可以获取数据布局。也就是说这个Tensor的值是通过.layout的布局方式来排列,并没有按照shape的顺序来排列。

解析的模型layout是torch.strided,这种布局是按照stride来排列的,这种解释还是比较模糊。

假设3x3卷积的权重的shape为[32,3,3,3](nchw),通过Tensor.stride()就可以获取各个维度的stride信息。权重的stride信息为[27,1,9,3](nchw),含义是:每跨一个n,步长为27,每跨一个c,步长为1,每跨一个h,步长为9,每跨一个w,步长为3。这样也就意味着排列顺序是nhwc,所以解析的时候需要按照框架要求进行转换。

2、权重是int8,解析后的框架是uint8,如何使之能适用于uint8的框架?

看了下量化计算的实现:如果计算是需要减去zeropoint,即计算的时候使用value-zeropoint的值参与计算,那么就可以将权重的所有值都加一个128(int8最小值为-128),zeropoint的值也得加128就可以了

3、quant和dequant op做了啥事情?

pytorch的每个量化tensor存的数据其实还是float数据,在通过storage获取数据时会将其按照scale/zeropoint量化成int8数据。

quant op做的事情:input / scale + zeropoint,将float转换为uint8数据,再将uint8数据反量化后存下来:(input - zeropoint) * scale。即quant op做了量化反量化的事情,将一个float Tensor转换为一个int8 scale tensor。

dequant op就做了一个反量化操作:(input - zeropoint) * scale

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!