问题
How would one access a training operation from a tf.keras.models.Model
?
Consider the following:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Flatten
from tensorflow.keras.models import Model
import numpy as np
from sys import exit as xit
# Make some dummy data
dummy_data_shape=(5,5)
def batch_generator(size):
""" Makes some random data """
def _gen():
y_batch=np.random.randint(0,2, size=size)
y_batch=np.expand_dims(y_batch,-1)
y_expanded=np.expand_dims(y_batch,-1)
x_batch=np.ones((size,*dummy_data_shape))*y_expanded
yield x_batch,y_batch
return _gen()
# Make some simple model
Y=tf.placeholder(tf.float32,[None,1])
X = Input(shape=dummy_data_shape)
layer_mod = Flatten()(X)
layer_mod = Dense(1)(layer_mod)
# Tie it all together and compile
out_model = Model(inputs=[X], outputs=[layer_mod])
out_model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.metrics.binary_crossentropy
)
### How can I access a train_op from the out_model?
with tf.Session() as sess:
data_iter=batch_generator(10)
sess.run(tf.global_variables_initializer())
x,y=next(data_iter)
## Here: How to access the operation that trains the model?
train_op=out_model.train_op #<-- ?
sess.run(train_op, feed_dict={X:x,Y:y})
What should the second to last line in the code above be for the model to train?
来源:https://stackoverflow.com/questions/55687011/access-training-operation-in-a-tf-keras-model