问题
Update:
Did a bit more testing and I can't reproduce the behaviour with:
import tensorflow as tf
import numpy as np
@tf.function
def tf_being_unpythonic(an_input, another_input):
return an_input + another_input
@tf.function
def example(*inputs, other_args = True):
return tf_being_unpythonic(*inputs)
class TestClass(tf.keras.Model):
def __init__(self, a, b):
super().__init__()
self.a= a
self.b = b
@tf.function
def call(self, *inps, some_kwarg=False):
if some_kwarg:
return self.a(*inps)
return self.b(*inps)
class Model(tf.keras.Model):
def __init__(self):
super().__init__()
self.inps = tf.keras.layers.Flatten()
self.hl1 = tf.keras.layers.Dense(5)
self.hl2 = tf.keras.layers.Dense(4)
self.out = tf.keras.layers.Dense(1)
@tf.function
def call(self,observation):
x = self.inps(observation)
x = self.hl1(x)
x = self.hl2(x)
return self.out(x)
class Model2(Model):
def __init__(self):
super().__init__()
self.prein = tf.keras.layers.Concatenate()
@tf.function
def call(self,b,c):
x = self.prein([b,c])
return super().call(x)
am = Model()
pm = Model2()
test = TestClass(am,pm)
a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))
test(a,some_kwarg=True)
test(a,b)
So it's probably a bug somewhere else.
@tf.function
def call(self, *inp, target=False, training=False):
if not len(inp):
raise ValueError("Call requires some input")
if target:
return self._target_network(*inp, training)
return self._network(*inp, training)
I get:
ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []
But print(inp) gives:
(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,)
I've since edited and was just uncommited toy code so can't investigate further. Will leave the question here so that everyone who doesn't get this issue won't have something to read.
回答1:
I don't think that using a *args
construct is a good practice for a tf.function
. As you can see, most of the TF functions accepting a variable number of inputs use a tuple.
So, you can rewrite your function signature as:
def call(self, inputs, target=False, training=False)
and calling it with:
instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])
Edit
By the way, I don't see any error while using tf.function
with a *args
construct:
import tensorflow as tf
@tf.function
def call(*inp, target=False, training=False):
if not len(inp):
raise ValueError("Call requires some input")
return inp[0]
def main():
print(call(1))
print(call(2, 2))
print(call(3, 3, 3))
if __name__ == '__main__':
main()
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
So you should provide us more informations about what you try to do and where the error is.
回答2:
This may have been a bug that was resolved recently. *args
and **kwargs
should work fine.
来源:https://stackoverflow.com/questions/59167107/tensorflow-2-how-to-use-args-in-tf-function