Access deprecated attribute “validation_data” in tf.keras.callbacks.Callback

倖福魔咒の 提交于 2021-02-07 07:27:18

问题


I decided to switch from keras to tf.keras (as recommended here). Therefore I installed tf.__version__=2.0.0 and tf.keras.__version__=2.2.4-tf. In an older version of my code (using some older Tensorflow version tf.__version__=1.x.x) I used a callback to compute custom metrics on the entire validation data at the end of each epoch. The idea to do so was taken from here. However, it seems as if the "validation_data" attribute is deprecated so that the following code is not working any longer.

class ValMetrics(Callback):

    def on_train_begin(self, logs={}):

        self.val_all_mse = []

    def on_epoch_end(self, epoch, logs):

        val_predict = np.asarray(self.model.predict(self.validation_data[0]))
        val_targ = self.validation_data[1]

        val_epoch_mse = mse_score(val_targ, val_predict)

        self.val_epoch_mse.append(val_epoch_mse)

        # Add custom metrics to the logs, so that we can use them with
        # EarlyStop and csvLogger callbacks
        logs["val_epoch_mse"] = val_epoch_mse

        print(f"\nEpoch: {epoch + 1}")
        print("-----------------")
        print("val_mse:     {:+.6f}".format(val_epoch_mse))

        return

My current workaround is the following. I simply gave validation_data as an argument to the ValMetrics class :

class ValMetrics(Callback):

    def __init__(self, validation_data):
        super(Callback, self).__init__()
        self.X_val, self.y_val = validation_data

Still I have some questions: Is the "validation_data" attribute really deprecated or can it be found elsewhere? Is there a better way to access the validation data at the end of each epoch than with the above workaround?

Thanks a lot!


回答1:


You are right that the argument, validation_data is deprecated as per Tensorflow Callbacks Documentation.

The issue which you are facing has been raised in Github. Related issues are Issue1, Issue2 and Issue3.

None of the above Github Issues is resolved and Your workaround of passing Validation_Data as an argument to Custom Callback is a good one, as per this Github Comment, as many people found it useful.

Specifying the code of workaround below, for the benefit of the Stackoverflow Community, even though it is present in Github.

class Metrics(Callback):

    def __init__(self, val_data, batch_size = 20):
        super().__init__()
        self.validation_data = val_data
        self.batch_size = batch_size

    def on_train_begin(self, logs={}):
        print(self.validation_data)
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []

    def on_epoch_end(self, epoch, logs={}):
        batches = len(self.validation_data)
        total = batches * self.batch_size

        val_pred = np.zeros((total,1))
        val_true = np.zeros((total))

        for batch in range(batches):
            xVal, yVal = next(self.validation_data)
            val_pred[batch * self.batch_size : (batch+1) * self.batch_size] = np.asarray(self.model.predict(xVal)).round()
            val_true[batch * self.batch_size : (batch+1) * self.batch_size] = yVal

        val_pred = np.squeeze(val_pred)
        _val_f1 = f1_score(val_true, val_pred)
        _val_precision = precision_score(val_true, val_pred)
        _val_recall = recall_score(val_true, val_pred)

        self.val_f1s.append(_val_f1)
        self.val_recalls.append(_val_recall)
        self.val_precisions.append(_val_precision)

        return

I will keep following the Github Issues mentioned above and will update the Answer accordingly.

Hope this helps. Happy Learning!



来源:https://stackoverflow.com/questions/60080646/access-deprecated-attribute-validation-data-in-tf-keras-callbacks-callback

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!