How to get weights in tflite using c++ api?

你。 提交于 2021-01-05 06:39:06

问题


I am using a .tflite model on device. The last layer is ConditionalRandomField layer, and I need weights of the layer to do prediction. How do I get weights with c++ api?

related: How can I view weights in a .tflite file?

Netron or flatc doesn't meet my needs. too heavy on device.

It seems TfLiteNode stores weights in void* user_data or void* builtin_data. How do I read them?

UPDATE:

Conclusion: .tflite doesn't store CRF weights while .h5 dose. (Maybe because they do not affect output.)

WHAT I DO:

// obtain from model.
Interpreter *interpreter;
// get the last index of nodes.
// I'm not sure if the index sequence of nodes is the direction which tensors or layers flows.
const TfLiteNode *node = &((interpreter->node_and_registration(interpreter->nodes_size()-1))->first);

// then follow the answer of @yyoon

回答1:


In a TFLite node, the weights should be stored in the inputs array, which contains the index of the corresponding TfLiteTensor*.

So, if you have already obtained the TfLiteNode* of the last layer, you could do something like this to read the weight values.

TfLiteContext* context; // You would usually have access to this already.
TfLiteNode* node;       // <obtain this from the graph>;

for (int i = 0; i < node->inputs->size; ++i) {
  TfLiteTensor* input_tensor = GetInput(context, node, i);

  // Determine if this is a weight tensor.
  // Usually the weights will be memory-mapped read-only tensor
  // directly baked in the TFLite model (flatbuffer).
  if (input_tensor->allocation_type == kTfLiteMmapRo) {
    // Read the values from input_tensor, based on its type.
    // For example, if you have float weights,
    const float* weights = GetTensorData<float>(input_tensor);

    // <read the weight values...>
  }
}


来源:https://stackoverflow.com/questions/61930021/how-to-get-weights-in-tflite-using-c-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!