Converting pretrained saved model from NCHW to NHWC to make it compatible for Tensorflow Lite

本秂侑毒 提交于 2021-02-18 19:21:46

问题


I have converted a model from PyTorch to Keras and used the backend to extract the tensorflow graph. Since the data format for PyTorch was NCHW, the model extracted and saved is also that. While converting the model to TFLite, due to the format being NCHW, it cannot get converted. Is there a way to convert the whole graph into NHCW?


回答1:


It is better to have a graph with the data-format matched to TFLite for faster inference. One approach is to manually insert transpose ops into the graph, like this example: How to convert the CIFAR10 tutorial to NCHW

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with tf.Session(config=config) as session:

    kernel = tf.ones(shape=[5, 5, 3, 64])
    images = tf.ones(shape=[64,24,24,3])

    imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
    conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
    conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC

    print("conv=",conv.eval())



回答2:


Unfortunately, currently there is no way to convert a NCHW graph to NHWC; you have to start with a NHWC graph to train at the very beginning, if later you want to run with TF lite.



来源:https://stackoverflow.com/questions/50119751/converting-pretrained-saved-model-from-nchw-to-nhwc-to-make-it-compatible-for-te

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!