问题
I'm trying to use the @tf.function directive with the Keras functional API, to create a TF graph in the training step of a simple neural network. I'm using Tensorflow v 2.1.0 installed with Python 3.7. However I obtain the runtime error as in title and I would appreciate any hint to understand the reason of that.
The code is the following.
import tensorflow as tf
import numpy as np
# import the CIFAR10 dataset and normalise the feature distributions
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
train_images = train_images / np.max(train_images)
test_images = test_images / np.max(train_images)
# convert the datasets to tf.data, batching the data
train_data = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(128)
test_data = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(128)
# make a model with a single dense layer
# note that the flatten layer is needed to convert the
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units = 10, activation = "relu"))
# compile the model
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
metrics = ["accuracy"])
# training step
@tf.function
def train(model, train_data, test_data):
model.fit(x = train_data,
validation_data = test_data,
epochs = 10)
return
# train the model
train(model = model, train_data = train_data, test_data = test_data)
The error that I get at runtime is as follows.
2020-04-01 11:33:27.084545: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 1228800000 exceeds 10% of system memory.
Traceback (most recent call last):
File "report.py", line 41, in <module>
train(model = model, train_data = train_data, test_data = test_data)
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__
result = self._call(*args, **kwds)
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 615, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 497, in _initialize
*args, **kwds))
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 439, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
RuntimeError: in converted code:
report.py:34 train *
model.fit(x = train_data,
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:819 fit
use_multiprocessing=use_multiprocessing)
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:648 fit
shuffle=shuffle)
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2346 _standardize_user_data
all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2523 _build_model_with_inputs
inputs, targets, _ = training_utils.extract_tensors_from_dataset(inputs)
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py:1677 extract_tensors_from_dataset
iterator = get_iterator(dataset)
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py:1658 get_iterator
initialize_iterator(iterator)
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py:1665 initialize_iterator
K.get_session((init_op,)).run(init_op)
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:493 get_session
session = _get_session(op_input_list)
/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:453 _get_session
raise RuntimeError('Cannot get session inside Tensorflow graph function.')
RuntimeError: Cannot get session inside Tensorflow graph function.
Please note that the same code as before runs fine without the @tf.function directive. On other hands, I get the same error on different datasets and on different models.
Thanks in advance.
回答1:
Looking at the documentation https://www.tensorflow.org/guide/function it isn't clear to me that the function you have defined could be compiled into a graph. I think it is meant to be used on functions that get used in a Lambda layer https://www.tensorflow.org/api_docs/python/tf/keras/layers/Lambda or similar.
You have already called compile on the model, which will be converting it into a graph, nothing more to do.
My guess is that it is throwing because it has no idea how to build a graph from the model.fit
call, but the error message is very confusing.
If you try a simple arithmetic function like
@tf.function
def add(x, y):
return x + y
add(1, 2)
This now outputs a tensor:
<tf.Tensor: shape=(), dtype=int32, numpy=3>
TensorFlow is fast. I wouldn't worry about performance until you really understand what is going on in the library and you know there is an issue.
来源:https://stackoverflow.com/questions/60968096/tensorflow-tf-function-cannot-get-session-inside-tensorflow-graph-function