TensorFlow Federated: How can I write an Input Spec for a model with more than one input

泄露秘密 提交于 2020-12-12 11:05:00

问题


I'm trying to make an image captioning model using the federated learning library provided by tensorflow, but I'm stuck at this error

Input 0 of layer dense is incompatible with the layer: : expected min_ndim=2, found ndim=1.

this is my input_spec:

input_spec=collections.OrderedDict(x=(tf.TensorSpec(shape=(2048,), dtype=tf.float32), tf.TensorSpec(shape=(34,), dtype=tf.int32)), y=tf.TensorSpec(shape=(None), dtype=tf.int32))

The model takes image features as the first input and a list of vocabulary as a second input, but I can't express this in the input_spec variable. I tried expressing it as a list of lists but it still didn't work. What can I try next?


回答1:


Great question! It looks to me like this error is coming out of TensorFlow proper--indicating that you probably have the correct nested structure, but the leaves may be off. Your input spec looks like it "should work" from TFF's perspective, so it seems it is probably slightly mismatched with the data you have

The first thing I would try--if you have an example tf.data.Dataset which will be passed in to your client computation, you can simply read input_spec directly off this dataset as the element_spec attribute. This would look something like:

# ds = example dataset
input_spec = ds.element_spec

This is the easiest path. If you have something like "lists of lists of numpy arrays", there is still a way for you to pull this information off the data itself--the following code snippet should get you there:

# data = list of list of numpy arrays
input_spec = tf.nest.map_structure(lambda x: tf.TensorSpec(x.shape, x.dtype), data)

Finally, if you have a list of lists of tf.Tensors, TensorFlow provides a similar function:

# tensor_structure = list of lists of tensors
tf.nest.map_structure(tf.TensorSpec.from_tensor, tensor_structure)

In short, I would reocmmend not specifying input_spec by hand, but rather letting the data tell you what its input spec should be.



来源:https://stackoverflow.com/questions/61034455/tensorflow-federated-how-can-i-write-an-input-spec-for-a-model-with-more-than-o

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!