How to save a Tensorflow Checkpoint file from Google Colaboratory in when using TPU mode?

喜夏-厌秋 提交于 2019-12-04 15:29:15

Another way to do this is to rewrite the model using Keras and use tf.contrib.tpu.keras_to_tpu_model(..) with tf.contrib.tpu.TPUDistributionStrategy(...). Here is small code snippet for this:

def get_model():
  return keras.Sequential([
    keras.layers.Dense(10, input_shape=(4,), activation=tf.nn.relu, name = "Dense_1"),
    keras.layers.Dense(10, activation=tf.nn.relu, name = "Dense_2"),
    keras.layers.Dense(3, activation=None, name = "logits"),
    keras.layers.Dense(3, activation=tf.nn.softmax, name = "softmax")
  ])

dnn_model = get_model()

dnn_model.compile(optimizer=tf.train.AdagradOptimizer(learning_rate=0.1), 
              loss='sparse_categorical_crossentropy',
              metrics=['sparse_categorical_crossentropy'])

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
    dnn_model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)))

# Train the model
tpu_model.fit(
  train_x, train_y,
  steps_per_epoch = steps_per_epoch,
  epochs=epochs,
)

tpu_model.save_weights('./saved_weights.h5', overwrite=True)

You can create a Google Cloud account under the free tier and then create a GCS bucket. After doing that you can authenticate yourself in Colab to get write access to your GCS bucket from Colab by doing the following:

from google.colab import auth
auth.authenticate_user()

Here is a sample Colab notebook that uses Cloud TPUs and GCS.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!