问题
I'm trying to train a CNN regression net in TF 1.12, using TPU v3-8 1.12 instance. The model succesfully compiles with XLA, starting the training process, but some where after the half iterations of the 1t epoch freezes, and doing nothing. I cannot find the root of the problem.
def read_tfrecord(example):
features = {
'image': tf.FixedLenFeature([], tf.string),
'labels': tf.FixedLenFeature([], tf.string)
}
sample=tf.parse_single_example(example, features)
image = tf.image.decode_jpeg(sample['image'], channels=3)
image = tf.reshape(image, tf.stack([540, 540, 3]))
image = augmentation(image)
labels = tf.decode_raw(sample['labels'], tf.float64)
labels = tf.reshape(labels, tf.stack([2,2,45]))
labels = tf.cast(labels, tf.float32)
return image, labels
def load_dataset(filenames):
files = tf.data.Dataset.list_files(filenames)
dataset = files.apply(tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4))
dataset = dataset.apply(tf.data.experimental.map_and_batch(map_func=read_tfrecord, batch_size=BATCH_SIZE, drop_remainder=True))
dataset = dataset.apply(tf.data.experimental.shuffle_and_repeat(1024, -1))
dataset = dataset.prefetch(buffer_size=1024)
return dataset
def augmentation(img):
image = tf.cast(img, tf.float32)/255.0
image = tf.image.random_brightness(image, max_delta=25/255)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.per_image_standardization(image)
return image
def get_batched_dataset(filenames):
dataset = load_dataset(filenames)
return dataset
def get_training_dataset():
return get_batched_dataset(training_filenames)
def get_validation_dataset():
return get_batched_dataset(validation_filenames)
回答1:
The most likely cause is an issue in the data pre-processing function, take a look at the troubleshooting documentation Errors in the middle of training, it could be helpful to get a guidance.
I did not catch anything strange with your code.
Are you using Cloud Storage Buckets to work with those images and files? If yes, Are those buckets in the same region?
You might use Cloud TPU Audit Logs to determine if the issue is related with the resources in the system or how you are accessing your data.
Finally I suggest you to take a look in the Training Mask RCNN on Cloud TPU documentation.
来源:https://stackoverflow.com/questions/57240149/tpu-training-freezes-in-the-middle-of-training