Load keras model h5 unknown metrics

别等时光非礼了梦想. 提交于 2020-11-29 19:10:32

问题


I have trained a keras CNN monitoring the metrics as follow:

METRICS = [
  TruePositives(name='tp'),
  FalsePositives(name='fp'),
  TrueNegatives(name='tn'),
  FalseNegatives(name='fn'), 
  BinaryAccuracy(name='accuracy'),
  Precision(name='precision'),
  Recall(name='recall'),
  AUC(name='auc'),
 ]

and then the model.compile:

 model.compile(optimizer='nadam', loss='binary_crossentropy',
         metrics=METRICS)

it works perfectly and I saved my h5 model (model.h5).

Now I have downloaded the model and I would like to use it in other script importing the model with:

 from keras.models import load_model
 model = load_model('model.h5')
 model.predict(....)

but during the running the compiler returns:

 ValueError: Unknown metric function: {'class_name': 'TruePositives', 'config': {'name': 'tp', 'dtype': 'float32', 'thresholds': None}}

How I should manage this issue?

Thank you in advance


回答1:


When you have custom metrics you need to follow slightly different approach.

  1. Create model, train and save the model
  2. Load the model with custom_objects and compile = False
  3. Finally compile the model with the custom_objects

I am showing the approach here

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Custom Loss1 (for example) 
#@tf.function() 
def customLoss1(yTrue,yPred):
  return tf.reduce_mean(yTrue-yPred) 

# Custom Loss2 (for example) 
#@tf.function() 
def customLoss2(yTrue, yPred):
  return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])
  return model 

# Create a basic model instance
model=create_model()

# Fit and evaluate model 
model.fit(x_train, y_train, epochs=5)

loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc)) # Original model, accuracy: 98.11%

# saving the model
model.save('./Mymodel',save_format='tf')

# load the model
loaded_model = tf.keras.models.load_model('./Mymodel',custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2},compile=False)

# compile the model
loaded_model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])

# loaded model also has same accuracy, metrics and loss
loss, acc,loss1, loss2 = loaded_model.evaluate(x_test, y_test,verbose=1)
print("Loaded model, accuracy: {:5.2f}%".format(100*acc)) #Loaded model, accuracy: 98.11%



回答2:


It looks like you are playing with a tensorflow tutorial. I also used these exact metrics and had the same problem. What worked for me was to load the model with compile = False and then compile it with the custom metrics. Then you should be able to use model.predict(....) as expected.

import keras

model = keras.models.load_model('model.h5', compile = False)

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

model.compile(optimizer = keras.optimizers.Adam(learning_rate=1e-4),
              loss = 'binary_crossentropy',
              metrics = METRICS
             )


来源:https://stackoverflow.com/questions/61513447/load-keras-model-h5-unknown-metrics

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!