Create tf.keras callback to save model predictions and targets for each batch during training in tf 2.0

笑着哭i 提交于 2020-01-15 06:18:09


In tensorflow 2 fetches and assign is not any more supported. Accessing batch results in tf 1.x in a custom keras callback is possible following the answer provided in In tf.keras and tf 2.0 under eager execution fetches are not supported, therefore the solution provided for tf 1.x is not working. Is there a way to get the y_true and y_pred inside the on_batch_end callback of a tf.keras custom callback?

I have tried to modify the answer working in tf.1 like below

from tf.keras.callbacks import Callback

class CollectOutputAndTarget(Callback):
    def __init__(self):
        super(CollectOutputAndTarget, self).__init__()
        self.targets = []  # collect y_true batches
        self.outputs = []  # collect y_pred batches

    def on_batch_end(self, batch, logs=None):
        # evaluate the variables and save them into lists
        # How to change the following 2 lines so that in tf.2 eager execution collect the batch results

When I run the code above the code fails, accessing data in self.model._targets[0] or self.model.outputs[0] apparently is not possible

