问题
I am working on tensorflow2 and I am trying to implement Max unpool with indices to implement SegNet.
When I run it I get the following problem. I am defining the def MaxUnpool2D
and then calling it in the model. I suppose that the problem is given by the fact that updates and mask have got shape (None, H,W,ch).
def MaxUnpooling2D(updates, mask):
size = 2
mask = tf.cast(mask, 'int32')
input_shape = tf.shape(updates, out_type='int32')
# calculation new shape
output_shape = (
input_shape[0],
input_shape[1]*size,
input_shape[2]*size,
input_shape[3])
# calculation indices for batch, height, width and feature maps
one_like_mask = tf.ones_like(mask, dtype='int32')
batch_shape = tf.concat(
[[input_shape[0]], [1], [1], [1]],
axis=0)
batch_range = tf.reshape(
tf.range(output_shape[0], dtype='int32'),
shape=batch_shape)
b = one_like_mask * batch_range
y = mask // (output_shape[2] * output_shape[3])
x = (mask // output_shape[3]) % output_shape[2]
feature_range = tf.range(output_shape[3], dtype='int32')
f = one_like_mask * feature_range
updates_size = tf.size(updates)
indices = K.transpose(K.reshape(
tf.stack([b, y, x, f]),
[4, updates_size]))
values = tf.reshape(updates, [updates_size])
return tf.scatter_nd(indices, values, output_shape)
def segnet_conv(
inputs,
kernel_size=3,
kernel_initializer='glorot_uniform',
batch_norm = False,
**kwargs):
conv1 = Conv2D(
filters=64,
kernel_size=kernel_size,
padding='same',
activation=None,
kernel_initializer=kernel_initializer,
name='conv_1'
)(inputs)
if batch_norm:
conv1 = BatchNormalization(name='bn_1')(conv1)
conv1 = LeakyReLU(alpha=0.3, name='activation_1')(conv1)
conv1 = Conv2D(
filters=64,
kernel_size=kernel_size,
padding='same',
activation=None,
kernel_initializer=kernel_initializer,
name='conv_2'
)(conv1)
if batch_norm:
conv1 = BatchNormalization(name='bn_2')(conv1)
conv1 = LeakyReLU(alpha=0.3, name='activation_2')(conv1)
pool1, mask1 = tf.nn.max_pool_with_argmax(
input=conv1,
ksize=2,
strides=2,
padding='SAME'
)
def segnet_deconv(
pool1,
mask1,
kernel_size=3,
kernel_initializer='glorot_uniform',
batch_norm = False,
**kwargs
):
dec = MaxUnpooling2D(pool5, mask5)
dec = Conv2D(
filters=512,
kernel_size=kernel_size,
padding='same',
activation=None,
kernel_initializer=kernel_initializer,
name='upconv_13'
)(dec)
def classifier(
dec,
ch_out=2,
kernel_size=3,
final_activation=None,
batch_norm = False,
**kwargs
):
dec = Conv2D(
filters=64,
kernel_size=kernel_size,
activation='relu',
padding='same',
name='dec_out1'
)(dec)
@tf.function
def segnet(
inputs,
ch_out=2,
kernel_size=3,
kernel_initializer='glorot_uniform',
final_activation=None,
batch_norm = False,
**kwargs
):
pool5, mask1, mask2, mask3, mask4, mask5 = segnet_conv(
inputs,
kernel_size=3,
kernel_initializer='glorot_uniform',
batch_norm = False
)
dec = segnet_deconv(
pool5,
mask1,
mask2,
mask3,
mask4,
mask5,
kernel_size=kernel_size,
kernel_initializer=kernel_initializer,
batch_norm = batch_norm
)
output = classifier(
dec,
ch_out=2,
kernel_size=3,
final_activation=None,
batch_norm = batch_norm
)
return output
inputs = Input(shape=(*params['image_size'], params['num_channels']), name='input')
outputs = segnet(inputs, n_labels=2, kernel=3, pool_size=(2, 2), output_mode=None)
# we define our U-Net to output logits
model = Model(inputs, outputs)
Can you please help me with this problem?
回答1:
I have solved the problem. If someone will need here is the code for MaxUnpooling2D:
def MaxUnpooling2D(pool, ind, output_shape, batch_size, name=None):
"""
Unpooling layer after max_pool_with_argmax.
Args:
pool: max pooled output tensor
ind: argmax indices
ksize: ksize is the same as for the pool
Return:
unpool: unpooling tensor
:param batch_size:
"""
with tf.compat.v1.variable_scope(name):
pool_ = tf.reshape(pool, [-1])
batch_range = tf.reshape(tf.range(batch_size, dtype=ind.dtype), [tf.shape(pool)[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, [-1, 1])
ind_ = tf.reshape(ind, [-1, 1])
ind_ = tf.concat([b, ind_], 1)
ret = tf.scatter_nd(ind_, pool_, shape=[batch_size, output_shape[1] * output_shape[2] * output_shape[3]])
# the reason that we use tf.scatter_nd: if we use tf.sparse_tensor_to_dense, then the gradient is None, which will cut off the network.
# But if we use tf.scatter_nd, the gradients for all the trainable variables will be tensors, instead of None.
# The usage for tf.scatter_nd is that: create a new tensor by applying sparse UPDATES(which is the pooling value) to individual values of slices within a
# zero tensor of given shape (FLAT_OUTPUT_SHAPE) according to the indices (ind_). If we ues the orignal code, the only thing we need to change is: changeing
# from tf.sparse_tensor_to_dense(sparse_tensor) to tf.sparse_add(tf.zeros((output_sahpe)),sparse_tensor) which will give us the gradients!!!
ret = tf.reshape(ret, [tf.shape(pool)[0], output_shape[1], output_shape[2], output_shape[3]])
return ret
来源:https://stackoverflow.com/questions/60104102/custom-max-pool-layer-valueerror-the-channel-dimension-of-the-inputs-should-be