问题
I'm getting back into python and have been trying out some stuff with tensorflow and keras. I wanted to use the plot_model function and after sorting out some graphviz issues I am now getting this error -
TypeError: add_node() received a non node class object:
I've tried to find an answer myself but have come up short, as the only answer I found with this error didn't seem to be to do with tf. Any suggestions or alternative ideas would be greatly appreciated. Here's the code and error message - my first question on here so sorry if I missed anything, just let me know.
I'm using miniconda3 with python 3.8
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense, Dropout
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping
from numpy import argmax
from matplotlib import pyplot
from random import randint
tf.keras.backend.set_floatx("float64")
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
class mnist_model(Model):
def __init__(self):
super(mnist_model, self).__init__()
self.conv = Conv2D(32, 3, activation = tf.nn.leaky_relu, kernel_initializer = 'he_uniform', input_shape = (28, 28, 3))
self.pool = MaxPool2D((2,2))
self.flat = Flatten()
self.den1 = Dense(128, activation = tf.nn.relu, kernel_initializer = 'he_normal')
self.drop = Dropout(0.25)
self.den2 = Dense(10, activation = tf.nn.softmax)
def call(self, inputs):
n = self.conv(inputs)
n = self.pool(n)
n = self.flat(n)
n = self.den1(n)
n = self.drop(n)
return self.den2(n)
model = mnist_model()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
limit = EarlyStopping(monitor = 'val_loss', patience = 5)
history = model.fit(x_train, y_train, batch_size=64, epochs = 1, verbose = 2, validation_split = 0.15, steps_per_epoch = 100, callbacks = [limit])
print("\nTraining finished\n\nTesting 10000 samples")
model.evaluate(x_test, y_test, verbose = 1)
print("Testing finished\n")
plot_model(model, show_shapes = True, rankdir = 'LR')
##################################################################################################################################################################
## Error message: ##
Train on 51000 samples, validate on 9000 samples
Training finished
Testing 10000 samples
10000/10000 [==============================] - 7s 682us/sample - loss: 0.2447 - accuracy: 0.9242
Testing finished
Traceback (most recent call last):
File "C:\Users\Thomas\Desktop\Various Python\Tensorflow\Tensorflow_experimentation\tc_mnist.py", line 60, in <module>
plot_model(model, show_shapes = True, rankdir = 'LR')
File "C:\Users\Thomas\miniconda3\envs\tensorflow\lib\site-packages\tensorflow_core\python\keras\utils\vis_utils.py", line 283, in plot_model
dpi=dpi)
File "C:\Users\Thomas\miniconda3\envs\tensorflow\lib\site-packages\tensorflow_core\python\keras\utils\vis_utils.py", line 131, in model_to_dot
dot.add_node(node)
File "C:\Users\Thomas\miniconda3\envs\tensorflow\lib\site-packages\pydotplus\graphviz.py", line 1281, in add_node
'class object: {}'.format(str(graph_node))
TypeError: add_node() received a non node class object: <pydotplus.graphviz.Node object at 0x00000221C7E3E888>`
回答1:
I think root-cause of the issue is with shape inference of Subclassed model where model.summary
shows multiple
as Output Shape
. I added a model call within the subclassed model as shown below.
def model(self):
x = tf.keras.layers.Input(shape=(28, 28, 1))
return Model(inputs=[x], outputs=self.call(x))
With this modification, shape inference is automatic in Functional API. As Functional and Sequential model as static graphs of layers, we can get the shape inference easily. However, subclassed model is a piece of python code (a call method) and there is no graph of layers to infer easily. We cannot know how layers are connected to each other (because that's defined in the body of call, not as an explicit data structure), so we cannot infer input / output shapes.
Please check full code here for your reference.
来源:https://stackoverflow.com/questions/61603966/tf-keras-plot-model-add-node-received-a-non-node-class-object