问题
I decided to switch from keras to tf.keras (as recommended here). Therefore I installed tf.__version__=2.0.0
and tf.keras.__version__=2.2.4-tf
. In an older version of my code (using some older Tensorflow version tf.__version__=1.x.x
) I used a callback to compute custom metrics on the entire validation data at the end of each epoch. The idea to do so was taken from here. However, it seems as if the "validation_data" attribute is deprecated so that the following code is not working any longer.
class ValMetrics(Callback):
def on_train_begin(self, logs={}):
self.val_all_mse = []
def on_epoch_end(self, epoch, logs):
val_predict = np.asarray(self.model.predict(self.validation_data[0]))
val_targ = self.validation_data[1]
val_epoch_mse = mse_score(val_targ, val_predict)
self.val_epoch_mse.append(val_epoch_mse)
# Add custom metrics to the logs, so that we can use them with
# EarlyStop and csvLogger callbacks
logs["val_epoch_mse"] = val_epoch_mse
print(f"\nEpoch: {epoch + 1}")
print("-----------------")
print("val_mse: {:+.6f}".format(val_epoch_mse))
return
My current workaround is the following. I simply gave validation_data as an argument to the ValMetrics
class :
class ValMetrics(Callback):
def __init__(self, validation_data):
super(Callback, self).__init__()
self.X_val, self.y_val = validation_data
Still I have some questions: Is the "validation_data" attribute really deprecated or can it be found elsewhere? Is there a better way to access the validation data at the end of each epoch than with the above workaround?
Thanks a lot!
回答1:
You are right that the argument, validation_data
is deprecated as per Tensorflow Callbacks Documentation.
The issue which you are facing has been raised in Github. Related issues are Issue1, Issue2 and Issue3.
None of the above Github Issues is resolved and Your workaround of passing Validation_Data
as an argument to Custom Callback is a good one, as per this Github Comment, as many people found it useful.
Specifying the code of workaround below, for the benefit of the Stackoverflow Community
, even though it is present in Github.
class Metrics(Callback):
def __init__(self, val_data, batch_size = 20):
super().__init__()
self.validation_data = val_data
self.batch_size = batch_size
def on_train_begin(self, logs={}):
print(self.validation_data)
self.val_f1s = []
self.val_recalls = []
self.val_precisions = []
def on_epoch_end(self, epoch, logs={}):
batches = len(self.validation_data)
total = batches * self.batch_size
val_pred = np.zeros((total,1))
val_true = np.zeros((total))
for batch in range(batches):
xVal, yVal = next(self.validation_data)
val_pred[batch * self.batch_size : (batch+1) * self.batch_size] = np.asarray(self.model.predict(xVal)).round()
val_true[batch * self.batch_size : (batch+1) * self.batch_size] = yVal
val_pred = np.squeeze(val_pred)
_val_f1 = f1_score(val_true, val_pred)
_val_precision = precision_score(val_true, val_pred)
_val_recall = recall_score(val_true, val_pred)
self.val_f1s.append(_val_f1)
self.val_recalls.append(_val_recall)
self.val_precisions.append(_val_precision)
return
I will keep following the Github Issues mentioned above and will update the Answer accordingly.
Hope this helps. Happy Learning!
来源:https://stackoverflow.com/questions/60080646/access-deprecated-attribute-validation-data-in-tf-keras-callbacks-callback