How to graph tf.keras model in Tensorflow-2.0?

末鹿安然 提交于 2019-12-09 17:27:25

问题


I upgraded to Tensorflow 2.0 and there is no tf.summary.FileWriter("tf_graphs", sess.graph). I was looking through some other StackOverflow questions on this and they said to use tf.compat.v1.summary etc. Surely there must be a way to graph and visualize a tf.keras model in Tensorflow version 2. What is it? I'm looking for a tensorboard output like the one below. Thank you!


回答1:


According to the docs, you can use Tensorboard to visualise graphs once your model has been trained.

First, define your model and run it. Then, open Tensorboard and switch to the Graph tab.


Minimal Compilable Example

This example is taken from the docs. First, define your model and data.

# Relevant imports.
%load_ext tensorboard

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
from packaging import version

import tensorflow as tf
from tensorflow import keras

# Define the model.
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

(train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0

Next, train your model. Here, you will need to define a callback for Tensorboard to use for visualising stats and graphs.

# Define the Keras TensorBoard callback.
logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

# Train the model.
model.fit(
    train_images,
    train_labels, 
    batch_size=64,
    epochs=5, 
    callbacks=[tensorboard_callback])

After training, in your notebook, run

%tensorboard --logdir logs

And switch to the Graph tab in the navbar:

You will see a graph that looks a lot like this:




回答2:


You can visualize the graph of any tf.function decorated function, but first, you have to trace its execution.

Visualizing the graph of a Keras model means to visualize it's call method.

By default, this method is not tf.function decorated and therefore you have to wrap the model call in a function correctly decorated and execute it.

import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)


@tf.function
def traceme(x):
    return model(x)


logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
    tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)



回答3:


Here's what is working for me at the moment (TF 2.0.0), based on the tf.keras.callbacks.TensorBoard code:

# After model has been compiled
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.keras.backend import get_graph
tb_path = '/tmp/tensorboard/'
tb_writer = tf.summary.create_file_writer(tb_path)
with tb_writer.as_default():
    if not model.run_eagerly:
        summary_ops_v2.graph(get_graph(), step=0)


来源:https://stackoverflow.com/questions/56690089/how-to-graph-tf-keras-model-in-tensorflow-2-0

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!