keras.layers.TimeDistributed with hub.KerasLayer NotImplementedError

给你一囗甜甜゛ 提交于 2021-01-20 09:36:27

问题


I want to use tf.keras.TimeDistributed() layer with the tf.hub inception_v3 CNN model from the latest TensorFLow V2 version (tf-nightly-gpu-2.0-preview). The output is shown below. It seemst that tf.keras.TimeDistributed() is not fully implemented to work with tf.hub models. Somehow, the shape of the input layer cannot be computed. My question: Is there a workaround this problem?

tf.keras.TimeDistributed with regular tf.keras.layer works fine. I just would like to apply the CNN model to each time step.

Model

import tensorflow as tf
import tensorflow_hub as hub 
from tensorflow.keras import layers, Model

model_url = "https://tfhub.dev/google/tf2- 

preview/inception_v3/feature_vector/3"

feature_layer = hub.KerasLayer(model_url,
                               input_shape = (299, 299, 3),
                               output_shape = [2048],
                               trainable = False)

video = layers.Input(shape = (None, 299, 299, 3))

encoded_frames = layers.TimeDistributed(feature_layer)(video)

model = Model(inputs = video, outputs = encoded_frames)

Expected output

tf.keras model

Error messages

File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 489, in compute_output_shape raise NotImplementedError NotImplementedError


回答1:


In Tensorflow 2 it is possible to use custom layers in combination with the TimeDistributed layer. The error is thrown because it can't compute the output shape (see here).

So in your case you should be able to subclass KerasLayer and implement compute_output_shape manually.




回答2:


Wrapper Layers like TimeDistributed require a layer instance to be passed. If you build the model out of custom layers, you'll need to at least wrap them in tf.keras.layers.Lambda. This might not be possible in your case of models from hub.KerasLayer, so you might consider the solutions posted here:

TimeDistributed of a KerasLayer in Tensorflow 2.0



来源:https://stackoverflow.com/questions/56173992/keras-layers-timedistributed-with-hub-keraslayer-notimplementederror

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!