问题
This question is about TensorFlow (and TensorBoard) version 2.2rc3, but I have experienced the same issue with 2.1.
Consider the following weird code:
from datetime import datetime
import tensorflow as tf
from tensorflow import keras
inputs = keras.layers.Input(shape=(784, ))
x1 = keras.layers.Dense(32, activation='relu', name='Model/Block1/relu')(inputs)
x1 = keras.layers.Dropout(0.2, name='Model/Block1/dropout')(x1)
x1 = keras.layers.Dense(10, activation='softmax', name='Model/Block1/softmax')(x1)
x2 = keras.layers.Dense(32, activation='relu', name='Model/Block2/relu')(inputs)
x2 = keras.layers.Dropout(0.2, name='Model/Block2/dropout')(x2)
x2 = keras.layers.Dense(10, activation='softmax', name='Model/Block2/softmax')(x2)
x3 = keras.layers.Dense(32, activation='relu', name='Model/Block3/relu')(inputs)
x3 = keras.layers.Dropout(0.2, name='Model/Block3/dropout')(x3)
x3 = keras.layers.Dense(10, activation='softmax', name='Model/Block3/softmax')(x3)
x4 = keras.layers.Dense(32, activation='relu', name='Model/Block4/relu')(inputs)
x4 = keras.layers.Dropout(0.2, name='Model/Block4/dropout')(x4)
x4 = keras.layers.Dense(10, activation='softmax', name='Model/Block4/softmax')(x4)
outputs = x1 + x2 + x3 + x4
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop(),
metrics=['accuracy'])
logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
model.fit(x_train, y_train,
batch_size=64,
epochs=5,
validation_split=0.2,
callbacks=[tensorboard_callback])
When running it and looking at the graph created in TensorBoard you will see the following.
As can be seen, the addition operations are really ugly.
When replacing the line
outputs = x1 + x2 + x3 + x4
With the lines:
outputs = keras.layers.add([x1, x2], name='Model/add/add1')
outputs = keras.layers.add([outputs, x3], name='Model/add/add2')
outputs = keras.layers.add([outputs, x4], name='Model/add/add3')
a much nicer graph is created by TensorBoard (in this second screenshot, the Model as well as one of the inner blocks are shown in details).
The difference between the two representations of the model is that in the second one, we could name the addition operations and group them.
I could not find any way to name these operations, unless by using the keras.layers.add()
. In this model the problem does not look that critical as the model is simple, and it is easy to replace +
with keras.layers.add()
. However, in more complex models, it can become a real pain. For example, operations such as t[:, start:end]
should be translated to complex calls to tf.strided_slice()
. So my models representations are quite messy with plenty of cryptic gather, stride and concat operations.
I wonder if there is a way to wrap / group such operations to allow nicer graphs in TensorBoard.
回答1:
outputs = keras.layers.Add()([x1, x2, x3, x4])
回答2:
Following the hint from Marco Cerliani, Lambda
layer is indeed very useful here. So the following code will group nicely the +
:
outputs = keras.layers.Lambda(lambda x: x[0] + x[1], name='Model/add/add1')([x1, x2])
outputs = keras.layers.Lambda(lambda x: x[0] + x[1], name='Model/add/add2')([outputs, x2])
outputs = keras.layers.Lambda(lambda x: x[0] + x[1], name='Model/add/add3')([outputs, x2])
Or if needed to wrap strides, the following code will group nicely the t[]
:
x1 = keras.layers.Lambda(lambda x: x[:, 0:5], name='Model/stride_concat/stride1')(x1) # instead of x1 = x1[:, 0:5]
x2 = keras.layers.Lambda(lambda x: x[:, 5:10], name='Model/stride_concat/stride2')(x2) # instead of x2 = x2[:, 5:10]
outputs = keras.layers.concatenate([x1, x2], name='Model/stride_concat/concat')
This answers the question asked. But actually, there is still an open issue that is described in another question: 'TensorFlowOpLayer messes up the TensorBoard graphs'
来源:https://stackoverflow.com/questions/61581114/messed-up-tensorboard-graphs-due-to-python-operations