Access training operation in a tf.keras.Model

这一生的挚爱 提交于 2019-12-11 08:36:15

问题


How would one access a training operation from a tf.keras.models.Model? Consider the following:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Flatten
from tensorflow.keras.models import Model
import numpy as np
from sys import exit as xit
# Make some dummy data
dummy_data_shape=(5,5)
def batch_generator(size):
    """ Makes some random data """
    def _gen():
        y_batch=np.random.randint(0,2, size=size)
        y_batch=np.expand_dims(y_batch,-1)
        y_expanded=np.expand_dims(y_batch,-1)
        x_batch=np.ones((size,*dummy_data_shape))*y_expanded
        yield x_batch,y_batch
    return _gen()

# Make some simple model
Y=tf.placeholder(tf.float32,[None,1])
X = Input(shape=dummy_data_shape)
layer_mod = Flatten()(X)
layer_mod = Dense(1)(layer_mod)

# Tie it all together and compile
out_model = Model(inputs=[X], outputs=[layer_mod])
out_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.metrics.binary_crossentropy
)

### How can I access a train_op from the out_model?
with tf.Session() as sess:
    data_iter=batch_generator(10)
    sess.run(tf.global_variables_initializer())
    x,y=next(data_iter)
    ## Here: How to access the operation that trains the model?
    train_op=out_model.train_op #<-- ?
    sess.run(train_op, feed_dict={X:x,Y:y})

What should the second to last line in the code above be for the model to train?

来源:https://stackoverflow.com/questions/55687011/access-training-operation-in-a-tf-keras-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!