How to use a pre-trained keras model for inference in tf.data.Dataset.map?

廉价感情. 提交于 2019-12-11 09:54:25

问题


I have a pre-trained model, and I'm trying to build another model that takes as input the output of the previous model. I don't want to train the models end-to-end, and want to use the first model for inference only. The first model was trained using tf.data.Dataset pipeline, and my first inclination was to integrate the model as just another dataset.map() operation at the tail of the pipeline, but I'm having issues with that. I've encountered 20 different errors in the process, each unrelated to the prior one. The batch-normalization layer particularly seems to be a pain-point for this.

Below is a minimal starter example that illustrates the issue. It's written in R but answers in python are also welcome.

I'm using tensorflow-gpu version 1.13.1 and keras from tf.keras

library(reticulate)
library(tensorflow)
library(keras)
library(tfdatasets)
use_implementation("tensorflow")

model_weights_path <- 'model-weights.h5'

arr <- function(...) 
  np_array(array(seq_len(prod(unlist(c(...)))), unlist(c(...))), dtype = 'float32')

new_model <- function(load_weights = TRUE) {
  model <- keras_model_sequential() %>% 
    layer_conv_1d(5, 5, activation = 'relu', input_shape = shape(150, 10)) %>%
    layer_batch_normalization() %>%
    layer_flatten() %>%
    layer_dense(10, activation = 'softmax')
  if (load_weights)
    load_model_weights_hdf5(model, model_weights_path)
  freeze_weights(model)
  model
}

if(!file.exists(model_weights_path)) {
  model <- new_model(FALSE) 
  save_model_weights_hdf5(model, model_weights_path)
}

model <- new_model()

data <- arr(20, 150, 10)
ds <- tfdatasets::tensors_dataset(data) %>% 
  dataset_repeat()

ds2 <- ds %>% 
  dataset_map(function(x) {
    model(x)
  })

try(nb <- next_batch(ds2))

sess <- k_get_session()
it <- make_iterator_initializable(ds2)
sess$run(iterator_initializer(it))
nb <- it$get_next()

try(sess$run(nb))

sess$run(tf$initialize_all_variables())

try(sess$run(nb))

来源:https://stackoverflow.com/questions/55227732/how-to-use-a-pre-trained-keras-model-for-inference-in-tf-data-dataset-map

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!